Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: afl-3.0
|
| 3 |
+
language: en
|
| 4 |
+
tags:
|
| 5 |
+
- t5
|
| 6 |
+
datasets:
|
| 7 |
+
- wikipedia
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# chunked T5 - small (cT5-small)
|
| 11 |
+
|
| 12 |
+
Github: https://github.com/mtreviso/chunked-t5
|
| 13 |
+
|
| 14 |
+
A T5 model that uses a new loss where a special end-of-chunk token `</c>` is appended after sentinel tokens.
|
| 15 |
+
The decoder has to predict the full input with masked tokens followed by `</c>`.
|
| 16 |
+
This allows a much faster auto-regressive generation since the decoder can predict multiple tokens in parallel.
|
| 17 |
+
|
| 18 |
+
For example, for the input `the quick brown fox jumps over the lazy dog`:
|
| 19 |
+
```
|
| 20 |
+
encoder: the <extra_id_0> fox jumps <extra_id_1> the lazy dog
|
| 21 |
+
|
| 22 |
+
T5 decoder : <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
|
| 23 |
+
cT5 decoder: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2>
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
The generation may look like this for T5 and cT5:
|
| 27 |
+
```
|
| 28 |
+
T5: <extra_id_0>
|
| 29 |
+
T5: <extra_id_0> quick
|
| 30 |
+
T5: <extra_id_0> quick brown
|
| 31 |
+
T5: <extra_id_0> quick brown <extra_id_1>
|
| 32 |
+
T5: <extra_id_0> quick brown <extra_id_1> over
|
| 33 |
+
T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
|
| 34 |
+
T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2> </s>
|
| 35 |
+
|
| 36 |
+
cT5: <extra_id_0> <pad> <extra_id_1> <pad> <extra_id_2> </s>
|
| 37 |
+
cT5: <extra_id_0> quick <pad> <extra_id_1> over <pad> <extra_id_2> </s>
|
| 38 |
+
cT5: <extra_id_0> quick brown <pad> <extra_id_1> over </c> <extra_id_2> </s>
|
| 39 |
+
cT5: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2> </s>
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
In the original T5, the decoder is called \\(n_s + 1 + \sum_i |s_i|\\) times autoregressively,
|
| 43 |
+
where \\(n_s\\) is the number of sentinel tokens and \\(s_1,...,s_{n_s}\\) are the predicted chunks.
|
| 44 |
+
In contrast, cT5's decoder is called just \\(max_i |s_i| + 1\\) times.
|
| 45 |
+
The generation stops when all sentences were fully translated to complete chunks, i.e., until all `</c>` tokens were generated.
|
| 46 |
+
Alternatively, you can also set `max_chunk_size` to manually force the model to stop after generating a chunk with `max_chunk_size` tokens.
|
| 47 |
+
The overhead of calling the decoder with a longer input is less pronounced since this computation can be parallelized in GPUs/TPUs.
|
| 48 |
+
|
| 49 |
+
## Training details
|
| 50 |
+
|
| 51 |
+
cT5 models used T5's weights as a starting point, and then it was finetuned on the
|
| 52 |
+
English [wikipedia](https://huggingface.co/datasets/wikipedia) for 3 epochs,
|
| 53 |
+
achieving ~74% validation accuracy (ct5-small).
|
| 54 |
+
The training script is in JAX + Flax and can be found in `pretrain_ct5.py`.
|
| 55 |
+
|
| 56 |
+
Flax checkpoints can be converted to PyTorch via `convert_flax_to_pytorch.py [flax_dirname]`.
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
## Checkpoints
|
| 60 |
+
|
| 61 |
+
- ct5-small: https://huggingface.co/mtreviso/ct5-small-en-wiki
|
| 62 |
+
- ct5-base: todo
|
| 63 |
+
- ct5-large: todo
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
## Usage
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
from transformers import AutoTokenizer
|
| 70 |
+
from modeling_ct5 import CT5ForConditionalGeneration
|
| 71 |
+
|
| 72 |
+
tokenizer = AutoTokenizer.from_pretrained("mtreviso/ct5-small-en-wiki")
|
| 73 |
+
model = CT5ForConditionalGeneration.from_pretrained("mtreviso/ct5-small-en-wiki")
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
For training:
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
| 80 |
+
labels = tokenizer("<extra_id_0> man </c> <extra_id_1> the </c> <extra_id_2>", return_tensors="pt").input_ids
|
| 81 |
+
outputs = model(input_ids=input_ids, labels=labels)
|
| 82 |
+
loss = outputs.loss
|
| 83 |
+
logits = outputs.logits
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
For generation:
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
texts = [
|
| 90 |
+
"The <extra_id_0> walks in <extra_id_1> park",
|
| 91 |
+
"UN Chief says there is no way to <extra_id_0> in Syria",
|
| 92 |
+
]
|
| 93 |
+
input_ids = tokenizer(texts, return_tensors="pt", padding=True).input_ids
|
| 94 |
+
generated_ids = model.generate(
|
| 95 |
+
input_ids,
|
| 96 |
+
use_cache=False, # important to set to False to avoid caching
|
| 97 |
+
eoc_token_id=tokenizer.vocab['</c>'], # important to set to the correct end-of-chunk id
|
| 98 |
+
max_chunk_size=5, # the default is 9999999, which is a large number
|
| 99 |
+
)
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
This will produce the following tokens:
|
| 103 |
+
```python
|
| 104 |
+
>> ['<pad>', '<extra_id_0>', '▁Walking', '▁Trail', '</c>', '<extra_id_1>', '▁the', '</c>', '<extra_id_2>', '</s>']
|
| 105 |
+
>> ['<pad>', '<extra_id_0>', '▁treat', '▁Syria', '</c>', '<extra_id_1>', '</s>', '<pad>', '<pad>', '<pad>']
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
You have to pass `use_cache=False` to `generate()` in order to avoid caching during the generation procedure as caching is not available for parallel decoding.
|
| 109 |
+
Currently, parallel decoding is only supported for PyTorch (greedy search, greedy sampling, beam search, beam sampling) and JAX (greedy search and greedy sampling).
|
| 110 |
+
|
| 111 |
+
**Note on the beam search implementation**: my beam search implementation is slower than optimal.
|
| 112 |
+
This is because I use the structures provided by HuggingFace's implementation, namely, BeamScores and BeamHypotheses to store the beam search results for each chunk in the input.
|
| 113 |
+
In other words, my implementation computes independent "beams" for each chunk rather than for each input sequence.
|
| 114 |
+
It is possible to make it faster by using a custom BeamScores and BeamHypotheses class, but I haven't done that yet.
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
## Evaluation
|
| 118 |
+
|
| 119 |
+
See the notebook `evaluate_ct5.ipynb` for an example of how to evaluate cT5 in terms of accuracy and perplexity.
|
| 120 |
+
The notebook `profile.ipynb` shows how to profile the model to get runtimes.
|
| 121 |
+
|
| 122 |
+
Here is a comparison between cT5-small and T5-small on a subset of the WikiText-103 dataset using deterministic greedy search:
|
| 123 |
+
|
| 124 |
+
| Model | Exact match ↑ | Edit distance ratio ↑ | Perplexity ↓ | Time (seconds) ↓ |
|
| 125 |
+
|-------|---------------|----------------------|--------------|-----------------|
|
| 126 |
+
| T5-small | 0.11 | 0.60 | 2.22 | 44.71 |
|
| 127 |
+
| cT5-small | 0.09 | 0.58 | 1.48 | 10.63 |
|
| 128 |
+
|
| 129 |
+
On this toy dataset, cT5-small has a lower perplexity while being faster than T5-small. However, more experiments are needed for a rigorous evaluation.
|
| 130 |
+
|
| 131 |
+
If you are interested in applying cT5 to real data, please contact me.
|