Generalized Interpolating Discrete Diffusion
By Dimitri von Rütte, Janis Fluri, Yuhui Ding, Antonio Orvieto, Bernhard Schölkopf, Thomas Hofmann
We present Generalized Interpolating Discrete Diffusion (GIDD), a novel framework for training discrete diffusion models. GIDD can be seen as a generalization of the popular masked diffusion paradigm (MDM) to any diffusion process that can be written as a linear interpolation between a data distribution and some (time-variable) mixing distribution. We demonstrate the flexibility of GIDD by training models on a hybrid diffusion process that combines masking and uniform noise. The model therefore is trained to not only "fill in the blanks" (i.e. the masked tokens), but also to consider the correctness of already-filled-in tokens and, if necessary, replace incorrect tokens with more plausible ones. We show that GIDD models trained on hybrid noise have better sample quality (generative PPL) than mask-only models, and that they are able to identify and correct their own mistakes in generated samples through a self-correction step. This repository contains all training and evaluation code necessary for reproducing the results in the paper.
Pretrained Checkpoints
Our trained checkpoints are available under the following links. All of them have been trained on 131B tokens from the OpenWebText dataset with the GPT-2 tokenizer.
Model | Small (169.6M) | Base (424.5M) |
---|---|---|
GIDD+ (p_u = 0.0) | dvruette/gidd-small-p_unif-0.0 | dvruette/gidd-base-p_unif-0.0 |
GIDD+ (p_u = 0.1) | dvruette/gidd-small-p_unif-0.1 | dvruette/gidd-base-p_unif-0.1 |
GIDD+ (p_u = 0.2) | dvruette/gidd-small-p_unif-0.2 | dvruette/gidd-base-p_unif-0.2 |
Use the Model
- Install the GIDD repo:
pip install git+https://github.com/dvruette/gidd
- For quickly downloading a trained model and playing around with it, the
GiddPipeline
class is most convenient:
from gidd import GiddPipeline
# Download a pretrained model from HuggingFace
pipe = GiddPipeline.from_pretrained("dvruette/gidd-base-p_unif-0.2", trust_remote_code=True)
# Generate samples
texts = pipe.generate(num_samples=4, num_inference_steps=128)
# Run self-correction step
corrected_texts = pipe.self_correction(texts, num_inference_steps=128, early_stopping=True, temperature=0.1)
print(corrected_texts)
- Downloads last month
- 28