← Back to all posts

Reproducing Diffusion-LM

Mark Henry

Li et al 2022 presents Diffusion-LM, a diffusion text model. Unlike autoregressive models which genererate text one token at a time, diffusion models start with a sequence of random noise and iteratively refines or "denoises" it into desirable tokens. (Google has a great explanation and demo at Gemini Diffusion's webpage).

Although diffusion models still lag behind autoregressive models in cogency and utility, in the process of reproducing the model I learned how a diffusion text model works. My BERT-based diffusion model was trained on WikiText and training took place on my own RTX 5070.

My model shows desirable qualities like good denoising fidelity, graceful degradation as noisiness is increased, and generation of recognizable tokens.

Line graph showing denoising performance versus timestep. Cosine similarity starts at 99.82% for low noise (t=0-10) and decreases, slowly at first, to 73.88% at maximum noise (t=1900)
Denoising performance of the trained model. The model achieves excellent denoising fidelity of 99.82% cosine similarity at t=0-10; guesses become less accurate as noise is added. This is the expected result and I'm very happy with it.
Diffusion-LM Denoising Performance Across Noise Schedule
Timestep Noise % Cosine Sim Std Quality
t=0 0.0% 0.9983 ±0.0002 🟢 Excellent
t=1 2.4% 0.9982 ±0.0001 🟢 Excellent
t=5 5.1% 0.9984 ±0.0001 🟢 Excellent
t=10 7.1% 0.9981 ±0.0001 🟢 Excellent
t=50 15.8% 0.9967 ±0.0003 🟢 Excellent
t=100 22.4% 0.9956 ±0.0005 🟢 Excellent
t=500 50.0% 0.9888 ±0.0016 🟢 Excellent
t=1000 70.7% 0.9769 ±0.0037 🟢 Excellent
t=1500 86.6% 0.9486 ±0.0097 🟢 Excellent
t=1900 97.5% 0.7388 ±0.0319 🟡 Good
Denoising performance, tabular format
Final generated text: 'trump visits session tvo ability la when apparently worse defend construction pond having geographic wheelhood swedish magazine explosion within girl ab forget cruz developers uniqueulation associated ashley mcourown 1920 interviewged anti septfula marcus supports utility springfield destructiveoga, change, university unique cell system twitter thinkingons en attack option hopefully children 40 consumers entrance'
Example output, pure random noise after 100 steps of denoising. Obviously it's incoherent

Core concepts of Diffusion-LM

The central idea of Diffusion-LM is the loss function, which has three terms. This loss is applied across the embeddings and the transformer layers during training.

  1. The model takes in noisy latents and outputs slightly-less-noisy latents. During training, we take latents and add noise to them, then reward the model for correctly recovering the original latents. Correspondingly, the first term of the loss function is the diffusion loss: the MSE between the model's guess and the original latents.
  2. The model would really love to be in a situation where the token embeddings are super far apart in embedding space compared to the noise applied to them, making denoising a trivial process of rounding the latents to the nearest token embedding. To prevent this, we include a second loss term that penalizes the embeddings for drifting too far apart from each other. This is called the prior loss. (Another way of looking at this is that it ensures that the noise applied has a large enough magnitude.)
  3. The other situation the model would like to create is a state where all the embeddings are right next to each other so it gets points for predicting the same thing over and over. To prevent this, the third loss term is a cross-entropy loss that penalizes weaksauce predictions. This reconstruction loss encourages the model to make bold predictions, and prohibits embeddings from clumping up.

So two loss components control the "spread" of the embeddings, and one component rewards correct denoising.

The final important concept is that latents in continuous space are rounded off to the nearest token embedding in the final step of denoising. This bridges the gap between the continuous and discrete domains.

Code

My code is available at https://github.com/mark-henry/text-diffusion.