Generalized Interpolating Discrete Diffusion

By Dimitri von Rütte, Janis Fluri, Yuhui Ding, Antonio Orvieto, Bernhard Schölkopf, Thomas Hofmann

arXiv Open In Colab GitHub

animation

We present Generalized Interpolating Discrete Diffusion (GIDD), a novel framework for training discrete diffusion models. GIDD can be seen as a generalization of the popular masked diffusion paradigm (MDM) to any diffusion process that can be written as a linear interpolation between a data distribution and some (time-variable) mixing distribution. We demonstrate the flexibility of GIDD by training models on a hybrid diffusion process that combines masking and uniform noise. The model therefore is trained to not only "fill in the blanks" (i.e. the masked tokens), but also to consider the correctness of already-filled-in tokens and, if necessary, replace incorrect tokens with more plausible ones. We show that GIDD models trained on hybrid noise have better sample quality (generative PPL) than mask-only models, and that they are able to identify and correct their own mistakes in generated samples through a self-correction step. This repository contains all training and evaluation code necessary for reproducing the results in the paper.

Pretrained Checkpoints

Our trained checkpoints are available under the following links. All of them have been trained on 131B tokens from the OpenWebText dataset with the GPT-2 tokenizer.

Model Small (169.6M) Base (424.5M)
GIDD+ (p_u = 0.0) dvruette/gidd-small-p_unif-0.0 dvruette/gidd-base-p_unif-0.0
GIDD+ (p_u = 0.1) dvruette/gidd-small-p_unif-0.1 dvruette/gidd-base-p_unif-0.1
GIDD+ (p_u = 0.2) dvruette/gidd-small-p_unif-0.2 dvruette/gidd-base-p_unif-0.2

Use the Model

  1. Install the GIDD repo:
pip install git+https://github.com/dvruette/gidd
  1. For quickly downloading a trained model and playing around with it, the GiddPipeline class is most convenient:
from gidd import GiddPipeline

# Download a pretrained model from HuggingFace
pipe = GiddPipeline.from_pretrained("dvruette/gidd-small-p_unif-0.1", trust_remote_code=True)

# Generate samples
texts = pipe.generate(num_samples=4, num_inference_steps=128)

# Run self-correction step
corrected_texts = pipe.self_correction(texts, num_inference_steps=128, early_stopping=True, temperature=0.1)

print(corrected_texts)
Downloads last month
40
Safetensors
Model size
170M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Dataset used to train dvruette/gidd-small-p_unif-0.1

Collection including dvruette/gidd-small-p_unif-0.1