English
File size: 4,692 Bytes
3e95136
 
 
 
 
 
 
 
 
 
 
3144d4a
 
 
 
 
3e95136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358b940
3e95136
 
 
 
 
 
 
 
 
3ce2925
 
 
 
 
 
 
3e95136
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
---
license: apache-2.0
language:
- en
datasets:
- Skylion007/openwebtext
metrics:
- perplexity
- mauve
---

# Self-Distillation Through Time (SDTT)
SDTT is a distillation method for diffusion language models. Recent diffusion language models such as [SEDD](https://huggingface.co/louaaron/sedd-small) or [MDLM](https://huggingface.co/kuleshov-group/mdlm-owt) achieve great results. 
However, because they cannot use KV-caching (non-causal architecture), it is slow to sample from them. Therefore, we devise a novel distillation method to reduce the inference latency of discrete diffusion models. 
After distillation, we can sample up to 8x faster than GPT-2 (that uses KV-caching). Find more details below and on [our GitHub repo](https://github.com/jdeschena/sdtt).

## Using SDTT
- We released 3 groups of models:
    1. The **baseline students** distilled with the `kld`, `mse` and `tvd` objectives, distilled from a model trained for 1M steps.
    2. The **students from the scaling experiments**, with sizes `sm`, `md`, `large`, distilled from models trained for 400k steps.
    3. The **teachers from the scaling experiments**, with sizes `sm`, `md`, `large`, before any distillation.
- To load those models, first install our code:
```bash
git clone https://github.com/jdeschena/sdtt.git
cd sdtt
pip install -r requirements.txt
pip install flash-attn
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -e .
```
- You can then import our models, sample and evaluate them:

#### Load the baseline students
```python
from sdtt import load_small_student
student = load_small_student(loss="kld", round=7)  # load the kld student after the last distillation round
student = load_small_student(loss="mse", round=2)  # load the mse student after the second distillation round
student = load_small_student(loss="tvd", round=1)  # load the tvd student after the first distillation round
```

#### Load the students from the scaling experiment
```python
from sdtt import load_scaling_student
student = load_scaling_student(size="sm", round=7)  # load small student after the last distillation round
student = load_scaling_student(size="md", round=1)   # load medium student after the first distillation round
student = load_scaling_student(size="large", round=3)  # load large student after the third distillation round
```

#### Load the teachers from the scaling experiment
```python
from sdtt import load_scaling_teacher
student = load_scaling_student(size="sm",)  # load small teacher
student = load_scaling_student(size="md",)   # load medium teacher
student = load_scaling_student(size="large",)  # load large teacher
```

#### Sample from the pretrained models
```python
from sdtt import load_small_student, load_scaling_student, load_scaling_teacher
import torch

model = load_small_student(loss="kld", round=7)  # load model, see above
model.cuda()  # put model on gpu

# Unconditional generation
tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
)
# Detokenize
uncond_text = model.tokenizer.batch_decode(tokens)

# Conditional generation, based on a prompt
# Prepare a prompt
prompt = "Today is a great day. The sun is shining,"
prompt_tokens = model.tokenizer(prompt)["input_ids"]
prompt_tokens.insert(0, model.tokenizer.bos_token_id)
prompt_tokens = torch.tensor(prompt_tokens, device="cuda")
prompt_len = len(prompt_tokens)

def project_fn(x):
    # Project the first 10 tokens of all examples to the prompt
    x[:, :prompt_len] = prompt_tokens  
    return x  # Don't forget to return

tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
    project_fn=project_fn
)

cond_text = model.tokenizer.batch_decode(tokens)
```



For more details, please see our github repository: [SDTT](https://github.com/jdeschena/sdtt)

## Model Details
Our small checkpoints are distilled from the [MDLM](https://github.com/kuleshov-group/mdlm) checkpoints. We also release medium (424M) and large (863M) checkpoints that we pretrained ourselves.

## Citation

<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
Please cite our work using the bibtex below:

**BibTeX:**

```
@article{deschenaux2024autoregressionfastllmsselfdistillation,
        title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
        author={Deschenaux, Justin and Gulcehre, Caglar}
        eprint={2410.21035},
        archivePrefix={arXiv},
        primaryClass={cs.LG},
        url={https://arxiv.org/abs/2410.21035}, 
}
```

## Contact
Justin Deschenaux ([email protected])