Upload 12 files
Browse files- config.yaml +127 -0
- diffusion.py +1434 -0
- dit.py +388 -0
- ema.py +97 -0
- esm_utils.py +15 -0
- generate.py +60 -0
- main.py +250 -0
- mdlm_motif_benchmarking.py +96 -0
- mlm_generate_utils.py +108 -0
- noise_schedule.py +153 -0
- pl_data_loader.py +819 -0
- utils.py +230 -0
config.yaml
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
|
4 |
+
- /model: small
|
5 |
+
- /strategy: ddp
|
6 |
+
- /noise: loglinear
|
7 |
+
- /lr_scheduler: constant_warmup
|
8 |
+
|
9 |
+
mode: sample_eval # train / ppl_eval / sample_eval
|
10 |
+
diffusion: absorbing_state
|
11 |
+
backbone: membrane_esm_finetune # dit / dimamba / ar / vanilla_esm_pretrain / membrane_esm_finetune
|
12 |
+
parameterization: subs # subs / d3pm / sedd
|
13 |
+
time_conditioning: False
|
14 |
+
T: 0 # 0 (continuous time) / 1000
|
15 |
+
subs_masking: False
|
16 |
+
|
17 |
+
seed: 42
|
18 |
+
|
19 |
+
data:
|
20 |
+
train:
|
21 |
+
vanilla_esm_train_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/train.csv
|
22 |
+
membrane_esm_train_path: /workspace/sg666/MDpLM/data/membrane/train.csv
|
23 |
+
wrap: null
|
24 |
+
test:
|
25 |
+
vanilla_esm_test_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/test.csv
|
26 |
+
membrane_esm_test_path: /workspace/sg666/MDpLM/data/membrane/test.csv
|
27 |
+
wrap: null
|
28 |
+
valid:
|
29 |
+
vanilla_esm_valid_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/val.csv
|
30 |
+
membrane_esm_valid_path: /workspace/sg666/MDpLM/data/membrane/val.csv
|
31 |
+
wrap: null
|
32 |
+
wrapping: True
|
33 |
+
|
34 |
+
loader:
|
35 |
+
global_batch_size: 8
|
36 |
+
eval_global_batch_size: ${.global_batch_size}
|
37 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
38 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
39 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
40 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
41 |
+
pin_memory: True
|
42 |
+
|
43 |
+
sampling:
|
44 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
45 |
+
steps: 128
|
46 |
+
noise_removal: True
|
47 |
+
# TODO(yair): @subham, why aren't these params under `eval`?
|
48 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
49 |
+
num_sample_log: 2
|
50 |
+
semi_ar: False
|
51 |
+
stride_length: 1
|
52 |
+
num_strides: 1
|
53 |
+
|
54 |
+
training:
|
55 |
+
ema: 0.9999
|
56 |
+
antithetic_sampling: True
|
57 |
+
importance_sampling: False
|
58 |
+
sampling_eps: 1e-3
|
59 |
+
change_of_variables: False
|
60 |
+
mlm_model_path: /workspace/sg666/MDpLM/benchmarks/MLM/model_ckpts_650M/best_model_epoch
|
61 |
+
esm_model_path: facebook/esm2_t30_150M_UR50D
|
62 |
+
focus_mask: False
|
63 |
+
|
64 |
+
eval:
|
65 |
+
checkpoint_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/eos-wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/checkpoints/best.ckpt # Used to evaluate a checkpoint after training.
|
66 |
+
disable_ema: False
|
67 |
+
compute_generative_perplexity: False
|
68 |
+
perplexity_batch_size: 8
|
69 |
+
compute_perplexity_on_sanity: False
|
70 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
71 |
+
generate_samples: True
|
72 |
+
generation_model: /workspace/sg666/MDpLM/checkpoints/membrane_automodel/epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
|
73 |
+
|
74 |
+
optim:
|
75 |
+
weight_decay: 0.075
|
76 |
+
lr: 3e-4
|
77 |
+
beta1: 0.9
|
78 |
+
beta2: 0.999
|
79 |
+
eps: 1e-8
|
80 |
+
|
81 |
+
Model:
|
82 |
+
hidden_size: 1280
|
83 |
+
cond_dim: 256
|
84 |
+
n_heads: 20
|
85 |
+
n_blocks: 4
|
86 |
+
dropout: 0.5
|
87 |
+
length: null #512
|
88 |
+
scale_by_sigma: True
|
89 |
+
|
90 |
+
trainer:
|
91 |
+
_target_: lightning.Trainer
|
92 |
+
accelerator: cuda
|
93 |
+
num_nodes: 1
|
94 |
+
devices: ${device_count:}
|
95 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
96 |
+
gradient_clip_val: 1.0
|
97 |
+
precision: bf16
|
98 |
+
num_sanity_val_steps: 2
|
99 |
+
max_epochs: 60
|
100 |
+
max_steps: 1_000_000
|
101 |
+
log_every_n_steps: 10
|
102 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
103 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
104 |
+
val_check_interval: 955
|
105 |
+
|
106 |
+
wandb:
|
107 |
+
project: MDpLM_finetune_membrane_200k-seqs
|
108 |
+
notes: null
|
109 |
+
group: programmablebio
|
110 |
+
job_type: null
|
111 |
+
name: dit_test #dit_wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16
|
112 |
+
id: ${.name}_${seed}
|
113 |
+
|
114 |
+
hydra:
|
115 |
+
run:
|
116 |
+
dir: /workspace/sg666/MDpLM/outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
|
117 |
+
job:
|
118 |
+
chdir: true
|
119 |
+
|
120 |
+
checkpointing:
|
121 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
122 |
+
save_dir: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
|
123 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
124 |
+
resume_from_ckpt: false
|
125 |
+
resume_ckpt_path: ${.save_dir}/epochs30_lr3e-4_bsz8_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params_no-compile/checkpoints/last.ckpt #/checkpoints/last.ckpt
|
126 |
+
pretrained_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/vanilla_esm_pretrained_automodel/epochs10_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
|
127 |
+
finetuned_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
|
diffusion.py
ADDED
@@ -0,0 +1,1434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import typing
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import hydra.utils
|
9 |
+
import lightning as L
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch
|
13 |
+
# import dit
|
14 |
+
import ema
|
15 |
+
import time
|
16 |
+
import gc
|
17 |
+
import pl_data_loader as dataloader
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torchmetrics
|
20 |
+
import transformers
|
21 |
+
from torch import Tensor
|
22 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
23 |
+
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
|
24 |
+
|
25 |
+
import utils
|
26 |
+
import noise_schedule
|
27 |
+
|
28 |
+
LOG2 = math.log(2)
|
29 |
+
|
30 |
+
class CosineWarmup(_LRScheduler):
|
31 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
32 |
+
self.warmup_steps = warmup_steps
|
33 |
+
self.total_steps = total_steps
|
34 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
35 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
36 |
+
|
37 |
+
def get_lr(self):
|
38 |
+
if self.last_epoch < self.warmup_steps:
|
39 |
+
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
|
40 |
+
|
41 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
42 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
43 |
+
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
44 |
+
|
45 |
+
return [decayed_lr * base_lr for base_lr in self.base_lrs]
|
46 |
+
|
47 |
+
|
48 |
+
def _sample_categorical(categorical_probs):
|
49 |
+
gumbel_norm = (
|
50 |
+
1e-10
|
51 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
52 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
53 |
+
|
54 |
+
|
55 |
+
def _unsqueeze(x, reference):
|
56 |
+
return x.view(
|
57 |
+
* x.shape,
|
58 |
+
* ((1,) * (len(reference.shape) - len(x.shape))))
|
59 |
+
|
60 |
+
|
61 |
+
@dataclass
|
62 |
+
class Loss:
|
63 |
+
loss: torch.FloatTensor
|
64 |
+
nlls: torch.FloatTensor
|
65 |
+
token_mask: torch.FloatTensor
|
66 |
+
|
67 |
+
|
68 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
69 |
+
pass
|
70 |
+
|
71 |
+
|
72 |
+
class BPD(NLL):
|
73 |
+
def compute(self) -> Tensor:
|
74 |
+
"""Computes the bits per dimension.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
bpd
|
78 |
+
"""
|
79 |
+
return self.mean_value / self.weight / LOG2
|
80 |
+
|
81 |
+
|
82 |
+
class Perplexity(NLL):
|
83 |
+
def compute(self) -> Tensor:
|
84 |
+
"""Computes the Perplexity.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Perplexity
|
88 |
+
"""
|
89 |
+
return torch.exp(self.mean_value / self.weight)
|
90 |
+
|
91 |
+
|
92 |
+
class WrapVanillaESM(nn.Module):
|
93 |
+
def __init__(self, bert_model_path):
|
94 |
+
super(WrapVanillaESM, self).__init__()
|
95 |
+
#self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
96 |
+
#self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
|
97 |
+
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
|
98 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
|
99 |
+
|
100 |
+
|
101 |
+
def __call__(self, *args, **kwargs):
|
102 |
+
return self.model(*args, **kwargs)
|
103 |
+
|
104 |
+
def unfreeze_attn_layers(self):
|
105 |
+
model_layers = len(self.model.esm.encoder.layer)
|
106 |
+
|
107 |
+
for i, layer in enumerate(self.model.esm.encoder.layer):
|
108 |
+
if i >= model_layers-5: # fine-tune only last n layers
|
109 |
+
for module in layer.attention.self.key.modules():
|
110 |
+
for param in module.parameters():
|
111 |
+
param.requires_grad = True
|
112 |
+
for module in layer.attention.self.query.modules():
|
113 |
+
for param in module.parameters():
|
114 |
+
param.requires_grad = True
|
115 |
+
for module in layer.attention.self.value.modules():
|
116 |
+
for param in module.parameters():
|
117 |
+
param.requires_grad = True
|
118 |
+
|
119 |
+
def unfreeze_all_layers(self):
|
120 |
+
for param in self.model.parameters():
|
121 |
+
param.requires_grad = True
|
122 |
+
|
123 |
+
def forward(self, inputs, sigma, attention_mask):
|
124 |
+
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
|
125 |
+
return logits
|
126 |
+
|
127 |
+
def save_model(self, save_dir):
|
128 |
+
self.model.save_pretrained(save_dir)
|
129 |
+
self.tokenizer.save_pretrained(save_dir)
|
130 |
+
|
131 |
+
def load_model(self, load_dir):
|
132 |
+
self.model = AutoModel.from_pretrained(load_dir)
|
133 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
|
134 |
+
|
135 |
+
class WrapMembraneESM(nn.Module):
|
136 |
+
def __init__(self, bert_model_path):
|
137 |
+
super(WrapMembraneESM, self).__init__()
|
138 |
+
#self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
139 |
+
#self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
|
140 |
+
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
|
141 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
|
142 |
+
|
143 |
+
def __call__(self, *args, **kwargs):
|
144 |
+
return self.model(*args, **kwargs)
|
145 |
+
|
146 |
+
def freeze_model(self):
|
147 |
+
for param in self.model.parameters():
|
148 |
+
param.requires_grad = False
|
149 |
+
|
150 |
+
def unfreeze_all_layers(self):
|
151 |
+
for param in self.model.parameters():
|
152 |
+
param.requires_grad = True
|
153 |
+
|
154 |
+
def unfreeze_attn_layers(self):
|
155 |
+
model_layers = len(self.model.esm.encoder.layer)
|
156 |
+
|
157 |
+
for i, layer in enumerate(self.model.esm.encoder.layer):
|
158 |
+
if i >= model_layers-11: # fine-tune only last n layers
|
159 |
+
for module in layer.attention.self.key.modules():
|
160 |
+
for param in module.parameters():
|
161 |
+
param.requires_grad = True
|
162 |
+
for module in layer.attention.self.query.modules():
|
163 |
+
for param in module.parameters():
|
164 |
+
param.requires_grad = True
|
165 |
+
for module in layer.attention.self.value.modules():
|
166 |
+
for param in module.parameters():
|
167 |
+
param.requires_grad = True
|
168 |
+
|
169 |
+
def forward(self, inputs, sigma, attention_mask):
|
170 |
+
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
|
171 |
+
return logits
|
172 |
+
|
173 |
+
def save_model(self, save_dir):
|
174 |
+
self.model.save_pretrained(save_dir)
|
175 |
+
self.tokenizer.save_pretrained(save_dir)
|
176 |
+
|
177 |
+
def load_model(self, load_dir):
|
178 |
+
self.model = AutoModel.from_pretrained(load_dir)
|
179 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
|
180 |
+
|
181 |
+
class Diffusion(L.LightningModule):
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
config,
|
185 |
+
tokenizer: transformers.PreTrainedTokenizer):
|
186 |
+
super().__init__()
|
187 |
+
self.save_hyperparameters()
|
188 |
+
self.config = config
|
189 |
+
|
190 |
+
self.tokenizer = tokenizer
|
191 |
+
self.vocab_size = self.tokenizer.vocab_size
|
192 |
+
self.sampler = self.config.sampling.predictor
|
193 |
+
self.gen_ppl_eval_model_name_or_path = self.config.eval.\
|
194 |
+
gen_ppl_eval_model_name_or_path
|
195 |
+
self.antithetic_sampling = self.config.training.antithetic_sampling
|
196 |
+
self.importance_sampling = self.config.training.importance_sampling
|
197 |
+
self.change_of_variables = self.config.training.change_of_variables
|
198 |
+
if (not hasattr(self.tokenizer, 'mask_token')
|
199 |
+
or self.tokenizer.mask_token is None):
|
200 |
+
self.mask_index = self.vocab_size
|
201 |
+
self.vocab_size += 1
|
202 |
+
else:
|
203 |
+
self.mask_index = self.tokenizer.mask_token_id
|
204 |
+
self.parameterization = self.config.parameterization
|
205 |
+
|
206 |
+
|
207 |
+
# if self.config.backbone == 'dit':
|
208 |
+
# self.backbone = dit.DIT(
|
209 |
+
# self.config, vocab_size=self.vocab_size, mlm_model_path=config.training.mlm_model_path)
|
210 |
+
if self.config.backbone == "vanilla_esm_pretrain":
|
211 |
+
self.backbone = WrapVanillaESM(bert_model_path=self.config.training.esm_model_path)
|
212 |
+
self.backbone.unfreeze_all_layers()
|
213 |
+
self.backbone = torch.compile(self.backbone)
|
214 |
+
elif self.config.backbone == 'membrane_esm_finetune':
|
215 |
+
self.backbone = WrapMembraneESM(bert_model_path=self.config.checkpointing.pretrained_esm_mdlm_automodel_path)
|
216 |
+
self.backbone.unfreeze_all_layers()
|
217 |
+
# self.backbone = torch.compile(self.backbone)
|
218 |
+
|
219 |
+
# elif self.config.backbone == 'dimamba':
|
220 |
+
# self.backbone = dimamba.DiMamba(
|
221 |
+
# self.config,
|
222 |
+
# vocab_size=self.vocab_size,
|
223 |
+
# pad_token_id=self.tokenizer.pad_token_id)
|
224 |
+
# elif self.config.backbone == 'ar':
|
225 |
+
# self.backbone = autoregressive.AR(
|
226 |
+
# self.config,
|
227 |
+
# vocab_size=self.vocab_size,
|
228 |
+
# mask_index=self.mask_index)
|
229 |
+
# elif self.config.backbone == 'hf_dit':
|
230 |
+
# self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
|
231 |
+
# config.eval.checkpoint_path, trust_remote_code=True)
|
232 |
+
# else:
|
233 |
+
# raise ValueError(
|
234 |
+
# f'Unknown backbone: {self.config.backbone}')
|
235 |
+
|
236 |
+
self.T = self.config.T
|
237 |
+
self.subs_masking = self.config.subs_masking
|
238 |
+
|
239 |
+
self.softplus = torch.nn.Softplus()
|
240 |
+
# metrics are automatically reset at end of epoch
|
241 |
+
metrics = torchmetrics.MetricCollection({
|
242 |
+
'nll': NLL(),
|
243 |
+
'bpd': BPD(),
|
244 |
+
'ppl': Perplexity(),
|
245 |
+
})
|
246 |
+
metrics.set_dtype(torch.float64)
|
247 |
+
self.train_metrics = metrics.clone(prefix='train/')
|
248 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
249 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
250 |
+
|
251 |
+
# generative perplexity
|
252 |
+
self.gen_ppl_metric = Perplexity()
|
253 |
+
self.eval_model_tokenizer = transformers.AutoTokenizer.\
|
254 |
+
from_pretrained(self.gen_ppl_eval_model_name_or_path)
|
255 |
+
if self.eval_model_tokenizer.pad_token is None:
|
256 |
+
self.eval_model_tokenizer.pad_token =\
|
257 |
+
self.eval_model_tokenizer.eos_token
|
258 |
+
self.eval_model_tokenizer.pad_token_id =\
|
259 |
+
self.eval_model_tokenizer.eos_token_id
|
260 |
+
|
261 |
+
self.noise = noise_schedule.get_noise(self.config,
|
262 |
+
dtype=self.dtype)
|
263 |
+
if self.config.training.ema > 0:
|
264 |
+
self.ema = ema.ExponentialMovingAverage(
|
265 |
+
itertools.chain(self.backbone.parameters(),
|
266 |
+
self.noise.parameters()),
|
267 |
+
decay=self.config.training.ema)
|
268 |
+
else:
|
269 |
+
self.ema = None
|
270 |
+
|
271 |
+
self.lr = self.config.optim.lr
|
272 |
+
self.sampling_eps = self.config.training.sampling_eps
|
273 |
+
self.time_conditioning = self.config.time_conditioning
|
274 |
+
self.neg_infinity = -1000000.0
|
275 |
+
self.fast_forward_epochs = None
|
276 |
+
self.fast_forward_batches = None
|
277 |
+
self._validate_configuration()
|
278 |
+
|
279 |
+
def _validate_configuration(self):
|
280 |
+
assert not (self.change_of_variables
|
281 |
+
and self.importance_sampling)
|
282 |
+
if self.parameterization == 'sedd':
|
283 |
+
assert not self.importance_sampling
|
284 |
+
assert not self.change_of_variables
|
285 |
+
if self.parameterization == 'd3pm':
|
286 |
+
assert self.T > 0
|
287 |
+
if self.T > 0:
|
288 |
+
assert self.parameterization in {'d3pm', 'subs'}
|
289 |
+
if self.subs_masking:
|
290 |
+
assert self.parameterization == 'd3pm'
|
291 |
+
|
292 |
+
def on_load_checkpoint(self, checkpoint):
|
293 |
+
if self.ema:
|
294 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
295 |
+
# Copied from:
|
296 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
|
297 |
+
self.fast_forward_epochs = checkpoint['loops'][
|
298 |
+
'fit_loop']['epoch_progress']['current']['completed']
|
299 |
+
self.fast_forward_batches = checkpoint['loops'][
|
300 |
+
'fit_loop']['epoch_loop.batch_progress'][
|
301 |
+
'current']['completed']
|
302 |
+
|
303 |
+
def on_save_checkpoint(self, checkpoint):
|
304 |
+
if self.ema:
|
305 |
+
checkpoint['ema'] = self.ema.state_dict()
|
306 |
+
# Copied from:
|
307 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
|
308 |
+
# ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
|
309 |
+
# behind, so we're using the optimizer's progress.
|
310 |
+
checkpoint['loops']['fit_loop'][
|
311 |
+
'epoch_loop.batch_progress']['total'][
|
312 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
313 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
314 |
+
'optimizer']['step']['total'][
|
315 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
316 |
+
checkpoint['loops']['fit_loop'][
|
317 |
+
'epoch_loop.batch_progress']['current'][
|
318 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
319 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
320 |
+
'optimizer']['step']['current'][
|
321 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
322 |
+
# _batches_that_stepped tracks the number of global steps, not the number
|
323 |
+
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
|
324 |
+
checkpoint['loops']['fit_loop'][
|
325 |
+
'epoch_loop.state_dict'][
|
326 |
+
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
|
327 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
328 |
+
'optimizer']['step']['total']['completed']
|
329 |
+
if 'sampler' not in checkpoint.keys():
|
330 |
+
checkpoint['sampler'] = {}
|
331 |
+
if hasattr(self.trainer.train_dataloader.sampler,
|
332 |
+
'state_dict'):
|
333 |
+
sampler_state_dict = self.trainer.\
|
334 |
+
train_dataloader.sampler.state_dict()
|
335 |
+
checkpoint['sampler'][
|
336 |
+
'random_state'] = sampler_state_dict.get(
|
337 |
+
'random_state', None)
|
338 |
+
else:
|
339 |
+
checkpoint['sampler']['random_state'] = None
|
340 |
+
|
341 |
+
self.backbone.save_model(self.config.checkpointing.fine_tuned_esm_mdlm_ckpt_path)
|
342 |
+
|
343 |
+
def on_train_start(self):
|
344 |
+
torch.cuda.empty_cache()
|
345 |
+
if self.ema:
|
346 |
+
self.ema.move_shadow_params_to_device(self.device)
|
347 |
+
|
348 |
+
# Adapted from:
|
349 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
|
350 |
+
distributed = (
|
351 |
+
self.trainer._accelerator_connector.use_distributed_sampler
|
352 |
+
and self.trainer._accelerator_connector.is_distributed)
|
353 |
+
if distributed:
|
354 |
+
sampler_cls = dataloader.FaultTolerantDistributedSampler
|
355 |
+
else:
|
356 |
+
sampler_cls = dataloader.RandomFaultTolerantSampler
|
357 |
+
updated_dls = []
|
358 |
+
for dl in self.trainer.fit_loop._combined_loader.flattened:
|
359 |
+
if hasattr(dl.sampler, 'shuffle'):
|
360 |
+
dl_sampler = sampler_cls(
|
361 |
+
dl.dataset, shuffle=dl.sampler.shuffle)
|
362 |
+
else:
|
363 |
+
dl_sampler = sampler_cls(dl.dataset)
|
364 |
+
if (distributed
|
365 |
+
and self.fast_forward_epochs is not None
|
366 |
+
and self.fast_forward_batches is not None):
|
367 |
+
dl_sampler.load_state_dict({
|
368 |
+
'epoch': self.fast_forward_epochs,
|
369 |
+
'counter': (self.fast_forward_batches
|
370 |
+
* self.config.loader.batch_size)})
|
371 |
+
|
372 |
+
from functools import partial
|
373 |
+
from pl_data_loader import collate_fn
|
374 |
+
collate_partial = partial(collate_fn, tokenizer=self.tokenizer)
|
375 |
+
torch.cuda.empty_cache()
|
376 |
+
|
377 |
+
updated_dls.append(
|
378 |
+
torch.utils.data.DataLoader(
|
379 |
+
dl.dataset,
|
380 |
+
batch_size=self.config.loader.batch_size,
|
381 |
+
num_workers=self.config.loader.num_workers,
|
382 |
+
pin_memory=self.config.loader.pin_memory,
|
383 |
+
sampler=dl_sampler,
|
384 |
+
shuffle=False,
|
385 |
+
persistent_workers=False,
|
386 |
+
collate_fn=collate_partial))
|
387 |
+
self.trainer.fit_loop._combined_loader.flattened = updated_dls
|
388 |
+
|
389 |
+
def optimizer_step(self, *args, **kwargs):
|
390 |
+
super().optimizer_step(*args, **kwargs)
|
391 |
+
|
392 |
+
gc.collect()
|
393 |
+
torch.cuda.empty_cache()
|
394 |
+
|
395 |
+
if self.ema:
|
396 |
+
self.ema.update(itertools.chain(
|
397 |
+
self.backbone.parameters(),
|
398 |
+
self.noise.parameters()))
|
399 |
+
|
400 |
+
# optimizer_closure = kwargs.get('optimizer_closure', None)
|
401 |
+
|
402 |
+
# params_with_grad = [p for p in itertools.chain(
|
403 |
+
# self.backbone.parameters(),
|
404 |
+
# self.noise.parameters()
|
405 |
+
# ) if p.requires_grad and p.grad_fn is not None]
|
406 |
+
|
407 |
+
# # if params_with_grad:
|
408 |
+
# # super().optimizer_step(closure=optimizer_closure)
|
409 |
+
|
410 |
+
# if self.ema:
|
411 |
+
# self.ema.update(params_with_grad)
|
412 |
+
|
413 |
+
# super().optimizer_step(*args, **kwargs)
|
414 |
+
|
415 |
+
def _subs_parameterization(self, logits, xt):
|
416 |
+
# log prob at the mask index = - infinity
|
417 |
+
logits = logits.logits
|
418 |
+
logits[:, :, self.mask_index] += self.neg_infinity
|
419 |
+
# logits[:, :, self.tokenizer.eos_token_id] += self.neg_infinity
|
420 |
+
# logits[:, :, self.tokenizer.cls_token_id] += self.neg_infinity
|
421 |
+
|
422 |
+
# Normalize the logits such that x.exp() is
|
423 |
+
# a probability distribution over vocab_size.
|
424 |
+
logits = logits - torch.logsumexp(logits, dim=-1,
|
425 |
+
keepdim=True)
|
426 |
+
|
427 |
+
# Apply updates directly in the logits matrix.
|
428 |
+
# For the logits of the unmasked tokens, set all values
|
429 |
+
# to -infinity except for the indices corresponding to
|
430 |
+
# the unmasked tokens.
|
431 |
+
unmasked_indices = (xt != self.mask_index)
|
432 |
+
logits[unmasked_indices] = self.neg_infinity
|
433 |
+
logits[unmasked_indices, xt[unmasked_indices]] = 0
|
434 |
+
return logits
|
435 |
+
|
436 |
+
def _d3pm_parameterization(self, logits):
|
437 |
+
if self.subs_masking:
|
438 |
+
logits[:, :, self.mask_index] += self.neg_infinity
|
439 |
+
logits = logits - torch.logsumexp(logits, dim=-1,
|
440 |
+
keepdim=True)
|
441 |
+
return logits
|
442 |
+
|
443 |
+
def _sedd_parameterization(self, logits, xt, sigma):
|
444 |
+
esigm1_log = torch.where(
|
445 |
+
sigma < 0.5,
|
446 |
+
torch.expm1(sigma),
|
447 |
+
sigma.exp() - 1).log().to(logits.dtype)
|
448 |
+
# logits shape
|
449 |
+
# (batch_size, diffusion_model_input_length, vocab_size)
|
450 |
+
logits = logits - esigm1_log[:, None, None] - np.log(
|
451 |
+
logits.shape[-1] - 1)
|
452 |
+
# The below scatter operation sets the log score
|
453 |
+
# for the input word to 0.
|
454 |
+
logits = torch.scatter(logits, -1, xt[..., None],
|
455 |
+
torch.zeros_like(logits[..., :1]))
|
456 |
+
return logits
|
457 |
+
|
458 |
+
def _process_sigma(self, sigma):
|
459 |
+
if sigma is None:
|
460 |
+
assert self.parameterization == 'ar'
|
461 |
+
return sigma
|
462 |
+
if sigma.ndim > 1:
|
463 |
+
sigma = sigma.squeeze(-1)
|
464 |
+
if not self.time_conditioning:
|
465 |
+
sigma = torch.zeros_like(sigma)
|
466 |
+
assert sigma.ndim == 1, sigma.shape
|
467 |
+
return sigma
|
468 |
+
|
469 |
+
def forward(self, x, sigma, attention_mask, print_logits=False):
|
470 |
+
"""Returns log score."""
|
471 |
+
sigma = self._process_sigma(sigma)
|
472 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
473 |
+
logits = self.backbone(x, attention_mask)
|
474 |
+
# if print_logits:
|
475 |
+
# torch.set_printoptions(profile="full")
|
476 |
+
# print(logits)
|
477 |
+
# torch.set_printoptions(profile="default")
|
478 |
+
if self.parameterization == 'subs':
|
479 |
+
return self._subs_parameterization(logits=logits, xt=x)
|
480 |
+
return logits
|
481 |
+
|
482 |
+
def _d3pm_loss(self, model_output, xt, x0, t, attention_mask):
|
483 |
+
dt = 1 / self.T
|
484 |
+
|
485 |
+
if torch.is_tensor(t):
|
486 |
+
t = t[:, None]
|
487 |
+
assert t.ndim == 2
|
488 |
+
t = t.clamp(0., 1. - 1e-4)
|
489 |
+
alpha_t = 1 - t + torch.zeros_like(xt)
|
490 |
+
alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
|
491 |
+
|
492 |
+
log_x_theta_at_x0 = torch.gather(
|
493 |
+
model_output, -1, x0[:, :, None]).squeeze(-1)
|
494 |
+
log_x_theta_at_m = model_output[:, :, self.mask_index]
|
495 |
+
x_theta_at_m = log_x_theta_at_m.exp()
|
496 |
+
|
497 |
+
term_1_coef = dt / t
|
498 |
+
term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
|
499 |
+
term_1_log_dr = log_x_theta_at_x0
|
500 |
+
|
501 |
+
term_2_coef = 1 - dt / t
|
502 |
+
term_2_log_nr = term_1_log_nr
|
503 |
+
term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
|
504 |
+
|
505 |
+
L_vb_masked = (
|
506 |
+
term_1_coef * (term_1_log_nr - term_1_log_dr)
|
507 |
+
+ term_2_coef * (term_2_log_nr - term_2_log_dr))
|
508 |
+
|
509 |
+
L_vb = L_vb_masked * (xt == self.mask_index)
|
510 |
+
|
511 |
+
return self.T * L_vb
|
512 |
+
|
513 |
+
def _compute_loss(self, batch, prefix):
|
514 |
+
if 'attention_mask' in batch:
|
515 |
+
attention_mask = batch['attention_mask']
|
516 |
+
else:
|
517 |
+
attention_mask = None
|
518 |
+
if 'mask' in batch: mask = batch['mask']
|
519 |
+
else: mask = None
|
520 |
+
|
521 |
+
losses = self._loss(batch['input_ids'], attention_mask, mask)
|
522 |
+
loss = losses.loss
|
523 |
+
|
524 |
+
if prefix == 'train':
|
525 |
+
self.train_metrics.update(losses.nlls, losses.token_mask)
|
526 |
+
metrics = self.train_metrics
|
527 |
+
elif prefix == 'val':
|
528 |
+
self.valid_metrics.update(losses.nlls, losses.token_mask)
|
529 |
+
metrics = self.valid_metrics
|
530 |
+
elif prefix == 'test':
|
531 |
+
self.test_metrics.update(losses.nlls, losses.token_mask)
|
532 |
+
metrics = self.test_metrics
|
533 |
+
else:
|
534 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
535 |
+
|
536 |
+
self.log_dict(metrics,
|
537 |
+
on_step=False,
|
538 |
+
on_epoch=True,
|
539 |
+
sync_dist=True)
|
540 |
+
return loss
|
541 |
+
|
542 |
+
def on_train_epoch_start(self):
|
543 |
+
self.backbone.train()
|
544 |
+
self.noise.train()
|
545 |
+
|
546 |
+
def training_step(self, batch, batch_idx):
|
547 |
+
# Initialize throughput calculation
|
548 |
+
start_time = time.time()
|
549 |
+
|
550 |
+
loss = self._compute_loss(batch, prefix='train')
|
551 |
+
self.log(name='trainer/loss',
|
552 |
+
value=loss.item(),
|
553 |
+
on_step=True,
|
554 |
+
on_epoch=False,
|
555 |
+
sync_dist=True)
|
556 |
+
|
557 |
+
# Calculate throughput
|
558 |
+
elapsed_time = time.time() - start_time
|
559 |
+
total_tokens = batch['input_ids'].numel()
|
560 |
+
throughput = total_tokens / elapsed_time
|
561 |
+
|
562 |
+
self.log(name='trainer/throughput',
|
563 |
+
value=throughput,
|
564 |
+
on_step=True,
|
565 |
+
on_epoch=False,
|
566 |
+
sync_dist=True)
|
567 |
+
|
568 |
+
return loss
|
569 |
+
|
570 |
+
def on_validation_epoch_start(self):
|
571 |
+
# params_with_grad = [p for p in itertools.chain(
|
572 |
+
# self.backbone.parameters(),
|
573 |
+
# self.noise.parameters()
|
574 |
+
# ) if p.requires_grad]
|
575 |
+
# if self.ema:
|
576 |
+
# self.ema.store(params_with_grad)
|
577 |
+
# self.ema.copy_to(params_with_grad)
|
578 |
+
|
579 |
+
gc.collect()
|
580 |
+
torch.cuda.empty_cache()
|
581 |
+
if self.ema:
|
582 |
+
self.ema.store(
|
583 |
+
itertools.chain(
|
584 |
+
self.backbone.parameters(),
|
585 |
+
self.noise.parameters()))
|
586 |
+
self.ema.copy_to(itertools.chain(
|
587 |
+
self.backbone.parameters(),
|
588 |
+
self.noise.parameters()))
|
589 |
+
self.backbone.eval()
|
590 |
+
self.noise.eval()
|
591 |
+
assert self.valid_metrics.nll.mean_value == 0
|
592 |
+
assert self.valid_metrics.nll.weight == 0
|
593 |
+
|
594 |
+
|
595 |
+
def validation_step(self, batch, batch_idx):
|
596 |
+
loss = self._compute_loss(batch, prefix='val')
|
597 |
+
self.log(name='trainer/val_loss',
|
598 |
+
value=loss.item(),
|
599 |
+
on_step=True,
|
600 |
+
on_epoch=False,
|
601 |
+
prog_bar=True,
|
602 |
+
sync_dist=True)
|
603 |
+
return loss
|
604 |
+
|
605 |
+
def on_validation_epoch_end(self):
|
606 |
+
# params_with_grad = [p for p in itertools.chain(
|
607 |
+
# self.backbone.parameters(),
|
608 |
+
# self.noise.parameters()
|
609 |
+
# ) if p.requires_grad]
|
610 |
+
# if ((self.config.eval.compute_perplexity_on_sanity
|
611 |
+
# or not self.trainer.sanity_checking)
|
612 |
+
# and self.config.eval.generate_samples
|
613 |
+
# and not self.parameterization == 'ar'):
|
614 |
+
# # (justin): implement sampling and kv cache for AR
|
615 |
+
# samples, text_samples = None, None
|
616 |
+
# for _ in range(
|
617 |
+
# self.config.sampling.num_sample_batches):
|
618 |
+
# samples = self._sample()
|
619 |
+
# # Decode the samples to be re-tokenized by eval model
|
620 |
+
# text_samples = self.tokenizer.batch_decode(samples)
|
621 |
+
# if self.config.eval.compute_generative_perplexity:
|
622 |
+
# self.compute_generative_perplexity(text_samples)
|
623 |
+
# if self.trainer.global_rank == 0 and hasattr(
|
624 |
+
# self.trainer.logger, 'log_table'):
|
625 |
+
# # Log the last generated samples
|
626 |
+
# text_samples = text_samples[
|
627 |
+
# : self.config.sampling.num_sample_log]
|
628 |
+
# self.trainer.logger.log_table(
|
629 |
+
# key=f'samples@global_step{self.global_step}',
|
630 |
+
# columns=['Generated Samples'],
|
631 |
+
# data=[[s] for s in text_samples])
|
632 |
+
# if self.config.eval.compute_generative_perplexity:
|
633 |
+
# self.log('val/gen_ppl',
|
634 |
+
# self.gen_ppl_metric,
|
635 |
+
# on_epoch=True,
|
636 |
+
# on_step=False,
|
637 |
+
# sync_dist=True)
|
638 |
+
|
639 |
+
gc.collect()
|
640 |
+
torch.cuda.empty_cache()
|
641 |
+
if self.ema:
|
642 |
+
self.ema.restore(
|
643 |
+
itertools.chain(
|
644 |
+
self.backbone.parameters(),
|
645 |
+
self.noise.parameters()))
|
646 |
+
|
647 |
+
def test_step(self, batch, batch_idx):
|
648 |
+
loss = self._compute_loss(batch, prefix='test')
|
649 |
+
self.log('test/loss',
|
650 |
+
value=loss.item(),
|
651 |
+
on_step=False,
|
652 |
+
on_epoch=True,
|
653 |
+
sync_dist=True)
|
654 |
+
|
655 |
+
if self.config.eval.compute_generative_perplexity:
|
656 |
+
samples, text_samples = None, None
|
657 |
+
for _ in range(
|
658 |
+
self.config.sampling.num_sample_batches):
|
659 |
+
samples = self._sample()
|
660 |
+
# Decode the samples to be re-tokenized by eval model
|
661 |
+
text_samples = self.tokenizer.batch_decode(samples)
|
662 |
+
if self.config.eval.compute_generative_perplexity:
|
663 |
+
self.compute_generative_perplexity(text_samples)
|
664 |
+
if self.trainer.global_rank == 0 and hasattr(
|
665 |
+
self.trainer.logger, 'log_table'):
|
666 |
+
# Log the last generated samples
|
667 |
+
text_samples = text_samples[
|
668 |
+
: self.config.sampling.num_sample_log]
|
669 |
+
self.trainer.logger.log_table(
|
670 |
+
key=f'samples@global_step{self.global_step}',
|
671 |
+
columns=['Generated Samples'],
|
672 |
+
data=[[s] for s in text_samples])
|
673 |
+
if self.config.eval.compute_generative_perplexity:
|
674 |
+
self.log('test/gen_ppl',
|
675 |
+
self.gen_ppl_metric,
|
676 |
+
on_epoch=False,
|
677 |
+
on_step=True,
|
678 |
+
sync_dist=True)
|
679 |
+
|
680 |
+
def on_test_epoch_start(self):
|
681 |
+
# params_with_grad = [p for p in itertools.chain(
|
682 |
+
# self.backbone.parameters(),
|
683 |
+
# self.noise.parameters()
|
684 |
+
# ) if p.requires_grad]
|
685 |
+
|
686 |
+
if self.ema:
|
687 |
+
self.ema.store(itertools.chain(
|
688 |
+
self.backbone.parameters(),
|
689 |
+
self.noise.parameters()))
|
690 |
+
self.ema.copy_to(itertools.chain(
|
691 |
+
self.backbone.parameters(),
|
692 |
+
self.noise.parameters()))
|
693 |
+
|
694 |
+
self.backbone.eval()
|
695 |
+
self.noise.eval()
|
696 |
+
self.test_metrics.reset()
|
697 |
+
|
698 |
+
def on_test_epoch_end(self):
|
699 |
+
# params_with_grad = [p for p in itertools.chain(
|
700 |
+
# self.backbone.parameters(),
|
701 |
+
# self.noise.parameters()
|
702 |
+
# ) if p.requires_grad]
|
703 |
+
|
704 |
+
if self.ema:
|
705 |
+
self.ema.restore(itertools.chain(
|
706 |
+
self.backbone.parameters(),
|
707 |
+
self.noise.parameters()))
|
708 |
+
|
709 |
+
for metric_name, metric_value in self.test_metrics.compute().items():
|
710 |
+
self.log(metric_name, metric_value, sync_dist=True)
|
711 |
+
|
712 |
+
def configure_optimizers(self):
|
713 |
+
# (yair): Lightning currently giving this warning when using `fp16`:
|
714 |
+
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
|
715 |
+
# Not clear if this is a problem or not.
|
716 |
+
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
|
717 |
+
|
718 |
+
# params_with_grad = [p for p in itertools.chain(
|
719 |
+
# self.backbone.parameters(),
|
720 |
+
# self.noise.parameters()
|
721 |
+
# ) if p.requires_grad]
|
722 |
+
|
723 |
+
optimizer = torch.optim.AdamW(
|
724 |
+
itertools.chain(self.backbone.parameters(),
|
725 |
+
self.noise.parameters()),
|
726 |
+
lr=self.config.optim.lr,
|
727 |
+
betas=(self.config.optim.beta1,
|
728 |
+
self.config.optim.beta2),
|
729 |
+
eps=self.config.optim.eps,
|
730 |
+
weight_decay=self.config.optim.weight_decay
|
731 |
+
)
|
732 |
+
|
733 |
+
# scheduler = hydra.utils.instantiate(
|
734 |
+
# self.config.lr_scheduler, optimizer=optimizer)
|
735 |
+
# scheduler_dict = {
|
736 |
+
# 'scheduler': scheduler,
|
737 |
+
# 'interval': 'step',
|
738 |
+
# 'monitor': 'val/loss',
|
739 |
+
# 'name': 'trainer/lr',
|
740 |
+
# }
|
741 |
+
|
742 |
+
self.total_steps = self.config.trainer.max_steps
|
743 |
+
scheduler = CosineWarmup(optimizer,
|
744 |
+
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
|
745 |
+
total_steps=self.total_steps)
|
746 |
+
|
747 |
+
scheduler_dict = {
|
748 |
+
'scheduler': scheduler,
|
749 |
+
'interval': 'step',
|
750 |
+
'frequency': 1,
|
751 |
+
'monitor': 'val/loss',
|
752 |
+
'name': 'trainer/lr'
|
753 |
+
}
|
754 |
+
|
755 |
+
return [optimizer], [scheduler_dict]
|
756 |
+
|
757 |
+
@torch.no_grad()
|
758 |
+
def eval_retokenize(self, text_samples, max_length):
|
759 |
+
"""Retokenizes samples for the eval model.
|
760 |
+
|
761 |
+
Args:
|
762 |
+
text_samples: List of sentences generated by the model.
|
763 |
+
Returns:
|
764 |
+
samples: Samples re-tokenized for the eval model
|
765 |
+
attn_mask: Attention mask for the eval model
|
766 |
+
eval_context_size: Size of the context for the eval model
|
767 |
+
"""
|
768 |
+
if 'llama2' in self.gen_ppl_eval_model_name_or_path:
|
769 |
+
tokenizer_kwargs = {
|
770 |
+
'text_samples': text_samples,
|
771 |
+
'return_tensors': 'pt',
|
772 |
+
'return_token_type_ids': False,
|
773 |
+
'return_attention_mask': True,
|
774 |
+
'truncation': True,
|
775 |
+
'padding': True,
|
776 |
+
'max_length': max_length,
|
777 |
+
}
|
778 |
+
eval_context_size = 4096
|
779 |
+
else:
|
780 |
+
tokenizer_kwargs = {
|
781 |
+
'return_tensors': 'pt',
|
782 |
+
'return_token_type_ids': False,
|
783 |
+
'return_attention_mask': True,
|
784 |
+
'truncation': True,
|
785 |
+
'padding': True,
|
786 |
+
'max_length': max_length,
|
787 |
+
}
|
788 |
+
eval_context_size = 1024
|
789 |
+
samples = self.eval_model_tokenizer(
|
790 |
+
text_samples, ** tokenizer_kwargs)
|
791 |
+
attn_mask = samples['attention_mask']
|
792 |
+
samples = samples['input_ids']
|
793 |
+
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
|
794 |
+
attn_mask = attn_mask.to(self.device)
|
795 |
+
samples = samples.to(self.device)
|
796 |
+
return samples, attn_mask, eval_context_size
|
797 |
+
|
798 |
+
# @torch.no_grad()
|
799 |
+
# def compute_generative_perplexity(
|
800 |
+
# self,
|
801 |
+
# text_samples: typing.List[str],
|
802 |
+
# retokenize: bool = True,
|
803 |
+
# max_length: typing.Optional[int] = None) -> None:
|
804 |
+
# """Compute the generative perplexity of the model.
|
805 |
+
|
806 |
+
# Args:
|
807 |
+
# text_samples: List of sentences generated by the model.
|
808 |
+
|
809 |
+
# Returns:
|
810 |
+
# Perplexity of the generated text under a different
|
811 |
+
# pre-trained AR model (e.g., GPT2).
|
812 |
+
# """
|
813 |
+
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
814 |
+
# eval_model = transformers.AutoModelForCausalLM.from_pretrained(
|
815 |
+
# self.gen_ppl_eval_model_name_or_path).eval()
|
816 |
+
# if max_length is None:
|
817 |
+
# max_length = self.config.model.length
|
818 |
+
# if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
|
819 |
+
# eval_model = eval_model.to(self.device)
|
820 |
+
# # Re-tokenize using eval model's tokenizer
|
821 |
+
# if retokenize:
|
822 |
+
# (samples, attn_mask,
|
823 |
+
# eval_context_size) = self.eval_retokenize(
|
824 |
+
# text_samples, max_length=max_length)
|
825 |
+
# else:
|
826 |
+
# samples = text_samples
|
827 |
+
# attn_mask = torch.ones(samples.shape).to(self.device)
|
828 |
+
# eval_context_size = samples.shape[-1]
|
829 |
+
# batch_size = min(
|
830 |
+
# self.config.eval.perplexity_batch_size,
|
831 |
+
# samples.shape[0])
|
832 |
+
# num_batches = samples.shape[0] // batch_size
|
833 |
+
# for i in range(num_batches):
|
834 |
+
# _samples = torch.split(
|
835 |
+
# samples[i * batch_size: (i + 1) * batch_size],
|
836 |
+
# eval_context_size,
|
837 |
+
# dim=-1)
|
838 |
+
# _attn_mask = torch.split(
|
839 |
+
# attn_mask[i * batch_size: (i + 1) * batch_size],
|
840 |
+
# eval_context_size,
|
841 |
+
# dim=-1)
|
842 |
+
# for (sample_chunk, attn_mask_chunk) in zip(
|
843 |
+
# _samples, _attn_mask):
|
844 |
+
# logits = eval_model(
|
845 |
+
# sample_chunk, attention_mask=attn_mask_chunk)[0]
|
846 |
+
# logits = logits.transpose(-1, -2)
|
847 |
+
|
848 |
+
# nlls = F.cross_entropy(logits[..., :-1],
|
849 |
+
# sample_chunk[..., 1:],
|
850 |
+
# reduction='none')
|
851 |
+
# first_eos = (sample_chunk == self.eval_model_tokenizer\
|
852 |
+
# .eos_token_id).cumsum(-1) == 1
|
853 |
+
# token_mask = (
|
854 |
+
# sample_chunk
|
855 |
+
# != self.eval_model_tokenizer.eos_token_id)
|
856 |
+
# self.gen_ppl_metric.update(
|
857 |
+
# nlls, first_eos[..., 1:] + token_mask[..., 1:])
|
858 |
+
|
859 |
+
|
860 |
+
@torch.no_grad()
|
861 |
+
def compute_masked_perplexity(self, sequences, masked):
|
862 |
+
"""Compute the pseudo-perplexity of the generated protein sequences."""
|
863 |
+
total_nll = 0
|
864 |
+
total_tokens = 0
|
865 |
+
|
866 |
+
for sequence in sequences:
|
867 |
+
# Tokenize the sequence
|
868 |
+
input_ids = self.tokenizer(masked, return_tensors="pt").input_ids.to(self.device)
|
869 |
+
gt_ids = self.tokenizer(sequence.upper(), return_tensors="pt").input_ids.to(self.device)
|
870 |
+
|
871 |
+
# print(input_ids.shape)
|
872 |
+
# print(gt_ids.shape)
|
873 |
+
|
874 |
+
# Forward pass through the ESM model
|
875 |
+
attention_mask = torch.ones_like(input_ids)
|
876 |
+
if self.config.mode in ['train', 'ppl_eval']:
|
877 |
+
outputs = self.backbone.model.forward(input_ids=input_ids, attention_mask=attention_mask)
|
878 |
+
elif self.config.mode == "sample_eval":
|
879 |
+
outputs = self.backbone.model.forward(input_ids)
|
880 |
+
logits = outputs[-1] # B, L, V
|
881 |
+
|
882 |
+
# Compute loss
|
883 |
+
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
|
884 |
+
# shift_labels = input_ids[:, 1:].contiguous()
|
885 |
+
# print(masked)
|
886 |
+
# print(gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1))
|
887 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
|
888 |
+
gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1),
|
889 |
+
reduction='sum')
|
890 |
+
|
891 |
+
total_nll += loss.item()
|
892 |
+
#total_tokens += (input_ids != self.tokenizer.pad_token_id).sum().item() - 1 # -1 for the first token
|
893 |
+
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
|
894 |
+
# Compute pseudo-perplexity
|
895 |
+
# print(total_nll, ",;,", total_tokens)
|
896 |
+
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
|
897 |
+
self.gen_ppl_metric.update(pseudo_perplexity)
|
898 |
+
|
899 |
+
return pseudo_perplexity.item()
|
900 |
+
|
901 |
+
@torch.no_grad()
|
902 |
+
def compute_generative_perplexity(
|
903 |
+
self,
|
904 |
+
text_samples: typing.List[str],
|
905 |
+
retokenize: bool = True,
|
906 |
+
max_length: typing.Optional[int] = None) -> None:
|
907 |
+
"""Compute the generative perplexity of the model.
|
908 |
+
|
909 |
+
Args:
|
910 |
+
text_samples: List of sentences generated by the model.
|
911 |
+
|
912 |
+
Returns:
|
913 |
+
Perplexity of the generated text under a different
|
914 |
+
pre-trained AR model (e.g., GPT2).
|
915 |
+
"""
|
916 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
917 |
+
eval_model = transformers.AutoModelForCausalLM.from_pretrained(
|
918 |
+
self.gen_ppl_eval_model_name_or_path).eval()
|
919 |
+
if max_length is None:
|
920 |
+
max_length = self.config.model.length
|
921 |
+
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
|
922 |
+
eval_model = eval_model.to(self.device)
|
923 |
+
# Re-tokenize using eval model's tokenizer
|
924 |
+
if retokenize:
|
925 |
+
(samples, attn_mask,
|
926 |
+
eval_context_size) = self.eval_retokenize(
|
927 |
+
text_samples, max_length=max_length)
|
928 |
+
else:
|
929 |
+
samples = text_samples
|
930 |
+
attn_mask = torch.ones(samples.shape).to(self.device)
|
931 |
+
eval_context_size = samples.shape[-1]
|
932 |
+
batch_size = min(
|
933 |
+
self.config.eval.perplexity_batch_size,
|
934 |
+
samples.shape[0])
|
935 |
+
num_batches = samples.shape[0] // batch_size
|
936 |
+
for i in range(num_batches):
|
937 |
+
_samples = torch.split(
|
938 |
+
samples[i * batch_size: (i + 1) * batch_size],
|
939 |
+
eval_context_size,
|
940 |
+
dim=-1)
|
941 |
+
_attn_mask = torch.split(
|
942 |
+
attn_mask[i * batch_size: (i + 1) * batch_size],
|
943 |
+
eval_context_size,
|
944 |
+
dim=-1)
|
945 |
+
for (sample_chunk, attn_mask_chunk) in zip(
|
946 |
+
_samples, _attn_mask):
|
947 |
+
logits = eval_model(
|
948 |
+
sample_chunk, attention_mask=attn_mask_chunk)[0]
|
949 |
+
logits = logits.transpose(-1, -2)
|
950 |
+
|
951 |
+
nlls = F.cross_entropy(logits[..., :-1],
|
952 |
+
sample_chunk[..., 1:],
|
953 |
+
reduction='none')
|
954 |
+
first_eos = (sample_chunk == self.eval_model_tokenizer\
|
955 |
+
.eos_token_id).cumsum(-1) == 1
|
956 |
+
token_mask = (
|
957 |
+
sample_chunk
|
958 |
+
!= self.eval_model_tokenizer.eos_token_id)
|
959 |
+
self.gen_ppl_metric.update(
|
960 |
+
nlls, first_eos[..., 1:] + token_mask[..., 1:])
|
961 |
+
|
962 |
+
def q_xt(self, x, move_chance):
|
963 |
+
"""Computes the noisy sample xt.
|
964 |
+
|
965 |
+
Args:
|
966 |
+
x: int torch.Tensor with shape (batch_size,
|
967 |
+
diffusion_model_input_length), input.
|
968 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
969 |
+
"""
|
970 |
+
|
971 |
+
actual_seq_length = (x != 1).sum(dim=1, keepdim=True)
|
972 |
+
|
973 |
+
max_mask_length = (actual_seq_length * 0.75).long()
|
974 |
+
|
975 |
+
move_indices = torch.rand(*x.shape, device=x.device) < move_chance
|
976 |
+
|
977 |
+
restricted_move_indices = torch.zeros_like(move_indices, dtype=torch.bool)
|
978 |
+
|
979 |
+
for i in range(x.shape[0]):
|
980 |
+
true_positions = torch.where(move_indices[i])[0]
|
981 |
+
if len(true_positions) > max_mask_length[i]:
|
982 |
+
selected_positions = true_positions[:max_mask_length[i].item()]
|
983 |
+
restricted_move_indices[i, selected_positions] = True
|
984 |
+
else:
|
985 |
+
restricted_move_indices[i] = move_indices[i]
|
986 |
+
xt = torch.where(restricted_move_indices, self.mask_index, x)
|
987 |
+
|
988 |
+
return xt
|
989 |
+
|
990 |
+
def _sample_prior(self, *batch_dims):
|
991 |
+
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
|
992 |
+
|
993 |
+
def _ddpm_caching_update(self, x, t, dt, p_x0=None, attention_mask=None):
|
994 |
+
assert self.config.noise.type == 'loglinear'
|
995 |
+
sigma_t, _ = self.noise(t)
|
996 |
+
if t.ndim > 1:
|
997 |
+
t = t.squeeze(-1)
|
998 |
+
assert t.ndim == 1
|
999 |
+
move_chance_t = t[:, None, None]
|
1000 |
+
move_chance_s = (t - dt)[:, None, None]
|
1001 |
+
assert move_chance_t.ndim == 3, move_chance_t.shape
|
1002 |
+
if p_x0 is None:
|
1003 |
+
p_x0 = self.forward(x, sigma_t, attention_mask).exp()
|
1004 |
+
|
1005 |
+
assert move_chance_t.ndim == p_x0.ndim
|
1006 |
+
q_xs = p_x0 * (move_chance_t - move_chance_s)
|
1007 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
1008 |
+
_x = _sample_categorical(q_xs)
|
1009 |
+
|
1010 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
1011 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * _x
|
1012 |
+
|
1013 |
+
def _ddpm_update(self, x, t, dt, attention_mask):
|
1014 |
+
sigma_t, _ = self.noise(t)
|
1015 |
+
sigma_s, _ = self.noise(t - dt)
|
1016 |
+
if sigma_t.ndim > 1:
|
1017 |
+
sigma_t = sigma_t.squeeze(-1)
|
1018 |
+
if sigma_s.ndim > 1:
|
1019 |
+
sigma_s = sigma_s.squeeze(-1)
|
1020 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
1021 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
1022 |
+
move_chance_t = 1 - torch.exp(-sigma_t)
|
1023 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
1024 |
+
move_chance_t = move_chance_t[:, None, None]
|
1025 |
+
move_chance_s = move_chance_s[:, None, None]
|
1026 |
+
unet_conditioning = sigma_t
|
1027 |
+
log_p_x0 = self.forward(x, unet_conditioning, attention_mask)
|
1028 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
1029 |
+
# Technically, this isn't q_xs since there's a division
|
1030 |
+
# term that is missing. This division term doesn't affect
|
1031 |
+
# the samples.
|
1032 |
+
q_xs = log_p_x0.exp() * (move_chance_t
|
1033 |
+
- move_chance_s)
|
1034 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
1035 |
+
_x = _sample_categorical(q_xs)
|
1036 |
+
|
1037 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
1038 |
+
return copy_flag * x + (1 - copy_flag) * _x
|
1039 |
+
|
1040 |
+
def _ar_sampler(self, bsz):
|
1041 |
+
# precompute token buffer
|
1042 |
+
num_pred_tokens = self.config.model.length - 1
|
1043 |
+
x = torch.zeros(
|
1044 |
+
(bsz, num_pred_tokens + 1),
|
1045 |
+
dtype=torch.long,
|
1046 |
+
device=self.device)
|
1047 |
+
x[:, 0] = self.tokenizer.bos_token_id
|
1048 |
+
# precompute noise
|
1049 |
+
noise = (torch.distributions.Gumbel(0, 1)
|
1050 |
+
.sample((bsz, num_pred_tokens, self.vocab_size))
|
1051 |
+
.to(self.device))
|
1052 |
+
for i in range(num_pred_tokens):
|
1053 |
+
next_logits = self.forward(x[:, :i + 1], None)[:, -1]
|
1054 |
+
y = (next_logits + noise[:, i]).argmax(-1)
|
1055 |
+
x[:, i + 1] = y
|
1056 |
+
return x
|
1057 |
+
|
1058 |
+
@torch.no_grad()
|
1059 |
+
def _sample(self, num_steps=None, eps=1e-5, x_input = None):
|
1060 |
+
"""Generate samples from the model."""
|
1061 |
+
batch_size_per_gpu = self.config.eval.perplexity_batch_size
|
1062 |
+
if self.parameterization == 'ar':
|
1063 |
+
return self._ar_sampler(batch_size_per_gpu)
|
1064 |
+
# Lightning auto-casting is not working in this method for some reason
|
1065 |
+
if num_steps is None:
|
1066 |
+
num_steps = self.config.sampling.steps
|
1067 |
+
if x_input is not None:
|
1068 |
+
x = x_input.input_ids
|
1069 |
+
attention_mask = x_input.attention_mask
|
1070 |
+
else:
|
1071 |
+
x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
|
1072 |
+
attention_mask = torch.ones_like(x)
|
1073 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
1074 |
+
dt = (1 - eps) / num_steps
|
1075 |
+
p_x0_cache = None
|
1076 |
+
|
1077 |
+
for i in range(num_steps):
|
1078 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
|
1079 |
+
if self.sampler == 'ddpm':
|
1080 |
+
x = self._ddpm_update(x, t, dt)
|
1081 |
+
elif self.sampler == 'ddpm_cache':
|
1082 |
+
p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, attention_mask=attention_mask)
|
1083 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
1084 |
+
# Disable caching
|
1085 |
+
p_x0_cache = None
|
1086 |
+
x = x_next
|
1087 |
+
# print(self.tokenizer.decode(x.squeeze()))
|
1088 |
+
else:
|
1089 |
+
x = self._analytic_update(x, t, dt, attention_mask)
|
1090 |
+
|
1091 |
+
if self.config.sampling.noise_removal:
|
1092 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1,
|
1093 |
+
device=self.device)
|
1094 |
+
if self.sampler == 'analytic':
|
1095 |
+
x = self._denoiser_update(x, t)
|
1096 |
+
else:
|
1097 |
+
unet_conditioning = self.noise(t)[0]
|
1098 |
+
x = self.forward(x, unet_conditioning, attention_mask, print_logits=True).argmax(dim=-1)
|
1099 |
+
# print(self.tokenizer.decode(x.squeeze()))
|
1100 |
+
return x
|
1101 |
+
|
1102 |
+
def restore_model_and_sample(self, num_steps, eps=1e-5):
|
1103 |
+
"""Generate samples from the model."""
|
1104 |
+
# Lightning auto-casting is not working in this method for some reason
|
1105 |
+
# params_with_grad = [p for p in itertools.chain(
|
1106 |
+
# self.backbone.parameters(),
|
1107 |
+
# self.noise.parameters()
|
1108 |
+
# ) if p.requires_grad]
|
1109 |
+
|
1110 |
+
if self.ema:
|
1111 |
+
self.ema.store(itertools.chain(self.backbone.parameters(),
|
1112 |
+
self.noise.parameters()))
|
1113 |
+
self.ema.copy_to(itertools.chain(self.backbone.parameters(),
|
1114 |
+
self.noise.parameters()))
|
1115 |
+
self.backbone.eval()
|
1116 |
+
self.noise.eval()
|
1117 |
+
samples = self._sample(num_steps=num_steps, eps=eps)
|
1118 |
+
if self.ema:
|
1119 |
+
self.ema.restore(itertools.chain(self.backbone.parameters(),
|
1120 |
+
self.noise.parameters()))
|
1121 |
+
self.backbone.train()
|
1122 |
+
self.noise.train()
|
1123 |
+
return samples
|
1124 |
+
|
1125 |
+
def get_score(self, x, sigma, attention_mask=None):
|
1126 |
+
model_output = self.forward(x, sigma, attention_mask)
|
1127 |
+
if self.parameterization == 'subs':
|
1128 |
+
# score(x, t) = p_t(y) / p_t(x)
|
1129 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
1130 |
+
|
1131 |
+
# case 1: x = masked
|
1132 |
+
# (i) y = unmasked
|
1133 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
1134 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
1135 |
+
# (ii) y = masked
|
1136 |
+
# log score(x, t) = 0
|
1137 |
+
|
1138 |
+
# case 2: x = unmasked
|
1139 |
+
# (i) y != masked, y != x
|
1140 |
+
# log score(x_i, t) = - inf
|
1141 |
+
# (ii) y = x
|
1142 |
+
# log score(x_i, t) = 0
|
1143 |
+
# (iii) y = masked token
|
1144 |
+
# log score(x_i, t) = - log k
|
1145 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
1146 |
+
|
1147 |
+
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
|
1148 |
+
assert log_k.ndim == 1
|
1149 |
+
|
1150 |
+
masked_score = model_output + log_k[:, None, None]
|
1151 |
+
masked_score[:, :, self.mask_index] = 0
|
1152 |
+
|
1153 |
+
unmasked_score = self.neg_infinity * torch.ones_like(
|
1154 |
+
model_output)
|
1155 |
+
unmasked_score = torch.scatter(
|
1156 |
+
unmasked_score,
|
1157 |
+
-1,
|
1158 |
+
x[..., None],
|
1159 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
1160 |
+
unmasked_score[:, :, self.mask_index] = - (
|
1161 |
+
log_k[:, None] * torch.ones_like(x))
|
1162 |
+
|
1163 |
+
masked_indices = (x == self.mask_index).to(
|
1164 |
+
model_output.dtype)[:, :, None]
|
1165 |
+
model_output = (
|
1166 |
+
masked_score * masked_indices
|
1167 |
+
+ unmasked_score * (1 - masked_indices))
|
1168 |
+
return model_output.exp()
|
1169 |
+
|
1170 |
+
def _staggered_score(self, score, dsigma):
|
1171 |
+
score = score.clone()
|
1172 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
1173 |
+
score *= dsigma.exp()[:, None]
|
1174 |
+
score[..., self.mask_index] += extra_const
|
1175 |
+
return score
|
1176 |
+
|
1177 |
+
def _analytic_update(self, x, t, step_size, attention_mask=None):
|
1178 |
+
curr_sigma, _ = self.noise(t)
|
1179 |
+
next_sigma, _ = self.noise(t - step_size)
|
1180 |
+
dsigma = curr_sigma - next_sigma
|
1181 |
+
score = self.get_score(x, curr_sigma, attention_mask)
|
1182 |
+
stag_score = self._staggered_score(score, dsigma)
|
1183 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
1184 |
+
return _sample_categorical(probs)
|
1185 |
+
|
1186 |
+
def _denoiser_update(self, x, t):
|
1187 |
+
sigma, _ = self.noise(t)
|
1188 |
+
score = self.get_score(x, sigma)
|
1189 |
+
stag_score = self._staggered_score(score, sigma)
|
1190 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
1191 |
+
probs[..., self.mask_index] = 0
|
1192 |
+
samples = _sample_categorical(probs)
|
1193 |
+
return samples
|
1194 |
+
|
1195 |
+
def _transp_transition(self, i, sigma):
|
1196 |
+
sigma = _unsqueeze(sigma, reference=i[..., None])
|
1197 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
1198 |
+
i, num_classes=self.vocab_size)
|
1199 |
+
edge += torch.where(i == self.mask_index,
|
1200 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
1201 |
+
0)[..., None]
|
1202 |
+
return edge
|
1203 |
+
|
1204 |
+
def _sample_t(self, n, device):
|
1205 |
+
_eps_t = torch.rand(n, device=device)
|
1206 |
+
if self.antithetic_sampling:
|
1207 |
+
offset = torch.arange(n, device=device) / n
|
1208 |
+
_eps_t = (_eps_t / n + offset) % 1
|
1209 |
+
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
|
1210 |
+
if self.importance_sampling:
|
1211 |
+
return self.noise.importance_sampling_transformation(t)
|
1212 |
+
return t
|
1213 |
+
|
1214 |
+
def _maybe_sub_sample(self, x0, attention_mask):
|
1215 |
+
# seqlen = x0.shape[1]
|
1216 |
+
# if seqlen > self.config.model.length:
|
1217 |
+
# assert seqlen == 2 * self.config.model.length
|
1218 |
+
# # cropping is needed for text8-crop dataset
|
1219 |
+
# # try the same starting point for now
|
1220 |
+
# start = np.random.choice(self.config.model.length)
|
1221 |
+
# end = start + self.config.model.length
|
1222 |
+
# input_tokens = x0[:, start: end]
|
1223 |
+
# output_tokens = x0[:, start + 1: end + 1]
|
1224 |
+
# new_attention_mask = attention_mask[:, start: end]
|
1225 |
+
|
1226 |
+
# # Helps with validation PPL, since the val
|
1227 |
+
# # examples will all start and end with BOS/EOS
|
1228 |
+
# input_tokens[:, 0] = self.tokenizer.bos_token_id
|
1229 |
+
# output_tokens[:, -1] = self.tokenizer.eos_token_id
|
1230 |
+
# elif self.parameterization == 'ar':
|
1231 |
+
# input_tokens = x0[:, :-1]
|
1232 |
+
# output_tokens = x0[:, 1:]
|
1233 |
+
# new_attention_mask = attention_mask[:, 1:]
|
1234 |
+
# else:
|
1235 |
+
input_tokens = x0
|
1236 |
+
output_tokens = None
|
1237 |
+
new_attention_mask = attention_mask
|
1238 |
+
return input_tokens, output_tokens, new_attention_mask
|
1239 |
+
|
1240 |
+
def _reconstruction_loss(self, x0, attention_mask):
|
1241 |
+
t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
|
1242 |
+
device=self.device)
|
1243 |
+
assert self.config.noise.type == 'loglinear'
|
1244 |
+
# The above assert is for d3pm parameterization
|
1245 |
+
unet_conditioning = self.noise(t0)[0][:, None]
|
1246 |
+
model_output_t0 = self.forward(x0, unet_conditioning, attention_mask)
|
1247 |
+
return - torch.gather(input=model_output_t0,
|
1248 |
+
dim=-1,
|
1249 |
+
index=x0[:, :, None]).squeeze(-1)
|
1250 |
+
|
1251 |
+
def _forward_pass_diffusion(self, x0, attention_mask, mask=None):
|
1252 |
+
t = self._sample_t(x0.shape[0], x0.device)
|
1253 |
+
if self.T > 0:
|
1254 |
+
t = (t * self.T).to(torch.int)
|
1255 |
+
t = t / self.T
|
1256 |
+
# t \in {1/T, 2/T, ..., 1}
|
1257 |
+
t += (1 / self.T)
|
1258 |
+
|
1259 |
+
if self.change_of_variables:
|
1260 |
+
unet_conditioning = t[:, None]
|
1261 |
+
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
|
1262 |
+
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
|
1263 |
+
move_chance = torch.exp(f_0 + t * (f_T - f_0))
|
1264 |
+
move_chance = move_chance[:, None]
|
1265 |
+
else:
|
1266 |
+
sigma, dsigma = self.noise(t)
|
1267 |
+
unet_conditioning = sigma[:, None]
|
1268 |
+
move_chance = 1 - torch.exp(-sigma[:, None])
|
1269 |
+
|
1270 |
+
if mask is None: xt = self.q_xt(x0, move_chance)
|
1271 |
+
else: xt = x0.where(mask==1, torch.full_like(x0, self.tokenizer.mask_token_id))
|
1272 |
+
model_output = self.forward(xt, unet_conditioning, attention_mask)
|
1273 |
+
# print(self.tokenizer.decode(torch.argmax(model_output[0], dim=-1)))
|
1274 |
+
|
1275 |
+
utils.print_nans(model_output, 'model_output')
|
1276 |
+
|
1277 |
+
if self.parameterization == 'sedd':
|
1278 |
+
return dsigma[:, None] * self._score_entropy(
|
1279 |
+
model_output, sigma[:, None], xt, x0)
|
1280 |
+
|
1281 |
+
if self.T > 0:
|
1282 |
+
diffusion_loss = self._d3pm_loss(
|
1283 |
+
model_output=model_output, xt=xt, x0=x0, t=t)
|
1284 |
+
if self.parameterization == 'd3pm':
|
1285 |
+
reconstruction_loss = self._reconstruction_loss(x0)
|
1286 |
+
elif self.parameterization == 'subs':
|
1287 |
+
reconstruction_loss = 0
|
1288 |
+
return reconstruction_loss + diffusion_loss
|
1289 |
+
|
1290 |
+
# SUBS parameterization, continuous time.
|
1291 |
+
log_p_theta = torch.gather(
|
1292 |
+
input=model_output,
|
1293 |
+
dim=-1,
|
1294 |
+
index=x0[:, :, None]).squeeze(-1)
|
1295 |
+
|
1296 |
+
if self.change_of_variables or self.importance_sampling:
|
1297 |
+
return log_p_theta * torch.log1p(
|
1298 |
+
- torch.exp(- self.noise.sigma_min))
|
1299 |
+
|
1300 |
+
return - log_p_theta * (
|
1301 |
+
dsigma / torch.expm1(sigma))[:, None]
|
1302 |
+
|
1303 |
+
def _loss(self, x0, attention_mask, mask=None):
|
1304 |
+
(input_tokens, output_tokens,
|
1305 |
+
attention_mask) = self._maybe_sub_sample(
|
1306 |
+
x0, attention_mask)
|
1307 |
+
|
1308 |
+
if self.parameterization == 'ar':
|
1309 |
+
logprobs = self.backbone(input_tokens, None, attention_mask)
|
1310 |
+
loss = - logprobs.gather(
|
1311 |
+
-1, output_tokens[:, :, None])[:, :, 0]
|
1312 |
+
else:
|
1313 |
+
loss = self._forward_pass_diffusion(input_tokens, attention_mask, mask)
|
1314 |
+
|
1315 |
+
nlls = loss * attention_mask
|
1316 |
+
count = attention_mask.sum()
|
1317 |
+
|
1318 |
+
batch_nll = nlls.sum()
|
1319 |
+
token_nll = batch_nll / count
|
1320 |
+
|
1321 |
+
return Loss(loss=token_nll,
|
1322 |
+
nlls=nlls,
|
1323 |
+
token_mask=attention_mask)
|
1324 |
+
|
1325 |
+
def _score_entropy(self, log_score, sigma, xt, x0):
|
1326 |
+
"""Computes the SEDD loss.
|
1327 |
+
|
1328 |
+
Args:
|
1329 |
+
log_score: float torch.Tensor with shape (batch_size,
|
1330 |
+
diffusion_model_input_length, vocab_size),
|
1331 |
+
log score, output of the denoising network.
|
1332 |
+
xt: int torch.Tensor with shape (batch_size,
|
1333 |
+
diffusion_model_input_length), input.
|
1334 |
+
x0: int torch.Tensor with shape (batch_size,
|
1335 |
+
diffusion_model_input_length), input.
|
1336 |
+
sigma: float torch.Tensor with shape (batch_size, 1).
|
1337 |
+
|
1338 |
+
Returns:
|
1339 |
+
loss with shape (batch_size, diffusion_model_input_length)
|
1340 |
+
"""
|
1341 |
+
masked_indices = xt == self.mask_index
|
1342 |
+
|
1343 |
+
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
|
1344 |
+
q_ratio = 1 / expsig_minus_1[masked_indices]
|
1345 |
+
|
1346 |
+
words_that_were_masked = x0[masked_indices]
|
1347 |
+
|
1348 |
+
neg_term = q_ratio * torch.gather(
|
1349 |
+
log_score[masked_indices],
|
1350 |
+
-1,
|
1351 |
+
words_that_were_masked[..., None]).squeeze(-1)
|
1352 |
+
score = log_score[masked_indices].exp()
|
1353 |
+
if self.mask_index == self.vocab_size - 1:
|
1354 |
+
pos_term = score[:, :-1].sum(dim=-1)
|
1355 |
+
else:
|
1356 |
+
pos_term = score[:, : self.mask_index].sum(
|
1357 |
+
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
|
1358 |
+
const = q_ratio * (q_ratio.log() - 1)
|
1359 |
+
|
1360 |
+
entropy = torch.zeros(* xt.shape, device=xt.device)
|
1361 |
+
entropy[masked_indices] += pos_term - neg_term + const
|
1362 |
+
return entropy
|
1363 |
+
|
1364 |
+
@torch.no_grad
|
1365 |
+
def sample_subs_guidance(
|
1366 |
+
self, n_samples, stride_length, num_strides, dt=0.001):
|
1367 |
+
ones = torch.ones(n_samples, dtype=self.dtype,
|
1368 |
+
device=self.device)
|
1369 |
+
|
1370 |
+
num_steps = int(1 / dt)
|
1371 |
+
sampling_steps = 0
|
1372 |
+
intermediate_tokens = []
|
1373 |
+
target = None
|
1374 |
+
for _ in range(num_strides + 1):
|
1375 |
+
p_x0_cache = None
|
1376 |
+
x = self._sample_prior(
|
1377 |
+
n_samples,
|
1378 |
+
self.config.model.length).to(self.device)
|
1379 |
+
if target is not None:
|
1380 |
+
x[:, : -stride_length] = target
|
1381 |
+
for i in range(num_steps + 1):
|
1382 |
+
p_x0_cache, x_next = self._ddpm_caching_update(
|
1383 |
+
x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
|
1384 |
+
if (not torch.allclose(x_next, x)
|
1385 |
+
or self.time_conditioning):
|
1386 |
+
p_x0_cache = None
|
1387 |
+
sampling_steps += 1
|
1388 |
+
x = x_next
|
1389 |
+
x = self.forward(x, 0 * ones).argmax(dim=-1)
|
1390 |
+
intermediate_tokens.append(
|
1391 |
+
x[:, :stride_length].cpu().numpy())
|
1392 |
+
target = x[:, stride_length:]
|
1393 |
+
|
1394 |
+
intermediate_tokens.append(target.cpu().numpy())
|
1395 |
+
intermediate_text_samples = []
|
1396 |
+
sequence_lengths = ((
|
1397 |
+
np.concatenate(intermediate_tokens, axis=1)[:, 1:]
|
1398 |
+
== self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
|
1399 |
+
for i in range(2, len(intermediate_tokens) + 1):
|
1400 |
+
intermediate_text_samples.append(
|
1401 |
+
self.tokenizer.batch_decode(
|
1402 |
+
np.concatenate(intermediate_tokens[:i], axis=1)))
|
1403 |
+
return (sampling_steps, intermediate_text_samples,
|
1404 |
+
sequence_lengths)
|
1405 |
+
|
1406 |
+
def restore_model_and_semi_ar_sample(
|
1407 |
+
self, stride_length, num_strides, dt=0.001):
|
1408 |
+
"""Generate samples from the model."""
|
1409 |
+
# Lightning auto-casting is not working in this method for some reason
|
1410 |
+
|
1411 |
+
# params_with_grad = [p for p in itertools.chain(
|
1412 |
+
# self.backbone.parameters(),
|
1413 |
+
# self.noise.parameters()
|
1414 |
+
# ) if p]
|
1415 |
+
|
1416 |
+
if self.ema:
|
1417 |
+
self.ema.store(itertools.chain(self.backbone.parameters(),
|
1418 |
+
self.noise.parameters()))
|
1419 |
+
self.ema.copy_to(itertools.chain(self.backbone.parameters(),
|
1420 |
+
self.noise.parameters()))
|
1421 |
+
self.backbone.eval()
|
1422 |
+
self.noise.eval()
|
1423 |
+
(sampling_steps, samples,
|
1424 |
+
sequence_lengths) = self.sample_subs_guidance(
|
1425 |
+
n_samples=self.config.loader.eval_batch_size,
|
1426 |
+
stride_length=stride_length,
|
1427 |
+
num_strides=num_strides,
|
1428 |
+
dt=dt)
|
1429 |
+
if self.ema:
|
1430 |
+
self.ema.restore(itertools.chain(self.backbone.parameters(),
|
1431 |
+
self.noise.parameters()))
|
1432 |
+
self.backbone.train()
|
1433 |
+
self.noise.train()
|
1434 |
+
return sampling_steps, samples, sequence_lengths
|
dit.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import typing
|
3 |
+
|
4 |
+
import flash_attn
|
5 |
+
import flash_attn.layers.rotary
|
6 |
+
import huggingface_hub
|
7 |
+
import omegaconf
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
from transformers import AutoModel
|
14 |
+
|
15 |
+
# Flags required to enable jit fusion kernels
|
16 |
+
torch._C._jit_set_profiling_mode(False)
|
17 |
+
torch._C._jit_set_profiling_executor(False)
|
18 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
19 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
20 |
+
|
21 |
+
|
22 |
+
def bias_dropout_add_scale(
|
23 |
+
x: torch.Tensor,
|
24 |
+
bias: typing.Optional[torch.Tensor],
|
25 |
+
scale: torch.Tensor,
|
26 |
+
residual: typing.Optional[torch.Tensor],
|
27 |
+
prob: float,
|
28 |
+
training: bool) -> torch.Tensor:
|
29 |
+
if bias is not None:
|
30 |
+
out = scale * F.dropout(x + bias, p=prob, training=training)
|
31 |
+
else:
|
32 |
+
out = scale * F.dropout(x, p=prob, training=training)
|
33 |
+
|
34 |
+
if residual is not None:
|
35 |
+
out = residual + out
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
def get_bias_dropout_add_scale(training):
|
40 |
+
def _bias_dropout_add(x, bias, scale, residual, prob):
|
41 |
+
return bias_dropout_add_scale(
|
42 |
+
x, bias, scale, residual, prob, training)
|
43 |
+
|
44 |
+
return _bias_dropout_add
|
45 |
+
|
46 |
+
|
47 |
+
# function overload
|
48 |
+
def modulate(x: torch.Tensor,
|
49 |
+
shift: torch.Tensor,
|
50 |
+
scale: torch.Tensor) -> torch.Tensor:
|
51 |
+
return x * (1 + scale) + shift
|
52 |
+
|
53 |
+
|
54 |
+
@torch.jit.script
|
55 |
+
def bias_dropout_add_scale_fused_train(
|
56 |
+
x: torch.Tensor,
|
57 |
+
bias: typing.Optional[torch.Tensor],
|
58 |
+
scale: torch.Tensor,
|
59 |
+
residual: typing.Optional[torch.Tensor],
|
60 |
+
prob: float) -> torch.Tensor:
|
61 |
+
return bias_dropout_add_scale(
|
62 |
+
x, bias, scale, residual, prob, True)
|
63 |
+
|
64 |
+
|
65 |
+
@torch.jit.script
|
66 |
+
def bias_dropout_add_scale_fused_inference(
|
67 |
+
x: torch.Tensor,
|
68 |
+
bias: typing.Optional[torch.Tensor],
|
69 |
+
scale: torch.Tensor,
|
70 |
+
residual: typing.Optional[torch.Tensor],
|
71 |
+
prob: float) -> torch.Tensor:
|
72 |
+
return bias_dropout_add_scale(
|
73 |
+
x, bias, scale, residual, prob, False)
|
74 |
+
|
75 |
+
|
76 |
+
@torch.jit.script
|
77 |
+
def modulate_fused(x: torch.Tensor,
|
78 |
+
shift: torch.Tensor,
|
79 |
+
scale: torch.Tensor) -> torch.Tensor:
|
80 |
+
return modulate(x, shift, scale)
|
81 |
+
|
82 |
+
|
83 |
+
class Rotary(torch.nn.Module):
|
84 |
+
def __init__(self, dim, base=10_000):
|
85 |
+
super().__init__()
|
86 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
87 |
+
self.register_buffer('inv_freq', inv_freq)
|
88 |
+
self.seq_len_cached = None
|
89 |
+
self.cos_cached = None
|
90 |
+
self.sin_cached = None
|
91 |
+
|
92 |
+
def forward(self, x, seq_dim=1):
|
93 |
+
seq_len = x.shape[seq_dim]
|
94 |
+
if seq_len != self.seq_len_cached:
|
95 |
+
self.seq_len_cached = seq_len
|
96 |
+
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
|
97 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
|
98 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
99 |
+
# dims are: batch, seq_len, qkv, head, dim
|
100 |
+
self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
|
101 |
+
self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
|
102 |
+
# This makes the transformation on v an identity.
|
103 |
+
self.cos_cached[:,:,2,:,:].fill_(1.)
|
104 |
+
self.sin_cached[:,:,2,:,:].fill_(0.)
|
105 |
+
|
106 |
+
return self.cos_cached, self.sin_cached
|
107 |
+
|
108 |
+
|
109 |
+
def rotate_half(x):
|
110 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
111 |
+
return torch.cat((-x2, x1), dim=-1)
|
112 |
+
|
113 |
+
|
114 |
+
def apply_rotary_pos_emb(qkv, cos, sin):
|
115 |
+
cos = cos[0,:,0,0,:cos.shape[-1]//2]
|
116 |
+
sin = sin[0,:,0,0,:sin.shape[-1]//2]
|
117 |
+
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
|
118 |
+
|
119 |
+
|
120 |
+
# function overload
|
121 |
+
def modulate(x, shift, scale):
|
122 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
123 |
+
|
124 |
+
|
125 |
+
#################################################################################
|
126 |
+
# Layers #
|
127 |
+
#################################################################################
|
128 |
+
class LayerNorm(nn.Module):
|
129 |
+
def __init__(self, dim):
|
130 |
+
super().__init__()
|
131 |
+
self.weight = nn.Parameter(torch.ones([dim]))
|
132 |
+
self.dim = dim
|
133 |
+
def forward(self, x):
|
134 |
+
with torch.cuda.amp.autocast(enabled=False):
|
135 |
+
x = F.layer_norm(x.float(), [self.dim])
|
136 |
+
return x * self.weight[None,None,:]
|
137 |
+
|
138 |
+
|
139 |
+
def residual_linear(x, W, x_skip, residual_scale):
|
140 |
+
"""x_skip + residual_scale * W @ x"""
|
141 |
+
dim_out, dim_in = W.shape[0], W.shape[1]
|
142 |
+
return torch.addmm(
|
143 |
+
x_skip.view(-1, dim_out),
|
144 |
+
x.view(-1, dim_in),
|
145 |
+
W.T,
|
146 |
+
alpha=residual_scale).view(*x.shape[:-1], dim_out)
|
147 |
+
|
148 |
+
|
149 |
+
#################################################################################
|
150 |
+
# Embedding Layers for Timesteps and Class Labels #
|
151 |
+
#################################################################################
|
152 |
+
class TimestepEmbedder(nn.Module):
|
153 |
+
"""
|
154 |
+
Embeds scalar timesteps into vector representations.
|
155 |
+
"""
|
156 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
157 |
+
super().__init__()
|
158 |
+
self.mlp = nn.Sequential(
|
159 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
160 |
+
nn.SiLU(),
|
161 |
+
nn.Linear(hidden_size, hidden_size, bias=True))
|
162 |
+
self.frequency_embedding_size = frequency_embedding_size
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def timestep_embedding(t, dim, max_period=10000):
|
166 |
+
"""
|
167 |
+
Create sinusoidal timestep embeddings.
|
168 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
169 |
+
These may be fractional.
|
170 |
+
:param dim: the dimension of the output.
|
171 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
172 |
+
:return: an (N, D) Tensor of positional embeddings.
|
173 |
+
"""
|
174 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
175 |
+
half = dim // 2
|
176 |
+
freqs = torch.exp(- math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
177 |
+
|
178 |
+
if t.ndim == 1:
|
179 |
+
t = t.unsqueeze(1)
|
180 |
+
|
181 |
+
args = t.float() * freqs[None, :]
|
182 |
+
#args = t[:, None].float() * freqs[None]
|
183 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
184 |
+
if dim % 2:
|
185 |
+
embedding = torch.cat(
|
186 |
+
[embedding,
|
187 |
+
torch.zeros_like(embedding[:, :1])], dim=-1)
|
188 |
+
return embedding
|
189 |
+
|
190 |
+
def forward(self, t):
|
191 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
192 |
+
t_emb = self.mlp(t_freq)
|
193 |
+
return t_emb
|
194 |
+
|
195 |
+
|
196 |
+
class LabelEmbedder(nn.Module):
|
197 |
+
"""Embeds class labels into vector representations.
|
198 |
+
|
199 |
+
Also handles label dropout for classifier-free guidance.
|
200 |
+
"""
|
201 |
+
def __init__(self, num_classes, cond_size):
|
202 |
+
super().__init__()
|
203 |
+
self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
|
204 |
+
self.num_classes = num_classes
|
205 |
+
|
206 |
+
# TODO think of initializing with 0.02 std deviation like in original DiT paper
|
207 |
+
|
208 |
+
def forward(self, labels):
|
209 |
+
embeddings = self.embedding_table(labels)
|
210 |
+
return embeddings
|
211 |
+
|
212 |
+
|
213 |
+
#################################################################################
|
214 |
+
# Core Model #
|
215 |
+
#################################################################################
|
216 |
+
|
217 |
+
|
218 |
+
class DDiTBlock(nn.Module):
|
219 |
+
def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
|
220 |
+
super().__init__()
|
221 |
+
self.n_heads = n_heads
|
222 |
+
|
223 |
+
self.norm1 = LayerNorm(dim)
|
224 |
+
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
|
225 |
+
self.attn_out = nn.Linear(dim, dim, bias=False)
|
226 |
+
self.dropout1 = nn.Dropout(dropout)
|
227 |
+
|
228 |
+
self.norm2 = LayerNorm(dim)
|
229 |
+
self.mlp = nn.Sequential(
|
230 |
+
nn.Linear(dim, mlp_ratio * dim, bias=True),
|
231 |
+
nn.GELU(approximate='tanh'),
|
232 |
+
nn.Linear(mlp_ratio * dim, dim, bias=True))
|
233 |
+
self.dropout2 = nn.Dropout(dropout)
|
234 |
+
self.dropout = dropout
|
235 |
+
|
236 |
+
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
|
237 |
+
self.adaLN_modulation.weight.data.zero_()
|
238 |
+
self.adaLN_modulation.bias.data.zero_()
|
239 |
+
|
240 |
+
|
241 |
+
def _get_bias_dropout_scale(self):
|
242 |
+
if self.training:
|
243 |
+
return bias_dropout_add_scale_fused_train
|
244 |
+
else:
|
245 |
+
return bias_dropout_add_scale_fused_inference
|
246 |
+
|
247 |
+
|
248 |
+
def forward(self, x, rotary_cos_sin, c, seqlens=None):
|
249 |
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
250 |
+
|
251 |
+
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
252 |
+
|
253 |
+
(shift_msa, scale_msa, gate_msa, shift_mlp,
|
254 |
+
scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
|
255 |
+
|
256 |
+
# attention operation
|
257 |
+
x_skip = x
|
258 |
+
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
259 |
+
|
260 |
+
qkv = self.attn_qkv(x)
|
261 |
+
qkv = rearrange(qkv,
|
262 |
+
'b s (three h d) -> b s three h d',
|
263 |
+
three=3,
|
264 |
+
h=self.n_heads)
|
265 |
+
with torch.cuda.amp.autocast(enabled=False):
|
266 |
+
cos, sin = rotary_cos_sin
|
267 |
+
qkv = apply_rotary_pos_emb(
|
268 |
+
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
|
269 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
270 |
+
if seqlens is None:
|
271 |
+
cu_seqlens = torch.arange(
|
272 |
+
0, (batch_size + 1) * seq_len, step=seq_len,
|
273 |
+
dtype=torch.int32, device=qkv.device)
|
274 |
+
else:
|
275 |
+
cu_seqlens = seqlens.cumsum(-1)
|
276 |
+
x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
|
277 |
+
qkv, cu_seqlens, seq_len, 0., causal=False)
|
278 |
+
|
279 |
+
x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
|
280 |
+
|
281 |
+
x = bias_dropout_scale_fn(self.attn_out(x),
|
282 |
+
None,
|
283 |
+
gate_msa,
|
284 |
+
x_skip,
|
285 |
+
self.dropout)
|
286 |
+
|
287 |
+
# mlp operation
|
288 |
+
x = bias_dropout_scale_fn(
|
289 |
+
self.mlp(modulate_fused(
|
290 |
+
self.norm2(x), shift_mlp, scale_mlp)),
|
291 |
+
None, gate_mlp, x, self.dropout)
|
292 |
+
return x
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
class EmbeddingLayer(nn.Module):
|
297 |
+
def __init__(self, dim, vocab_dim):
|
298 |
+
super().__init__()
|
299 |
+
self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
|
300 |
+
torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
return self.embedding[x]
|
304 |
+
|
305 |
+
|
306 |
+
class DDitFinalLayer(nn.Module):
|
307 |
+
def __init__(self, hidden_size, out_channels, cond_dim):
|
308 |
+
super().__init__()
|
309 |
+
self.norm_final = LayerNorm(hidden_size)
|
310 |
+
self.linear = nn.Linear(hidden_size, out_channels)
|
311 |
+
self.linear.weight.data.zero_()
|
312 |
+
self.linear.bias.data.zero_()
|
313 |
+
|
314 |
+
self.adaLN_modulation = nn.Linear(cond_dim,
|
315 |
+
2 * hidden_size,
|
316 |
+
bias=True)
|
317 |
+
self.adaLN_modulation.weight.data.zero_()
|
318 |
+
self.adaLN_modulation.bias.data.zero_()
|
319 |
+
|
320 |
+
|
321 |
+
def forward(self, x, c):
|
322 |
+
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
|
323 |
+
x = modulate_fused(self.norm_final(x), shift, scale)
|
324 |
+
x = self.linear(x)
|
325 |
+
return x
|
326 |
+
|
327 |
+
|
328 |
+
class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
|
329 |
+
def __init__(self, config, vocab_size: int, mlm_model_path):
|
330 |
+
super().__init__()
|
331 |
+
if type(config) == dict:
|
332 |
+
config = omegaconf.OmegaConf.create(config)
|
333 |
+
|
334 |
+
self.config = config
|
335 |
+
self.vocab_size = vocab_size
|
336 |
+
|
337 |
+
self.vocab_embed = EmbeddingLayer(config.model.hidden_size,
|
338 |
+
vocab_size)
|
339 |
+
self.sigma_map = TimestepEmbedder(config.model.cond_dim)
|
340 |
+
self.rotary_emb = Rotary(
|
341 |
+
config.model.hidden_size // config.model.n_heads)
|
342 |
+
|
343 |
+
blocks = []
|
344 |
+
for _ in range(config.model.n_blocks):
|
345 |
+
blocks.append(DDiTBlock(config.model.hidden_size,
|
346 |
+
config.model.n_heads,
|
347 |
+
config.model.cond_dim,
|
348 |
+
dropout=config.model.dropout))
|
349 |
+
self.blocks = nn.ModuleList(blocks)
|
350 |
+
|
351 |
+
self.output_layer = DDitFinalLayer(
|
352 |
+
config.model.hidden_size,
|
353 |
+
vocab_size,
|
354 |
+
config.model.cond_dim)
|
355 |
+
self.scale_by_sigma = config.model.scale_by_sigma
|
356 |
+
|
357 |
+
self.mlm_model = AutoModel.from_pretrained(mlm_model_path, device_map='cpu')
|
358 |
+
|
359 |
+
def _get_bias_dropout_scale(self):
|
360 |
+
if self.training:
|
361 |
+
return bias_dropout_add_scale_fused_train
|
362 |
+
else:
|
363 |
+
return bias_dropout_add_scale_fused_inference
|
364 |
+
|
365 |
+
def forward(self, indices, sigma):
|
366 |
+
x = self.vocab_embed(indices)
|
367 |
+
c_sigma = F.silu(self.sigma_map(sigma))
|
368 |
+
|
369 |
+
rotary_cos_sin = self.rotary_emb(x)
|
370 |
+
|
371 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
372 |
+
for i in range(len(self.blocks)):
|
373 |
+
x = self.blocks[i](x, rotary_cos_sin, c_sigma, seqlens=None)
|
374 |
+
x = self.output_layer(x, c_sigma)
|
375 |
+
|
376 |
+
# Extract membrane-specific embeddings from final encoder layer
|
377 |
+
# of fine-tuned ESM model
|
378 |
+
# with torch.no_grad():
|
379 |
+
# membrane_embedding = self.mlm_model(input_ids=, attention_mask=).last_hidden_state.squeeze(0)
|
380 |
+
|
381 |
+
# Fuse MLM embeddings with conditioning vector
|
382 |
+
# c = torch.cat([c_sigma, membrane_embedding], dim=-1)
|
383 |
+
|
384 |
+
# print(membrane_embedding.size())
|
385 |
+
# print(c_sigma.size())
|
386 |
+
|
387 |
+
return x
|
388 |
+
|
ema.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class ExponentialMovingAverage:
|
5 |
+
"""
|
6 |
+
Maintains (exponential) moving average of a set of parameters.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self, parameters, decay, use_num_updates=True):
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
parameters: Iterable of `torch.nn.Parameter`; usually the result of
|
13 |
+
`model.parameters()`.
|
14 |
+
decay: The exponential decay.
|
15 |
+
use_num_updates: Whether to use number of updates when computing
|
16 |
+
averages.
|
17 |
+
"""
|
18 |
+
if decay < 0.0 or decay > 1.0:
|
19 |
+
raise ValueError('Decay must be between 0 and 1')
|
20 |
+
self.decay = decay
|
21 |
+
self.num_updates = 0 if use_num_updates else None
|
22 |
+
self.shadow_params = [p.clone().detach()
|
23 |
+
for p in parameters if p.requires_grad]
|
24 |
+
self.collected_params = []
|
25 |
+
|
26 |
+
def move_shadow_params_to_device(self, device):
|
27 |
+
self.shadow_params = [i.to(device) for i in self.shadow_params]
|
28 |
+
|
29 |
+
def update(self, parameters):
|
30 |
+
"""
|
31 |
+
Update currently maintained parameters.
|
32 |
+
|
33 |
+
Call this every time the parameters are updated, such as the result of
|
34 |
+
the `optimizer.step()` call.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
38 |
+
parameters used to initialize this object.
|
39 |
+
"""
|
40 |
+
decay = self.decay
|
41 |
+
if self.num_updates is not None:
|
42 |
+
self.num_updates += 1
|
43 |
+
decay = min(decay, (1 + self.num_updates) /
|
44 |
+
(10 + self.num_updates))
|
45 |
+
one_minus_decay = 1.0 - decay
|
46 |
+
with torch.no_grad():
|
47 |
+
parameters = [p for p in parameters if p.requires_grad]
|
48 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
49 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
50 |
+
|
51 |
+
def copy_to(self, parameters):
|
52 |
+
"""
|
53 |
+
Copy current parameters into given collection of parameters.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
57 |
+
updated with the stored moving averages.
|
58 |
+
"""
|
59 |
+
parameters = [p for p in parameters if p.requires_grad]
|
60 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
61 |
+
if param.requires_grad:
|
62 |
+
param.data.copy_(s_param.data)
|
63 |
+
|
64 |
+
def store(self, parameters):
|
65 |
+
"""
|
66 |
+
Save the current parameters for restoring later.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
70 |
+
temporarily stored.
|
71 |
+
"""
|
72 |
+
self.collected_params = [param.clone() for param in parameters]
|
73 |
+
|
74 |
+
def restore(self, parameters):
|
75 |
+
"""
|
76 |
+
Restore the parameters stored with the `store` method.
|
77 |
+
Useful to validate the model with EMA parameters without affecting the
|
78 |
+
original optimization process. Store the parameters before the
|
79 |
+
`copy_to` method. After validation (or model saving), use this to
|
80 |
+
restore the former parameters.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
84 |
+
updated with the stored parameters.
|
85 |
+
"""
|
86 |
+
for c_param, param in zip(self.collected_params, parameters):
|
87 |
+
param.data.copy_(c_param.data)
|
88 |
+
|
89 |
+
def state_dict(self):
|
90 |
+
return dict(decay=self.decay,
|
91 |
+
num_updates=self.num_updates,
|
92 |
+
shadow_params=self.shadow_params)
|
93 |
+
|
94 |
+
def load_state_dict(self, state_dict):
|
95 |
+
self.decay = state_dict['decay']
|
96 |
+
self.num_updates = state_dict['num_updates']
|
97 |
+
self.shadow_params = state_dict['shadow_params']
|
esm_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import config
|
3 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
|
4 |
+
|
5 |
+
def load_esm2_model(model_name):
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
7 |
+
masked_model = AutoModelForMaskedLM.from_pretrained(model_name)
|
8 |
+
embedding_model = AutoModel.from_pretrained(model_name)
|
9 |
+
return tokenizer, masked_model, embedding_model
|
10 |
+
|
11 |
+
def get_latents(model, tokenizer, sequence, device):
|
12 |
+
inputs = tokenizer(sequence, return_tensors="pt").to(device)
|
13 |
+
with torch.no_grad():
|
14 |
+
outputs = model(**inputs).last_hidden_state.squeeze(0)
|
15 |
+
return outputs
|
generate.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
import pandas as pd
|
7 |
+
from mlm_generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
|
8 |
+
from diffusion import Diffusion
|
9 |
+
import hydra
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def generate_sequence(sequence_length: int, tokenizer, mdlm: Diffusion):
|
16 |
+
global masked_sequence
|
17 |
+
masked_sequence = mask_for_de_novo(sequence_length)
|
18 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt").to(mdlm.device)
|
19 |
+
logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
|
20 |
+
generated_sequence = tokenizer.decode(logits.squeeze())
|
21 |
+
|
22 |
+
return generated_sequence
|
23 |
+
|
24 |
+
|
25 |
+
@hydra.main(version_base=None, config_path='configs', config_name='config')
|
26 |
+
def mdlm_motif_benchmark(config):
|
27 |
+
path = "/workspace/sg666/MDpLM"
|
28 |
+
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
|
30 |
+
mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer)
|
31 |
+
|
32 |
+
mdlm_model.eval()
|
33 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
34 |
+
mdlm_model.to(device)
|
35 |
+
|
36 |
+
print("loaded models...")
|
37 |
+
|
38 |
+
# Get 100 random sequence lengths to generate
|
39 |
+
sequence_lengths = [random.randint(50, 1000) for _ in range(100)]
|
40 |
+
|
41 |
+
generation_results = []
|
42 |
+
for seq_length in tqdm(sequence_lengths, desc=f"Generating sequences: "):
|
43 |
+
generated_sequence = generate_sequence(seq_length, tokenizer, mdlm_model)
|
44 |
+
generated_sequence = generated_sequence[5:-5].replace(" ", "") # Remove bos/eos tokens
|
45 |
+
|
46 |
+
perplexity = mdlm_model.compute_masked_perplexity([generated_sequence], masked_sequence)
|
47 |
+
perplexity = round(perplexity, 4)
|
48 |
+
|
49 |
+
generation_results.append([generated_sequence, perplexity])
|
50 |
+
|
51 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | generated sequence: {generated_sequence}")
|
52 |
+
sys.stdout.flush()
|
53 |
+
|
54 |
+
df = pd.DataFrame(generation_results, columns=['Generated Sequence', 'Perplexity'])
|
55 |
+
df.to_csv(path + f'/benchmarks/mdlm_de-novo_generation_results.csv', index=False)
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
mdlm_motif_benchmark()
|
main.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import wandb
|
4 |
+
import fsspec
|
5 |
+
import hydra
|
6 |
+
import lightning as L
|
7 |
+
import omegaconf
|
8 |
+
import rich.syntax
|
9 |
+
import rich.tree
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import pl_data_loader as dataloader
|
13 |
+
from diffusion import Diffusion
|
14 |
+
import utils
|
15 |
+
|
16 |
+
from lightning.pytorch.strategies import DDPStrategy
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
from datasets import load_from_disk, load_dataset
|
19 |
+
|
20 |
+
#wandb.login(key="2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f")
|
21 |
+
omegaconf.OmegaConf.register_new_resolver(
|
22 |
+
'cwd', os.getcwd)
|
23 |
+
omegaconf.OmegaConf.register_new_resolver(
|
24 |
+
'device_count', torch.cuda.device_count)
|
25 |
+
omegaconf.OmegaConf.register_new_resolver(
|
26 |
+
'eval', eval)
|
27 |
+
omegaconf.OmegaConf.register_new_resolver(
|
28 |
+
'div_up', lambda x, y: (x + y - 1) // y)
|
29 |
+
|
30 |
+
|
31 |
+
def _load_from_checkpoint(config, tokenizer):
|
32 |
+
if 'hf' in config.backbone:
|
33 |
+
return Diffusion(
|
34 |
+
config, tokenizer=tokenizer).to('cuda')
|
35 |
+
else:
|
36 |
+
model= Diffusion.load_from_checkpoint(
|
37 |
+
config.eval.checkpoint_path,
|
38 |
+
tokenizer=tokenizer,
|
39 |
+
config=config)
|
40 |
+
|
41 |
+
return model
|
42 |
+
|
43 |
+
@L.pytorch.utilities.rank_zero_only
|
44 |
+
def _print_config(
|
45 |
+
config: omegaconf.DictConfig,
|
46 |
+
resolve: bool = True,
|
47 |
+
save_cfg: bool = True) -> None:
|
48 |
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
config (DictConfig): Configuration composed by Hydra.
|
52 |
+
resolve (bool): Whether to resolve reference fields of DictConfig.
|
53 |
+
save_cfg (bool): Whether to save the configuration tree to a file.
|
54 |
+
"""
|
55 |
+
|
56 |
+
style = 'dim'
|
57 |
+
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
|
58 |
+
|
59 |
+
fields = config.keys()
|
60 |
+
for field in fields:
|
61 |
+
branch = tree.add(field, style=style, guide_style=style)
|
62 |
+
|
63 |
+
config_section = config.get(field)
|
64 |
+
branch_content = str(config_section)
|
65 |
+
if isinstance(config_section, omegaconf.DictConfig):
|
66 |
+
branch_content = omegaconf.OmegaConf.to_yaml(
|
67 |
+
config_section, resolve=resolve)
|
68 |
+
|
69 |
+
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
|
70 |
+
rich.print(tree)
|
71 |
+
if save_cfg:
|
72 |
+
with fsspec.open(
|
73 |
+
'{}/config_tree.txt'.format(
|
74 |
+
config.checkpointing.save_dir), 'w') as fp:
|
75 |
+
rich.print(tree, file=fp)
|
76 |
+
|
77 |
+
|
78 |
+
@L.pytorch.utilities.rank_zero_only
|
79 |
+
def _print_batch(train_ds, valid_ds, tokenizer, k=64):
|
80 |
+
#for dl_type, dl in [
|
81 |
+
#('train', train_ds), ('valid', valid_ds)]:
|
82 |
+
for dl_type, dl in [
|
83 |
+
('train', train_ds)]:
|
84 |
+
print(f'Printing {dl_type} dataloader batch.')
|
85 |
+
batch = next(iter(dl))
|
86 |
+
print('Batch input_ids.shape', batch['input_ids'].shape)
|
87 |
+
first = batch['input_ids'][0, :k]
|
88 |
+
last = batch['input_ids'][0, -k:]
|
89 |
+
print(f'First {k} tokens:', tokenizer.decode(first))
|
90 |
+
print('ids:', first)
|
91 |
+
print(f'Last {k} tokens:', tokenizer.decode(last))
|
92 |
+
print('ids:', last)
|
93 |
+
|
94 |
+
|
95 |
+
def generate_samples(config, logger, tokenizer):
|
96 |
+
logger.info('Generating samples.')
|
97 |
+
model = _load_from_checkpoint(config=config,
|
98 |
+
tokenizer=tokenizer)
|
99 |
+
model.gen_ppl_metric.reset()
|
100 |
+
if config.eval.disable_ema:
|
101 |
+
logger.info('Disabling EMA.')
|
102 |
+
model.ema = None
|
103 |
+
stride_length = config.sampling.stride_length
|
104 |
+
num_strides = config.sampling.num_strides
|
105 |
+
for _ in range(config.sampling.num_sample_batches):
|
106 |
+
if config.sampling.semi_ar:
|
107 |
+
_, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
|
108 |
+
stride_length=stride_length,
|
109 |
+
num_strides=num_strides,
|
110 |
+
dt=1 / config.sampling.steps)
|
111 |
+
text_samples = intermediate_samples[-1]
|
112 |
+
# Note: Samples generated using semi-ar method
|
113 |
+
# need to to be processed before computing generative perplexity
|
114 |
+
# since these samples contain numerous <|endoftext|> tokens
|
115 |
+
# and diffusion.compute_generative_perplexity() discards
|
116 |
+
# any text after the first EOS token.
|
117 |
+
else:
|
118 |
+
samples = model.restore_model_and_sample(
|
119 |
+
num_steps=config.sampling.steps)
|
120 |
+
text_samples = model.tokenizer.batch_decode(samples)
|
121 |
+
model.compute_generative_perplexity(text_samples)
|
122 |
+
print('Text samples:', text_samples)
|
123 |
+
if not config.sampling.semi_ar:
|
124 |
+
print('Generative perplexity:',
|
125 |
+
model.gen_ppl_metric.compute())
|
126 |
+
return text_samples
|
127 |
+
|
128 |
+
def _ppl_eval(config, logger, tokenizer, data_module):
|
129 |
+
logger.info('Starting Zero Shot Eval.')
|
130 |
+
|
131 |
+
model = _load_from_checkpoint(config=config,
|
132 |
+
tokenizer=tokenizer)
|
133 |
+
if config.eval.disable_ema:
|
134 |
+
logger.info('Disabling EMA.')
|
135 |
+
model.ema = None
|
136 |
+
|
137 |
+
wandb_logger = None
|
138 |
+
if config.get('wandb', None) is not None:
|
139 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
140 |
+
config=omegaconf.OmegaConf.to_object(config),
|
141 |
+
** config.wandb)
|
142 |
+
callbacks = []
|
143 |
+
if 'callbacks' in config:
|
144 |
+
for _, callback in config.callbacks.items():
|
145 |
+
callbacks.append(hydra.utils.instantiate(callback))
|
146 |
+
trainer = hydra.utils.instantiate(
|
147 |
+
config.trainer,
|
148 |
+
default_root_dir=os.getcwd(),
|
149 |
+
callbacks=callbacks,
|
150 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
151 |
+
logger=wandb_logger)
|
152 |
+
# _, valid_ds = dataloader.get_dataloaders(
|
153 |
+
# config, tokenizer, skip_train=True, valid_seed=config.seed)
|
154 |
+
trainer.test(model, data_module)
|
155 |
+
|
156 |
+
|
157 |
+
def _train(config, logger, tokenizer, data_module):
|
158 |
+
logger.info('Starting Training.')
|
159 |
+
wandb_logger = None
|
160 |
+
if config.get('wandb', None) is not None:
|
161 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
162 |
+
config=omegaconf.OmegaConf.to_object(config),
|
163 |
+
** config.wandb)
|
164 |
+
|
165 |
+
if (config.checkpointing.resume_from_ckpt
|
166 |
+
and config.checkpointing.resume_ckpt_path is not None
|
167 |
+
and utils.fsspec_exists(
|
168 |
+
config.checkpointing.resume_ckpt_path)):
|
169 |
+
ckpt_path = config.checkpointing.resume_ckpt_path
|
170 |
+
else:
|
171 |
+
ckpt_path = None
|
172 |
+
|
173 |
+
# Lightning callbacks
|
174 |
+
callbacks = []
|
175 |
+
if 'callbacks' in config:
|
176 |
+
for _, callback in config.callbacks.items():
|
177 |
+
callbacks.append(hydra.utils.instantiate(callback))
|
178 |
+
'''
|
179 |
+
train_ds, valid_ds = dataloader.get_dataloaders(
|
180 |
+
config, tokenizer)
|
181 |
+
_print_batch(train_ds, valid_ds, tokenizer)
|
182 |
+
|
183 |
+
model = diffusion.Diffusion(
|
184 |
+
config, tokenizer=valid_ds.tokenizer)
|
185 |
+
'''
|
186 |
+
trainer = hydra.utils.instantiate(
|
187 |
+
config.trainer,
|
188 |
+
default_root_dir=os.getcwd(),
|
189 |
+
callbacks=callbacks,
|
190 |
+
accelerator='cuda',
|
191 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
192 |
+
logger=wandb_logger)
|
193 |
+
|
194 |
+
model = Diffusion(
|
195 |
+
config, tokenizer=tokenizer)
|
196 |
+
|
197 |
+
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
|
198 |
+
|
199 |
+
'''
|
200 |
+
trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)
|
201 |
+
'''
|
202 |
+
|
203 |
+
@hydra.main(version_base=None, config_path='configs', config_name='config')
|
204 |
+
def main(config):
|
205 |
+
"""Main entry point for training."""
|
206 |
+
L.seed_everything(config.seed)
|
207 |
+
_print_config(config, resolve=True, save_cfg=True)
|
208 |
+
|
209 |
+
logger = utils.get_logger(__name__)
|
210 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
211 |
+
|
212 |
+
if config.backbone == "vanilla_esm_pretrain":
|
213 |
+
train_dataset = load_dataset('csv', data_files=config.data.train.vanilla_esm_train_path)
|
214 |
+
val_dataset = load_dataset('csv', data_files=config.data.valid.vanilla_esm_valid_path)
|
215 |
+
test_dataset = load_dataset('csv', data_files=config.data.test.vanilla_esm_test_path)
|
216 |
+
elif config.backbone == "membrane_esm_finetune" or config.backbone == "dit":
|
217 |
+
train_dataset = load_dataset('csv', data_files=config.data.train.membrane_esm_train_path)
|
218 |
+
val_dataset = load_dataset('csv', data_files=config.data.valid.membrane_esm_valid_path)
|
219 |
+
test_dataset = load_dataset('csv', data_files=config.data.test.membrane_esm_test_path)
|
220 |
+
|
221 |
+
lst = [i for i in range(1, 200)]
|
222 |
+
|
223 |
+
train_dataset = train_dataset['train']#.select(lst)
|
224 |
+
val_dataset = val_dataset['train']#.select(lst)
|
225 |
+
test_dataset = test_dataset['train']#.select(lst)
|
226 |
+
|
227 |
+
if config.training.focus_mask :
|
228 |
+
collator = dataloader.membrane_collate_fn
|
229 |
+
elif config.data.wrapping:
|
230 |
+
collator = dataloader.wrap_collate_fn
|
231 |
+
else:
|
232 |
+
collator = collate_fn
|
233 |
+
|
234 |
+
data_module = dataloader.CustomDataModule(
|
235 |
+
train_dataset, val_dataset, test_dataset,
|
236 |
+
tokenizer,
|
237 |
+
batch_size=config.loader.batch_size,
|
238 |
+
collate_fn=collator
|
239 |
+
)
|
240 |
+
|
241 |
+
if config.mode == 'sample_eval':
|
242 |
+
generate_samples(config, logger, tokenizer)
|
243 |
+
elif config.mode == 'ppl_eval':
|
244 |
+
_ppl_eval(config, logger, tokenizer, data_module)
|
245 |
+
else:
|
246 |
+
_train(config, logger, tokenizer, data_module)
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == '__main__':
|
250 |
+
main()
|
mdlm_motif_benchmarking.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
import pandas as pd
|
7 |
+
from mlm_generate_utils import mask_for_scaffold, calculate_cosine_sim, calculate_hamming_dist
|
8 |
+
from diffusion import Diffusion
|
9 |
+
import hydra
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
12 |
+
|
13 |
+
def masking_test(sequence: str, generate_case: str, tokenizer, mask_prob: float = 0.50):
|
14 |
+
"""
|
15 |
+
Masks 50% of the tokens in the sequence.
|
16 |
+
"""
|
17 |
+
tokens = list(sequence.upper())
|
18 |
+
num_tokens_to_mask = int(mask_prob * len(tokens)) # Select some fraction of the tokens
|
19 |
+
print(num_tokens_to_mask,len(tokens))
|
20 |
+
|
21 |
+
# Get random indices to mask
|
22 |
+
mask_indices = random.sample(range(len(tokens)), num_tokens_to_mask)
|
23 |
+
|
24 |
+
for idx in mask_indices:
|
25 |
+
tokens[idx] = tokenizer.mask_token # Replace with mask token
|
26 |
+
|
27 |
+
return ''.join(tokens)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def generate_scaffold_mdlm(sequence: str, generate_case: str, tokenizer, mdlm: Diffusion):
|
33 |
+
# # Mask soluble or transmembrane domains
|
34 |
+
# masked_sequence = mask_for_scaffold(sequence, generate_case)
|
35 |
+
|
36 |
+
# # Test out different masking rates
|
37 |
+
# masked_sequence = masking_test(sequence, generate_case, tokenizer)
|
38 |
+
|
39 |
+
# 100% masking rate, de novo generation
|
40 |
+
masked_sequence = len(sequence) * "<mask>"
|
41 |
+
|
42 |
+
print(masked_sequence)
|
43 |
+
|
44 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt").to(mdlm.device)
|
45 |
+
|
46 |
+
logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
|
47 |
+
# logits = mdlm.forward(inputs)
|
48 |
+
# print(tokenizer.decode(logits.squeeze(), skip_special_tokens=True))
|
49 |
+
|
50 |
+
return tokenizer.decode(logits.squeeze()), masked_sequence
|
51 |
+
|
52 |
+
|
53 |
+
@hydra.main(version_base=None, config_path='configs', config_name='config')
|
54 |
+
def mdlm_motif_benchmark(config):
|
55 |
+
path = "/workspace/sg666/MDpLM"
|
56 |
+
|
57 |
+
test_sequences = pd.read_csv(path + "/data/membrane/test.csv")['Sequence'].tolist()
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
59 |
+
|
60 |
+
mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer)
|
61 |
+
esm_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D") # model used for functionality testing
|
62 |
+
|
63 |
+
mdlm_model.eval()
|
64 |
+
esm_model.eval()
|
65 |
+
|
66 |
+
print("loaded models...")
|
67 |
+
|
68 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
69 |
+
mdlm_model.to(device)
|
70 |
+
esm_model.to(device)
|
71 |
+
|
72 |
+
for generate_case in ['uppercase', 'lowercase']:
|
73 |
+
case_results = []
|
74 |
+
for original_sequence in tqdm(test_sequences, desc=f"scaffolding ({generate_case}): "):
|
75 |
+
|
76 |
+
generated_sequence, masked_input = generate_scaffold_mdlm(original_sequence, generate_case, tokenizer, mdlm_model)
|
77 |
+
generated_sequence = generated_sequence[5:-5].replace(" ", "") # Remove bos/eos tokens
|
78 |
+
|
79 |
+
perplexity = mdlm_model.compute_masked_perplexity([original_sequence], masked_input)
|
80 |
+
cos_sim = calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, esm_model, device)
|
81 |
+
hamming_distance = calculate_hamming_dist(original_sequence, generated_sequence)
|
82 |
+
|
83 |
+
case_results.append([original_sequence, generated_sequence, perplexity, cos_sim, hamming_distance])
|
84 |
+
|
85 |
+
print("perplexity: ", perplexity, "cos sim: ", cos_sim, "hamming: ", hamming_distance)
|
86 |
+
print(f"generated sequence: {generated_sequence}")
|
87 |
+
print(f"original sequence: {original_sequence.upper()}")
|
88 |
+
sys.stdout.flush()
|
89 |
+
|
90 |
+
df = pd.DataFrame(case_results, columns=['Original Sequence', 'Generated Sequence', 'Perplexity', 'Cosine Similarity', 'Hamming Distance'])
|
91 |
+
df.to_csv(path + f'/benchmarks/MLM/mlm_{generate_case}_results.csv', index=False)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
mdlm_motif_benchmark()
|
mlm_generate_utils.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import config
|
4 |
+
import sys
|
5 |
+
import pandas as pd
|
6 |
+
from esm_utils import get_latents
|
7 |
+
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
def mask_for_de_novo(sequence_length):
|
11 |
+
return "<mask>" * sequence_length
|
12 |
+
|
13 |
+
def generate_de_novo(sequence_length, tokenizer, model):
|
14 |
+
masked_sequence = mask_for_de_novo(sequence_length)
|
15 |
+
inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
|
16 |
+
|
17 |
+
with torch.no_grad():
|
18 |
+
logits = model(**inputs).logits
|
19 |
+
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
|
20 |
+
logits_at_masks = logits[0, mask_token_indices]
|
21 |
+
|
22 |
+
pred_tokens = []
|
23 |
+
for i in mask_token_indices:
|
24 |
+
topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
|
25 |
+
probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
|
26 |
+
predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
|
27 |
+
predicted_token_id = topk_indices[predicted_index].item()
|
28 |
+
predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
|
29 |
+
pred_tokens.append(predicted_token)
|
30 |
+
|
31 |
+
generated_sequence = ''.join(pred_tokens)
|
32 |
+
perplexity = calculate_perplexity(model, tokenizer, generated_sequence)
|
33 |
+
|
34 |
+
return (generated_sequence, perplexity)
|
35 |
+
|
36 |
+
|
37 |
+
def mask_for_scaffold(sequence, generate_type):
|
38 |
+
if generate_type == "uppercase":
|
39 |
+
sequence = ''.join(["<mask>" if residue.isupper() else residue.upper() for residue in sequence])
|
40 |
+
elif generate_type == "lowercase":
|
41 |
+
sequence = ''.join(["<mask>" if residue.islower() else residue for residue in sequence])
|
42 |
+
return sequence
|
43 |
+
|
44 |
+
|
45 |
+
def generate_scaffold(sequence, generate_type, tokenizer, model):
|
46 |
+
masked_sequence = mask_for_scaffold(sequence, generate_type)
|
47 |
+
inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
logits = model(**inputs).logits
|
51 |
+
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
|
52 |
+
logits_at_masks = logits[0, mask_token_indices]
|
53 |
+
|
54 |
+
pred_tokens = []
|
55 |
+
for i in range(len(mask_token_indices)):
|
56 |
+
topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
|
57 |
+
probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
|
58 |
+
predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
|
59 |
+
predicted_token_id = topk_indices[predicted_index].item()
|
60 |
+
predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
|
61 |
+
|
62 |
+
pred_tokens.append('G' if predicted_token == '' else predicted_token)
|
63 |
+
|
64 |
+
generated_sequence = masked_sequence
|
65 |
+
for token in pred_tokens:
|
66 |
+
generated_sequence = generated_sequence.replace("<mask>", token, 1)
|
67 |
+
|
68 |
+
return generated_sequence, mask_token_indices
|
69 |
+
|
70 |
+
|
71 |
+
def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
|
72 |
+
total_loss = 0.0
|
73 |
+
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
|
74 |
+
|
75 |
+
for i in mask_token_indices:
|
76 |
+
masked_input = tensor_input.clone()
|
77 |
+
masked_input[0, i] = tokenizer.mask_token_id
|
78 |
+
|
79 |
+
labels = torch.full(tensor_input.shape, -100).to(model.device)
|
80 |
+
labels[0, i] = tensor_input[0, i]
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
outputs = model(masked_input, labels=labels)
|
84 |
+
total_loss += outputs.loss.item()
|
85 |
+
|
86 |
+
num_mask_tokens = len(mask_token_indices)
|
87 |
+
if num_mask_tokens == 0:
|
88 |
+
perplexity = 10000
|
89 |
+
else:
|
90 |
+
avg_loss = total_loss / num_mask_tokens
|
91 |
+
perplexity = math.exp(avg_loss)
|
92 |
+
|
93 |
+
return perplexity
|
94 |
+
|
95 |
+
|
96 |
+
def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
|
97 |
+
og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
|
98 |
+
new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)
|
99 |
+
|
100 |
+
sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
|
101 |
+
cosine_similarity = torch.mean(sequence_similarity).item()
|
102 |
+
return cosine_similarity
|
103 |
+
|
104 |
+
|
105 |
+
def calculate_hamming_dist(original_sequence, generated_sequence):
|
106 |
+
generated_sequence = generated_sequence.upper()
|
107 |
+
original_sequence = original_sequence.upper()
|
108 |
+
return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))
|
noise_schedule.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
# Flags required to enable jit fusion kernels
|
7 |
+
torch._C._jit_set_profiling_mode(False)
|
8 |
+
torch._C._jit_set_profiling_executor(False)
|
9 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
10 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
11 |
+
|
12 |
+
|
13 |
+
def get_noise(config, dtype=torch.float32):
|
14 |
+
return LogLinearNoise()
|
15 |
+
|
16 |
+
if config.noise.type == 'geometric':
|
17 |
+
return GeometricNoise(config.noise.sigma_min,
|
18 |
+
config.noise.sigma_max)
|
19 |
+
elif config.noise.type == 'loglinear':
|
20 |
+
return LogLinearNoise()
|
21 |
+
elif config.noise.type == 'cosine':
|
22 |
+
return CosineNoise()
|
23 |
+
elif config.noise.type == 'cosinesqr':
|
24 |
+
return CosineSqrNoise()
|
25 |
+
elif config.noise.type == 'linear':
|
26 |
+
return Linear(config.noise.sigma_min,
|
27 |
+
config.noise.sigma_max,
|
28 |
+
dtype)
|
29 |
+
else:
|
30 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
31 |
+
|
32 |
+
|
33 |
+
def binary_discretization(z):
|
34 |
+
z_hard = torch.sign(z)
|
35 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
36 |
+
return z_soft + (z_hard - z_soft).detach()
|
37 |
+
|
38 |
+
|
39 |
+
class Noise(abc.ABC, nn.Module):
|
40 |
+
"""
|
41 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
42 |
+
"""
|
43 |
+
def forward(self, t):
|
44 |
+
# Assume time goes from 0 to 1
|
45 |
+
return self.total_noise(t), self.rate_noise(t)
|
46 |
+
|
47 |
+
@abc.abstractmethod
|
48 |
+
def rate_noise(self, t):
|
49 |
+
"""
|
50 |
+
Rate of change of noise ie g(t)
|
51 |
+
"""
|
52 |
+
pass
|
53 |
+
|
54 |
+
@abc.abstractmethod
|
55 |
+
def total_noise(self, t):
|
56 |
+
"""
|
57 |
+
Total noise ie \int_0^t g(t) dt + g(0)
|
58 |
+
"""
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
class CosineNoise(Noise):
|
63 |
+
def __init__(self, eps=1e-3):
|
64 |
+
super().__init__()
|
65 |
+
self.eps = eps
|
66 |
+
|
67 |
+
def rate_noise(self, t):
|
68 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
69 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
70 |
+
scale = torch.pi / 2
|
71 |
+
return scale * sin / (cos + self.eps)
|
72 |
+
|
73 |
+
def total_noise(self, t):
|
74 |
+
cos = torch.cos(t * torch.pi / 2)
|
75 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
76 |
+
|
77 |
+
|
78 |
+
class CosineSqrNoise(Noise):
|
79 |
+
def __init__(self, eps=1e-3):
|
80 |
+
super().__init__()
|
81 |
+
self.eps = eps
|
82 |
+
|
83 |
+
def rate_noise(self, t):
|
84 |
+
cos = (1 - self.eps) * (
|
85 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
86 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
87 |
+
scale = torch.pi / 2
|
88 |
+
return scale * sin / (cos + self.eps)
|
89 |
+
|
90 |
+
def total_noise(self, t):
|
91 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
92 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
93 |
+
|
94 |
+
|
95 |
+
class Linear(Noise):
|
96 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
97 |
+
super().__init__()
|
98 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
99 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
100 |
+
|
101 |
+
def rate_noise(self, t):
|
102 |
+
return self.sigma_max - self.sigma_min
|
103 |
+
|
104 |
+
def total_noise(self, t):
|
105 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
106 |
+
|
107 |
+
def importance_sampling_transformation(self, t):
|
108 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
109 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
110 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
111 |
+
return (sigma_t - self.sigma_min) / (
|
112 |
+
self.sigma_max - self.sigma_min)
|
113 |
+
|
114 |
+
|
115 |
+
class GeometricNoise(Noise):
|
116 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
117 |
+
super().__init__()
|
118 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
119 |
+
|
120 |
+
def rate_noise(self, t):
|
121 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
122 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
123 |
+
|
124 |
+
def total_noise(self, t):
|
125 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
126 |
+
|
127 |
+
|
128 |
+
class LogLinearNoise(Noise):
|
129 |
+
"""Log Linear noise schedule.
|
130 |
+
|
131 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
|
132 |
+
~1 when t varies from 0 to 1. Total noise is
|
133 |
+
-log(1 - (1 - eps) * t), so the sigma will be
|
134 |
+
(1 - eps) * t.
|
135 |
+
"""
|
136 |
+
def __init__(self, eps=1e-3):
|
137 |
+
super().__init__()
|
138 |
+
self.eps = eps
|
139 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
140 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
141 |
+
|
142 |
+
def rate_noise(self, t):
|
143 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
144 |
+
|
145 |
+
def total_noise(self, t):
|
146 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
147 |
+
|
148 |
+
def importance_sampling_transformation(self, t):
|
149 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
150 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
151 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
152 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
153 |
+
return t
|
pl_data_loader.py
ADDED
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import shutil
|
8 |
+
import typing
|
9 |
+
import urllib
|
10 |
+
import zipfile
|
11 |
+
|
12 |
+
import datasets
|
13 |
+
import fsspec
|
14 |
+
import requests
|
15 |
+
import tokenizers
|
16 |
+
import torch
|
17 |
+
import transformers
|
18 |
+
|
19 |
+
import utils
|
20 |
+
|
21 |
+
LOGGER = utils.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def wt_detokenizer(string):
|
25 |
+
# contractions
|
26 |
+
string = string.replace("s '", "s'")
|
27 |
+
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
28 |
+
# number separators
|
29 |
+
string = string.replace(" @-@ ", "-")
|
30 |
+
string = string.replace(" @,@ ", ",")
|
31 |
+
string = string.replace(" @.@ ", ".")
|
32 |
+
# punctuation
|
33 |
+
string = string.replace(" : ", ": ")
|
34 |
+
string = string.replace(" ; ", "; ")
|
35 |
+
string = string.replace(" . ", ". ")
|
36 |
+
string = string.replace(" ! ", "! ")
|
37 |
+
string = string.replace(" ? ", "? ")
|
38 |
+
string = string.replace(" , ", ", ")
|
39 |
+
# double brackets
|
40 |
+
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
|
41 |
+
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
|
42 |
+
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
|
43 |
+
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
|
44 |
+
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
|
45 |
+
# miscellaneous
|
46 |
+
string = string.replace("= = = =", "====")
|
47 |
+
string = string.replace("= = =", "===")
|
48 |
+
string = string.replace("= =", "==")
|
49 |
+
string = string.replace(" " + chr(176) + " ", chr(176))
|
50 |
+
string = string.replace(" \n", "\n")
|
51 |
+
string = string.replace("\n ", "\n")
|
52 |
+
string = string.replace(" N ", " 1 ")
|
53 |
+
string = string.replace(" 's", "'s")
|
54 |
+
return string
|
55 |
+
|
56 |
+
|
57 |
+
def ptb_detokenizer(x):
|
58 |
+
x = x.replace(" 's", "'s")
|
59 |
+
x = x.replace("s ' ", "s' ")
|
60 |
+
x = x.replace(" n't", "n't")
|
61 |
+
x = x.replace(" \n ", "\n")
|
62 |
+
x = x.replace("\\/", "/")
|
63 |
+
for _ in range(10):
|
64 |
+
x = x.replace(" N ", " 1 ")
|
65 |
+
x = x.replace("$ 1", "$1")
|
66 |
+
x = x.replace("# 1", "#1")
|
67 |
+
x = x.replace("<unk>", "?")
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
def lm1b_detokenizer(x):
|
72 |
+
x = x.replace('http : / / ', 'http://')
|
73 |
+
x = x.replace('https : / / ', 'https://')
|
74 |
+
x = re.sub(r' \'(\w+)', r"'\1", x)
|
75 |
+
x = re.sub(r' (\w+) \. ', r' \1. ', x)
|
76 |
+
x = re.sub(r' (\w+) \.$', r' \1.', x)
|
77 |
+
x = x.replace(' ? ', '? ')
|
78 |
+
x = re.sub(r' \?$', '?', x)
|
79 |
+
x = x.replace(' ! ', '! ')
|
80 |
+
x = re.sub(r' \!$', '!', x)
|
81 |
+
x = x.replace(' , ', ', ')
|
82 |
+
x = x.replace(' : ', ': ')
|
83 |
+
x = x.replace(' ; ', '; ')
|
84 |
+
x = x.replace(' / ', '/')
|
85 |
+
x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
|
86 |
+
x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
|
87 |
+
x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
|
88 |
+
x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
|
89 |
+
x = x.replace('$ ', '$')
|
90 |
+
x = x.replace('£ ', '£')
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
def lambada_detokenizer(text):
|
95 |
+
text = text.replace("“", '"')
|
96 |
+
text = text.replace("”", '"')
|
97 |
+
return '\n'+text.strip()
|
98 |
+
|
99 |
+
|
100 |
+
def scientific_papers_detokenizer(x):
|
101 |
+
x = wt_detokenizer(x)
|
102 |
+
x = lm1b_detokenizer(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class Text8Tokenizer(transformers.PreTrainedTokenizer):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
bos_token='[BOS]',
|
110 |
+
eos_token='[EOS]',
|
111 |
+
sep_token='[SEP]',
|
112 |
+
cls_token='[CLS]',
|
113 |
+
pad_token='[PAD]',
|
114 |
+
mask_token='[MASK]',
|
115 |
+
unk_token='[UNK]',
|
116 |
+
**kwargs):
|
117 |
+
self.characters = list('abcdefghijklmnopqrstuvwxyz ')
|
118 |
+
self._vocab_str_to_int = {
|
119 |
+
'[CLS]': 0,
|
120 |
+
'[SEP]': 1,
|
121 |
+
'[BOS]': 2,
|
122 |
+
'[EOS]': 3,
|
123 |
+
'[MASK]': 4,
|
124 |
+
'[PAD]': 5,
|
125 |
+
'[RESERVED]': 6,
|
126 |
+
'[UNK]': 7,
|
127 |
+
** {ch: i + 8 for i, ch in enumerate(self.characters)}}
|
128 |
+
self._vocab_int_to_str = {
|
129 |
+
v: k for k, v in self._vocab_str_to_int.items()}
|
130 |
+
super().__init__(
|
131 |
+
bos_token=bos_token,
|
132 |
+
eos_token=eos_token,
|
133 |
+
sep_token=sep_token,
|
134 |
+
cls_token=cls_token,
|
135 |
+
pad_token=pad_token,
|
136 |
+
mask_token=mask_token,
|
137 |
+
unk_token=unk_token,
|
138 |
+
**kwargs)
|
139 |
+
|
140 |
+
@property
|
141 |
+
def vocab_size(self) -> int:
|
142 |
+
return len(self._vocab_str_to_int)
|
143 |
+
|
144 |
+
def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
|
145 |
+
return list(text.lower())
|
146 |
+
|
147 |
+
def _convert_token_to_id(self, token: str) -> int:
|
148 |
+
return self._vocab_str_to_int.get(
|
149 |
+
token, self._vocab_str_to_int['[UNK]'])
|
150 |
+
|
151 |
+
def _convert_id_to_token(self, index: int) -> str:
|
152 |
+
return self._vocab_int_to_str[index]
|
153 |
+
|
154 |
+
def convert_tokens_to_string(self, tokens):
|
155 |
+
return ''.join(tokens)
|
156 |
+
|
157 |
+
def get_vocab(self) -> typing.Dict[str, int]:
|
158 |
+
return self._vocab_str_to_int
|
159 |
+
|
160 |
+
|
161 |
+
def get_lambada_test_dataset():
|
162 |
+
url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"
|
163 |
+
|
164 |
+
def read_jsonl_to_list(url):
|
165 |
+
response = requests.get(url, stream=True)
|
166 |
+
data_list = []
|
167 |
+
|
168 |
+
# Process each line in the response content
|
169 |
+
for line in response.iter_lines(decode_unicode=True):
|
170 |
+
if line:
|
171 |
+
data = json.loads(line)
|
172 |
+
data_list.append(data)
|
173 |
+
|
174 |
+
return data_list
|
175 |
+
|
176 |
+
lambada_data = read_jsonl_to_list(url)
|
177 |
+
dataset = datasets.Dataset.from_list(lambada_data)
|
178 |
+
return dataset
|
179 |
+
|
180 |
+
def get_text8_dataset(cache_dir, max_seq_length=256,
|
181 |
+
drop_last=True, crop_train=False):
|
182 |
+
"""Adapted from:
|
183 |
+
https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344
|
184 |
+
|
185 |
+
Args:
|
186 |
+
cache_dir: str, path to cache directory.
|
187 |
+
max_seq_length: int, maximum length of sequences.
|
188 |
+
(default: 256, as in D3PM codebase.)
|
189 |
+
drop_last: bool, whether to drop the last incomplete
|
190 |
+
batch. (default: True, as in D3PM codebase.)
|
191 |
+
crop_train: bool, whether to subsample contiguous
|
192 |
+
subsequences from training example. serves to
|
193 |
+
make sure transformer models with absolute position
|
194 |
+
embeddings do not have incorrect position-wise
|
195 |
+
marginals. (default: False, but necessary to match D3PM AR)
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
dataset: dataset.DatasetDict, with keys 'train',
|
199 |
+
'valid', 'test'.
|
200 |
+
"""
|
201 |
+
url = 'http://mattmahoney.net/dc/text8.zip'
|
202 |
+
if not crop_train:
|
203 |
+
cache_dir = f'{cache_dir}/text8'
|
204 |
+
else:
|
205 |
+
cache_dir = f'{cache_dir}/text8-crop-train'
|
206 |
+
split_names = ['train', 'validation', 'test']
|
207 |
+
if not all([
|
208 |
+
utils.fsspec_exists(os.path.join(cache_dir, split))
|
209 |
+
for split in split_names
|
210 |
+
]):
|
211 |
+
# Check if raw data exists
|
212 |
+
raw_cache_dir = os.path.join(cache_dir, 'raw_data')
|
213 |
+
if not all([
|
214 |
+
utils.fsspec_exists(
|
215 |
+
os.path.join(raw_cache_dir, f'text8.{split}.txt'))
|
216 |
+
for split in split_names
|
217 |
+
]):
|
218 |
+
if not utils.fsspec_exists(
|
219 |
+
os.path.join(raw_cache_dir, 'text8.zip')):
|
220 |
+
utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)
|
221 |
+
LOGGER.info('Downloading text8 from URL {}.'.format(url))
|
222 |
+
with urllib.request.urlopen(url) as in_stream:
|
223 |
+
with open(os.path.join(raw_cache_dir, 'text8.zip'), 'wb') as out_file:
|
224 |
+
shutil.copyfileobj(in_stream, out_file)
|
225 |
+
|
226 |
+
with fsspec.open(
|
227 |
+
os.path.join(raw_cache_dir, 'text8.zip'),
|
228 |
+
'rb') as f:
|
229 |
+
rawdata = zipfile.ZipFile(f).read(
|
230 |
+
'text8').decode('utf-8')
|
231 |
+
|
232 |
+
# Splits taken from D3PM codebase
|
233 |
+
splits = {
|
234 |
+
'train': rawdata[:90000000],
|
235 |
+
'validation': rawdata[90000000: 95000000],
|
236 |
+
'test': rawdata[95000000:],
|
237 |
+
}
|
238 |
+
|
239 |
+
for split, data in splits.items():
|
240 |
+
_path = os.path.join(raw_cache_dir,
|
241 |
+
f'text8.{split}.txt')
|
242 |
+
with fsspec.open(_path, 'w') as f:
|
243 |
+
f.write(data)
|
244 |
+
else:
|
245 |
+
splits = {}
|
246 |
+
for split in split_names:
|
247 |
+
_path = os.path.join(raw_cache_dir,
|
248 |
+
f'text8.{split}.txt')
|
249 |
+
with fsspec.open(_path, 'r') as f:
|
250 |
+
splits[split] = f.read()
|
251 |
+
|
252 |
+
# Chunk and save as datasets.DatasetDict
|
253 |
+
def chunks(lst, n):
|
254 |
+
"""Yield successive n-sized chunks from lst."""
|
255 |
+
for i in range(0, len(lst), n):
|
256 |
+
yield lst[i:i + n]
|
257 |
+
|
258 |
+
dataset_dict = {}
|
259 |
+
for k, v in splits.items():
|
260 |
+
if k == 'train' and crop_train == True:
|
261 |
+
chunk_size = 2 * max_seq_length
|
262 |
+
else:
|
263 |
+
chunk_size = max_seq_length
|
264 |
+
text = list(chunks(v, chunk_size))
|
265 |
+
if drop_last and len(text[-1]) < chunk_size:
|
266 |
+
text = text[:-1]
|
267 |
+
dataset_dict[k] = datasets.Dataset.from_dict({'text': text})
|
268 |
+
dataset = datasets.DatasetDict(dataset_dict)
|
269 |
+
dataset.save_to_disk(cache_dir)
|
270 |
+
else:
|
271 |
+
dataset = datasets.load_from_disk(cache_dir)
|
272 |
+
|
273 |
+
return dataset
|
274 |
+
|
275 |
+
|
276 |
+
def _group_texts(examples, block_size, bos, eos):
|
277 |
+
# Concatenate all texts.
|
278 |
+
concatenated_examples = list(itertools.chain(* examples['input_ids']))
|
279 |
+
total_length = len(concatenated_examples)
|
280 |
+
# TODO(yair): look into not dropping the remainder but rather padding it.
|
281 |
+
# We drop the small remainder, and if the total_length < block_size - 2
|
282 |
+
# we exclude this batch and return an empty dict.
|
283 |
+
# We could add padding if the model supported it instead of
|
284 |
+
# this drop, you can customize this part to your needs.
|
285 |
+
new_block_size = block_size - 2 # [BOS] and [EOS] to be added
|
286 |
+
total_length = (total_length // new_block_size) * new_block_size
|
287 |
+
# Split by chunks of max_len.
|
288 |
+
result = {}
|
289 |
+
_values = []
|
290 |
+
_attn_masks = []
|
291 |
+
for i in range(0, total_length, new_block_size):
|
292 |
+
_values.append(
|
293 |
+
[bos]
|
294 |
+
+ concatenated_examples[i : i + new_block_size]
|
295 |
+
+ [eos])
|
296 |
+
_attn_masks.append(torch.ones(block_size))
|
297 |
+
result['input_ids'] = _values
|
298 |
+
result['attention_mask'] = _attn_masks
|
299 |
+
return result
|
300 |
+
|
301 |
+
|
302 |
+
def get_dataset(
|
303 |
+
dataset_name, tokenizer, wrap, mode, cache_dir,
|
304 |
+
block_size=1024, num_proc=len(os.sched_getaffinity(0)), streaming=False):
|
305 |
+
if wrap:
|
306 |
+
filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped.dat'
|
307 |
+
else:
|
308 |
+
filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped.dat'
|
309 |
+
_path = os.path.join(cache_dir, filename)
|
310 |
+
|
311 |
+
if utils.fsspec_exists(_path):
|
312 |
+
LOGGER.info(f'Loading data from: {_path}')
|
313 |
+
return datasets.load_from_disk(_path).with_format('torch')
|
314 |
+
LOGGER.info(f'Generating new data at: {_path}')
|
315 |
+
|
316 |
+
crop_train = dataset_name == 'text8-crop'
|
317 |
+
if mode == 'train' and crop_train:
|
318 |
+
# double block size for sub-sampling
|
319 |
+
block_size *= 2
|
320 |
+
|
321 |
+
if dataset_name == 'wikitext103':
|
322 |
+
dataset = datasets.load_dataset(
|
323 |
+
'wikitext',
|
324 |
+
name='wikitext-103-raw-v1',
|
325 |
+
cache_dir=cache_dir)
|
326 |
+
elif dataset_name == 'wikitext2':
|
327 |
+
dataset = datasets.load_dataset(
|
328 |
+
'wikitext',
|
329 |
+
name='wikitext-2-raw-v1',
|
330 |
+
cache_dir=cache_dir)
|
331 |
+
elif dataset_name == 'ptb':
|
332 |
+
dataset = datasets.load_dataset(
|
333 |
+
'ptb_text_only', cache_dir=cache_dir)
|
334 |
+
elif dataset_name == 'lambada':
|
335 |
+
dataset = get_lambada_test_dataset()
|
336 |
+
elif dataset_name == 'text8':
|
337 |
+
assert wrap
|
338 |
+
dataset = get_text8_dataset(
|
339 |
+
cache_dir, max_seq_length=block_size)
|
340 |
+
elif dataset_name == 'text8-crop':
|
341 |
+
dataset = get_text8_dataset(
|
342 |
+
cache_dir, max_seq_length=block_size, crop_train=True)
|
343 |
+
elif dataset_name == 'openwebtext-train':
|
344 |
+
dataset = datasets.load_dataset(
|
345 |
+
'openwebtext',
|
346 |
+
split='train[:-100000]',
|
347 |
+
cache_dir=cache_dir,
|
348 |
+
streaming=streaming)
|
349 |
+
elif dataset_name == 'openwebtext-valid':
|
350 |
+
dataset = datasets.load_dataset(
|
351 |
+
'openwebtext',
|
352 |
+
split='train[-100000:]',
|
353 |
+
cache_dir=cache_dir,
|
354 |
+
streaming=streaming)
|
355 |
+
elif dataset_name == 'scientific_papers_arxiv':
|
356 |
+
dataset = datasets.load_dataset(
|
357 |
+
'scientific_papers', 'arxiv',
|
358 |
+
trust_remote_code=True,
|
359 |
+
cache_dir=cache_dir,
|
360 |
+
streaming=streaming)
|
361 |
+
elif dataset_name == 'scientific_papers_pubmed':
|
362 |
+
dataset = datasets.load_dataset(
|
363 |
+
'scientific_papers', 'pubmed',
|
364 |
+
trust_remote_code=True,
|
365 |
+
cache_dir=cache_dir,
|
366 |
+
streaming=streaming)
|
367 |
+
elif dataset_name == 'ag_news':
|
368 |
+
dataset = datasets.load_dataset(
|
369 |
+
'ag_news',
|
370 |
+
cache_dir=cache_dir,
|
371 |
+
streaming=streaming)
|
372 |
+
else:
|
373 |
+
dataset = datasets.load_dataset(
|
374 |
+
dataset_name,
|
375 |
+
cache_dir=cache_dir,
|
376 |
+
streaming=streaming)
|
377 |
+
|
378 |
+
if dataset_name in ['lambada', 'openwebtext-train',
|
379 |
+
'openwebtext-valid']:
|
380 |
+
data = dataset
|
381 |
+
else:
|
382 |
+
data = dataset[mode]
|
383 |
+
|
384 |
+
if dataset_name.startswith('wikitext'):
|
385 |
+
detokenizer = wt_detokenizer
|
386 |
+
elif dataset_name == 'ptb':
|
387 |
+
detokenizer = ptb_detokenizer
|
388 |
+
elif dataset_name == 'lm1b':
|
389 |
+
detokenizer = lm1b_detokenizer
|
390 |
+
elif dataset_name == 'lambada':
|
391 |
+
detokenizer = lambada_detokenizer
|
392 |
+
elif dataset_name.startswith('scientific_papers'):
|
393 |
+
detokenizer = scientific_papers_detokenizer
|
394 |
+
else:
|
395 |
+
detokenizer = None
|
396 |
+
|
397 |
+
def _apply_detokenizer(detokenizer):
|
398 |
+
def detok(text):
|
399 |
+
for i, t in enumerate(text, 0):
|
400 |
+
text[i] = detokenizer(t)
|
401 |
+
return text
|
402 |
+
return detok
|
403 |
+
|
404 |
+
EOS = tokenizer.encode(tokenizer.eos_token)[0]
|
405 |
+
BOS = tokenizer.encode(tokenizer.bos_token)[0]
|
406 |
+
|
407 |
+
def preprocess_and_tokenize(example):
|
408 |
+
if dataset_name == 'ptb':
|
409 |
+
text = example['sentence']
|
410 |
+
elif 'scientific_papers' in dataset_name:
|
411 |
+
text = example['article']
|
412 |
+
else:
|
413 |
+
text = example['text']
|
414 |
+
|
415 |
+
if detokenizer is not None:
|
416 |
+
text = _apply_detokenizer(detokenizer)(text)
|
417 |
+
|
418 |
+
tokenizer.padding_side = 'right'
|
419 |
+
tokenizer.truncation_side = 'right'
|
420 |
+
|
421 |
+
if wrap:
|
422 |
+
tokens = tokenizer(text,
|
423 |
+
add_special_tokens=False,
|
424 |
+
return_attention_mask=False,
|
425 |
+
return_token_type_ids=False)
|
426 |
+
tokens = {'input_ids':
|
427 |
+
[t + [EOS] for t in tokens['input_ids']]}
|
428 |
+
# Still missing BOS, but will be added in group_texts
|
429 |
+
else:
|
430 |
+
tokens = tokenizer(text,
|
431 |
+
max_length=block_size,
|
432 |
+
padding='max_length',
|
433 |
+
truncation=True,
|
434 |
+
add_special_tokens=True,
|
435 |
+
return_attention_mask=True,
|
436 |
+
return_token_type_ids=True)
|
437 |
+
return tokens
|
438 |
+
|
439 |
+
if streaming:
|
440 |
+
tokenized_dataset = data.map(
|
441 |
+
preprocess_and_tokenize,
|
442 |
+
batched=True,
|
443 |
+
desc='Tokenizing')
|
444 |
+
else:
|
445 |
+
tokenized_dataset = data.map(
|
446 |
+
preprocess_and_tokenize,
|
447 |
+
batched=True,
|
448 |
+
num_proc=num_proc,
|
449 |
+
load_from_cache_file=True,
|
450 |
+
desc='Tokenizing')
|
451 |
+
if dataset_name == 'ptb':
|
452 |
+
tokenized_dataset = tokenized_dataset.remove_columns(
|
453 |
+
'sentence')
|
454 |
+
elif 'scientific_papers' in dataset_name:
|
455 |
+
tokenized_dataset = tokenized_dataset.remove_columns([
|
456 |
+
'article', 'abstract', 'section_names'])
|
457 |
+
elif dataset_name == 'ag_news':
|
458 |
+
tokenized_dataset = tokenized_dataset.remove_columns(
|
459 |
+
['text', 'label'])
|
460 |
+
else:
|
461 |
+
tokenized_dataset = tokenized_dataset.remove_columns(
|
462 |
+
'text')
|
463 |
+
|
464 |
+
if not wrap:
|
465 |
+
tokenized_dataset.save_to_disk(_path)
|
466 |
+
return tokenized_dataset.with_format('torch')
|
467 |
+
|
468 |
+
group_texts = functools.partial(
|
469 |
+
_group_texts, block_size=block_size, bos=BOS, eos=EOS)
|
470 |
+
if streaming:
|
471 |
+
chunked_dataset = tokenized_dataset.map(
|
472 |
+
group_texts,
|
473 |
+
batched=True,
|
474 |
+
desc='Grouping')
|
475 |
+
else:
|
476 |
+
chunked_dataset = tokenized_dataset.map(
|
477 |
+
group_texts,
|
478 |
+
batched=True,
|
479 |
+
num_proc=num_proc,
|
480 |
+
load_from_cache_file=True,
|
481 |
+
desc='Grouping')
|
482 |
+
chunked_dataset.save_to_disk(_path)
|
483 |
+
chunked_dataset = chunked_dataset.with_format('torch')
|
484 |
+
return chunked_dataset
|
485 |
+
|
486 |
+
|
487 |
+
def get_tokenizer(config):
|
488 |
+
if config.data.tokenizer_name_or_path == 'text8':
|
489 |
+
tokenizer = Text8Tokenizer()
|
490 |
+
elif config.data.tokenizer_name_or_path == 'bert-base-uncased':
|
491 |
+
tokenizer = transformers.BertTokenizer.\
|
492 |
+
from_pretrained('bert-base-uncased')
|
493 |
+
else:
|
494 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
495 |
+
config.data.tokenizer_name_or_path)
|
496 |
+
|
497 |
+
if (isinstance(tokenizer, transformers.GPT2TokenizerFast)
|
498 |
+
or isinstance(tokenizer, transformers.GPT2Tokenizer)):
|
499 |
+
tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
|
500 |
+
(tokenizer.bos_token, tokenizer.bos_token_id),
|
501 |
+
(tokenizer.eos_token, tokenizer.eos_token_id))
|
502 |
+
|
503 |
+
# For wrapped batches:
|
504 |
+
# [BOS] sent1 [EOS] sent2-fragment [EOS]
|
505 |
+
# [BOS] sent2-fragment [EOS] sent3 [EOS]
|
506 |
+
if tokenizer.bos_token is None:
|
507 |
+
if tokenizer.cls_token is None:
|
508 |
+
raise AttributeError(
|
509 |
+
'Tokenizer must have a bos_token or '
|
510 |
+
f'cls_token: {tokenizer}')
|
511 |
+
tokenizer.bos_token = tokenizer.cls_token
|
512 |
+
if tokenizer.eos_token is None:
|
513 |
+
if tokenizer.sep_token is None:
|
514 |
+
raise AttributeError(
|
515 |
+
'Tokenizer must have a eos_token '
|
516 |
+
f'or sep_token: {tokenizer}')
|
517 |
+
tokenizer.eos_token = tokenizer.sep_token
|
518 |
+
if tokenizer.pad_token is None:
|
519 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
520 |
+
|
521 |
+
return tokenizer
|
522 |
+
|
523 |
+
|
524 |
+
def get_dataloaders(config, tokenizer, skip_train=False,
|
525 |
+
skip_valid=False, valid_seed=None):
|
526 |
+
num_gpus = torch.cuda.device_count()
|
527 |
+
assert (config.loader.global_batch_size
|
528 |
+
== (config.loader.batch_size
|
529 |
+
* config.trainer.num_nodes
|
530 |
+
* num_gpus
|
531 |
+
* config.trainer.accumulate_grad_batches))
|
532 |
+
if config.loader.global_batch_size % (
|
533 |
+
num_gpus * config.trainer.accumulate_grad_batches) != 0:
|
534 |
+
raise ValueError(
|
535 |
+
f'Train Batch Size {config.training.batch_size}'
|
536 |
+
f'not divisible by {num_gpus} gpus with accumulation '
|
537 |
+
f'{config.trainer.accumulate_grad_batches}.')
|
538 |
+
if config.loader.eval_global_batch_size % num_gpus != 0:
|
539 |
+
raise ValueError(
|
540 |
+
f'Eval Batch Size for {config.eval.batch_size} '
|
541 |
+
f'not divisible by {num_gpus}.')
|
542 |
+
if skip_train:
|
543 |
+
train_set = None
|
544 |
+
else:
|
545 |
+
train_set = get_dataset(
|
546 |
+
config.data.train,
|
547 |
+
tokenizer,
|
548 |
+
mode='train',
|
549 |
+
wrap=config.data.wrap,
|
550 |
+
#cache_dir=config.data.cache_dir,
|
551 |
+
block_size=config.model.length)
|
552 |
+
|
553 |
+
if config.data.valid in ['text8', 'lm1b', 'ag_news']:
|
554 |
+
validation_split = 'test'
|
555 |
+
else:
|
556 |
+
validation_split = 'validation'
|
557 |
+
if skip_valid:
|
558 |
+
valid_set = None
|
559 |
+
else:
|
560 |
+
valid_set = get_dataset(
|
561 |
+
config.data.valid,
|
562 |
+
tokenizer,
|
563 |
+
wrap=config.data.wrap,
|
564 |
+
mode=validation_split,
|
565 |
+
#cache_dir=config.data.cache_dir,
|
566 |
+
block_size=config.model.length,
|
567 |
+
streaming=False)
|
568 |
+
|
569 |
+
if skip_train:
|
570 |
+
train_loader = None
|
571 |
+
else:
|
572 |
+
train_loader = torch.utils.data.DataLoader(
|
573 |
+
train_set,
|
574 |
+
batch_size=config.loader.batch_size,
|
575 |
+
num_workers=config.loader.num_workers,
|
576 |
+
pin_memory=config.loader.pin_memory,
|
577 |
+
shuffle=not config.data.streaming,
|
578 |
+
persistent_workers=True)
|
579 |
+
train_loader.tokenizer = tokenizer
|
580 |
+
if skip_valid:
|
581 |
+
valid_loader = None
|
582 |
+
else:
|
583 |
+
if valid_seed is None:
|
584 |
+
shuffle_valid = False
|
585 |
+
generator = None
|
586 |
+
else:
|
587 |
+
shuffle_valid = True
|
588 |
+
generator = torch.Generator().manual_seed(valid_seed)
|
589 |
+
valid_loader = torch.utils.data.DataLoader(
|
590 |
+
valid_set,
|
591 |
+
batch_size=config.loader.eval_batch_size,
|
592 |
+
num_workers=config.loader.num_workers,
|
593 |
+
pin_memory=config.loader.pin_memory,
|
594 |
+
shuffle=shuffle_valid,
|
595 |
+
generator=generator)
|
596 |
+
# Will be used in generative perplexity calculation
|
597 |
+
valid_loader.tokenizer = tokenizer
|
598 |
+
|
599 |
+
return train_loader, valid_loader
|
600 |
+
|
601 |
+
|
602 |
+
# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
|
603 |
+
|
604 |
+
|
605 |
+
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
|
606 |
+
|
607 |
+
def __init__(self, *args, generator=None, **kwargs):
|
608 |
+
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
|
609 |
+
# which should be reproducible if pl.seed_everything was called beforehand.
|
610 |
+
# This means that changing the seed of the experiment will also change the
|
611 |
+
# sampling order.
|
612 |
+
if generator is None:
|
613 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
614 |
+
generator = torch.Generator().manual_seed(seed)
|
615 |
+
kwargs.pop('shuffle', None)
|
616 |
+
super().__init__(*args, generator=generator, **kwargs)
|
617 |
+
self.counter = 0
|
618 |
+
self.restarting = False
|
619 |
+
|
620 |
+
def state_dict(self):
|
621 |
+
return {'random_state': self.generator.get_state(),
|
622 |
+
'counter': self.counter}
|
623 |
+
|
624 |
+
def load_state_dict(self, state_dict):
|
625 |
+
self.generator.set_state(state_dict.get('random_state'))
|
626 |
+
self.counter = state_dict['counter']
|
627 |
+
# self.start_counter = self.counter
|
628 |
+
self.restarting = True
|
629 |
+
|
630 |
+
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
|
631 |
+
# epoch, and subsequent epoch will have very few batches.
|
632 |
+
|
633 |
+
def __iter__(self) -> typing.Iterator[int]:
|
634 |
+
n = len(self.data_source)
|
635 |
+
|
636 |
+
self.state = self.generator.get_state()
|
637 |
+
indices = torch.randperm(n, generator=self.generator).tolist()
|
638 |
+
|
639 |
+
if not self.restarting:
|
640 |
+
self.counter = 0
|
641 |
+
else:
|
642 |
+
indices = indices[self.counter:]
|
643 |
+
self.restarting = False
|
644 |
+
|
645 |
+
for index in indices:
|
646 |
+
self.counter += 1
|
647 |
+
yield index
|
648 |
+
|
649 |
+
self.counter = 0
|
650 |
+
|
651 |
+
|
652 |
+
class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
|
653 |
+
|
654 |
+
def __init__(self, *args, **kwargs):
|
655 |
+
super().__init__(*args, **kwargs)
|
656 |
+
self.counter = 0
|
657 |
+
self.restarting = False
|
658 |
+
|
659 |
+
def state_dict(self):
|
660 |
+
return {'epoch': self.epoch, 'counter': self.counter}
|
661 |
+
|
662 |
+
def load_state_dict(self, state_dict):
|
663 |
+
self.epoch = state_dict['epoch']
|
664 |
+
self.counter = state_dict['counter']
|
665 |
+
self.restarting = True
|
666 |
+
|
667 |
+
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
|
668 |
+
# epoch, and subsequent epoch will have very few batches.
|
669 |
+
def __iter__(self):
|
670 |
+
if self.shuffle:
|
671 |
+
# deterministically shuffle based on epoch and seed
|
672 |
+
g = torch.Generator()
|
673 |
+
g.manual_seed(self.seed + self.epoch)
|
674 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
675 |
+
else:
|
676 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
677 |
+
|
678 |
+
if not self.drop_last:
|
679 |
+
# add extra samples to make it evenly divisible
|
680 |
+
padding_size = self.total_size - len(indices)
|
681 |
+
if padding_size <= len(indices):
|
682 |
+
indices += indices[:padding_size]
|
683 |
+
else:
|
684 |
+
indices += (indices * math.ceil(
|
685 |
+
padding_size / len(indices)))[:padding_size]
|
686 |
+
else:
|
687 |
+
# remove tail of data to make it evenly divisible.
|
688 |
+
indices = indices[:self.total_size]
|
689 |
+
assert len(indices) == self.total_size
|
690 |
+
|
691 |
+
# subsample
|
692 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
693 |
+
assert len(indices) == self.num_samples
|
694 |
+
|
695 |
+
if not self.restarting:
|
696 |
+
self.counter = 0
|
697 |
+
else:
|
698 |
+
indices = indices[self.counter:]
|
699 |
+
self.restarting = False
|
700 |
+
|
701 |
+
for index in indices:
|
702 |
+
self.counter += 1
|
703 |
+
yield index
|
704 |
+
|
705 |
+
self.counter = 0
|
706 |
+
|
707 |
+
from torch.utils.data import Dataset, DataLoader
|
708 |
+
import lightning.pytorch as pl
|
709 |
+
from functools import partial
|
710 |
+
import sys
|
711 |
+
|
712 |
+
class CustomDataset(torch.utils.data.Dataset):
|
713 |
+
def __init__(self, dataset, indices):
|
714 |
+
self.dataset = dataset
|
715 |
+
self.indices = indices
|
716 |
+
|
717 |
+
def __len__(self):
|
718 |
+
return len(self.indices)
|
719 |
+
|
720 |
+
def __getitem__(self, idx):
|
721 |
+
actual_idx = int(self.indices[idx])
|
722 |
+
item = self.dataset[actual_idx]
|
723 |
+
return item
|
724 |
+
|
725 |
+
def membrane_collate_fn(batch, tokenizer):
|
726 |
+
"""Custom data collator that masks TM/soluble residues for focused training"""
|
727 |
+
MAX_LENGTH = 1024
|
728 |
+
sequences = [item['Sequence'].upper() for item in batch]
|
729 |
+
|
730 |
+
masks = []
|
731 |
+
for item in batch:
|
732 |
+
if item["Label"] == 0:
|
733 |
+
mask = [1 if i.isupper() else 0 for i in item["Sequence"]]
|
734 |
+
else:
|
735 |
+
mask = [0 if i.isupper() else 1 for i in item["Sequence"]]
|
736 |
+
mask = [1] + mask
|
737 |
+
if len(mask) > MAX_LENGTH: # Truncate
|
738 |
+
mask = mask[:MAX_LENGTH]
|
739 |
+
elif len(mask) < MAX_LENGTH: # Pad
|
740 |
+
mask += [1] * (MAX_LENGTH - len(mask))
|
741 |
+
|
742 |
+
masks.append(torch.as_tensor(mask))
|
743 |
+
|
744 |
+
mask_t = torch.stack(masks, dim=0)
|
745 |
+
tokens = tokenizer(sequences, return_tensors='pt', padding='max_length', truncation=True, max_length=MAX_LENGTH)
|
746 |
+
|
747 |
+
return {
|
748 |
+
'input_ids': tokens['input_ids'],
|
749 |
+
'attention_mask': tokens['attention_mask'],
|
750 |
+
'mask': mask_t
|
751 |
+
}
|
752 |
+
|
753 |
+
def wrap_collate_fn(batch, tokenizer):
|
754 |
+
"""Standard data collator that wraps sequences over padding them"""
|
755 |
+
# Define sequence size
|
756 |
+
chunk_size = 1024
|
757 |
+
eos_placeholder = "k"
|
758 |
+
eos = "<eos>"
|
759 |
+
|
760 |
+
# Wrap sequences by collecting and splitting them into chunks
|
761 |
+
# From MDLM paper: insert <eos> at start/end of chunks and in between sequences
|
762 |
+
sequences = eos_placeholder.join([item['Sequence'].upper() for item in batch])
|
763 |
+
sequences = eos_placeholder + sequences + eos_placeholder
|
764 |
+
wrapped_sequences = [sequences[i:i+chunk_size] for i in range(0, len(sequences), chunk_size)]
|
765 |
+
for idx, seq in enumerate(wrapped_sequences):
|
766 |
+
wrapped_sequences[idx] = seq.replace(eos_placeholder, eos)
|
767 |
+
|
768 |
+
# Tokenize for input ids and attention masks
|
769 |
+
tokens = tokenizer(wrapped_sequences, return_tensors='pt', padding=True)
|
770 |
+
|
771 |
+
return {
|
772 |
+
"input_ids": tokens['input_ids'],
|
773 |
+
"attention_mask": tokens['attention_mask']
|
774 |
+
}
|
775 |
+
|
776 |
+
|
777 |
+
|
778 |
+
def collate_fn(batch, tokenizer):
|
779 |
+
"""Standard data collator that truncates/pad sequences based on max_length"""
|
780 |
+
sequences = [item['Sequence'].upper() for item in batch]
|
781 |
+
max_len = max([len(seq) for seq in sequences])
|
782 |
+
#labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
|
783 |
+
|
784 |
+
tokens = tokenizer(sequences, return_tensors='pt', padding='max_length', truncation=True, max_length=1024)
|
785 |
+
|
786 |
+
#attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
|
787 |
+
|
788 |
+
return {
|
789 |
+
'input_ids': tokens['input_ids'],
|
790 |
+
'attention_mask': tokens['attention_mask']
|
791 |
+
}
|
792 |
+
|
793 |
+
class CustomDataModule(pl.LightningDataModule):
|
794 |
+
def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size: int=8, collate_fn=collate_fn):
|
795 |
+
super().__init__()
|
796 |
+
self.train_dataset = train_dataset
|
797 |
+
self.val_dataset = val_dataset
|
798 |
+
self.test_dataset = test_dataset
|
799 |
+
self.batch_size = batch_size
|
800 |
+
self.tokenizer = tokenizer
|
801 |
+
self.collate_fn = collate_fn
|
802 |
+
|
803 |
+
def train_dataloader(self):
|
804 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size,
|
805 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
806 |
+
num_workers=8, pin_memory=True)
|
807 |
+
|
808 |
+
|
809 |
+
def val_dataloader(self):
|
810 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size,
|
811 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
812 |
+
num_workers=8, pin_memory=True)
|
813 |
+
|
814 |
+
def test_dataloader(self):
|
815 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size,
|
816 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
817 |
+
num_workers=8, pin_memory=True)
|
818 |
+
|
819 |
+
|
utils.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Console logger utilities.
|
2 |
+
|
3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
5 |
+
"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
|
10 |
+
import fsspec
|
11 |
+
import lightning
|
12 |
+
import torch
|
13 |
+
from timm.scheduler import CosineLRScheduler
|
14 |
+
|
15 |
+
|
16 |
+
def fsspec_exists(filename):
|
17 |
+
"""Check if a file exists using fsspec."""
|
18 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
19 |
+
return fs.exists(filename)
|
20 |
+
|
21 |
+
|
22 |
+
def fsspec_listdir(dirname):
|
23 |
+
"""Listdir in manner compatible with fsspec."""
|
24 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
25 |
+
return fs.ls(dirname)
|
26 |
+
|
27 |
+
|
28 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
29 |
+
"""Mkdirs in manner compatible with fsspec."""
|
30 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
31 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
32 |
+
|
33 |
+
|
34 |
+
def print_nans(tensor, name):
|
35 |
+
if torch.isnan(tensor).any():
|
36 |
+
print(name, tensor)
|
37 |
+
|
38 |
+
|
39 |
+
class CosineDecayWarmupLRScheduler(
|
40 |
+
CosineLRScheduler,
|
41 |
+
torch.optim.lr_scheduler._LRScheduler):
|
42 |
+
"""Wrap timm.scheduler.CosineLRScheduler
|
43 |
+
Enables calling scheduler.step() without passing in epoch.
|
44 |
+
Supports resuming as well.
|
45 |
+
Adapted from:
|
46 |
+
https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, *args, **kwargs):
|
50 |
+
super().__init__(*args, **kwargs)
|
51 |
+
self._last_epoch = -1
|
52 |
+
self.step(epoch=0)
|
53 |
+
|
54 |
+
def step(self, epoch=None):
|
55 |
+
if epoch is None:
|
56 |
+
self._last_epoch += 1
|
57 |
+
else:
|
58 |
+
self._last_epoch = epoch
|
59 |
+
# We call either step or step_update, depending on
|
60 |
+
# whether we're using the scheduler every epoch or every
|
61 |
+
# step.
|
62 |
+
# Otherwise, lightning will always call step (i.e.,
|
63 |
+
# meant for each epoch), and if we set scheduler
|
64 |
+
# interval to "step", then the learning rate update will
|
65 |
+
# be wrong.
|
66 |
+
if self.t_in_epochs:
|
67 |
+
super().step(epoch=self._last_epoch)
|
68 |
+
else:
|
69 |
+
super().step_update(num_updates=self._last_epoch)
|
70 |
+
|
71 |
+
|
72 |
+
class LoggingContext:
|
73 |
+
"""Context manager for selective logging."""
|
74 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
75 |
+
self.logger = logger
|
76 |
+
self.level = level
|
77 |
+
self.handler = handler
|
78 |
+
self.close = close
|
79 |
+
|
80 |
+
def __enter__(self):
|
81 |
+
if self.level is not None:
|
82 |
+
self.old_level = self.logger.level
|
83 |
+
self.logger.setLevel(self.level)
|
84 |
+
if self.handler:
|
85 |
+
self.logger.addHandler(self.handler)
|
86 |
+
|
87 |
+
def __exit__(self, et, ev, tb):
|
88 |
+
if self.level is not None:
|
89 |
+
self.logger.setLevel(self.old_level)
|
90 |
+
if self.handler:
|
91 |
+
self.logger.removeHandler(self.handler)
|
92 |
+
if self.handler and self.close:
|
93 |
+
self.handler.close()
|
94 |
+
|
95 |
+
|
96 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
97 |
+
"""Initializes multi-GPU-friendly python logger."""
|
98 |
+
|
99 |
+
logger = logging.getLogger(name)
|
100 |
+
logger.setLevel(level)
|
101 |
+
|
102 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
103 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
104 |
+
for level in ('debug', 'info', 'warning', 'error',
|
105 |
+
'exception', 'fatal', 'critical'):
|
106 |
+
setattr(logger,
|
107 |
+
level,
|
108 |
+
lightning.pytorch.utilities.rank_zero_only(
|
109 |
+
getattr(logger, level)))
|
110 |
+
|
111 |
+
return logger
|
112 |
+
|
113 |
+
|
114 |
+
class Sampler:
|
115 |
+
def __init__(self, shape):
|
116 |
+
self.shape = shape
|
117 |
+
|
118 |
+
def _sampling_noise(self):
|
119 |
+
pass
|
120 |
+
|
121 |
+
def _hard_sample(self, logits):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def _soft_sample(self, logits):
|
125 |
+
return 0
|
126 |
+
|
127 |
+
def sample(self, logits):
|
128 |
+
noise = self._sampling_noise()
|
129 |
+
noise = noise[: logits.shape[0], :]
|
130 |
+
logits = logits + noise.to(
|
131 |
+
dtype=logits.dtype, device=logits.device)
|
132 |
+
hard_sample = self._hard_sample(logits)
|
133 |
+
soft_sample = self._soft_sample(logits)
|
134 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
135 |
+
|
136 |
+
|
137 |
+
class TopKSampler(Sampler):
|
138 |
+
def __init__(self, k, shape, gamma_tau=1.0):
|
139 |
+
super().__init__(shape)
|
140 |
+
self.k = k
|
141 |
+
self.gamma_tau = gamma_tau
|
142 |
+
self.num_betas = 10
|
143 |
+
self.sampler = torch.distributions.gamma.Gamma(
|
144 |
+
1 / k * torch.ones(self.num_betas, * self.shape), 1.0)
|
145 |
+
|
146 |
+
def _sampling_noise(self):
|
147 |
+
noise = self.sampler.sample()
|
148 |
+
beta = self.k / torch.arange(1, self.num_betas + 1, 1,
|
149 |
+
dtype=torch.float32)
|
150 |
+
beta = beta[:, None, None]
|
151 |
+
assert beta.ndim == noise.ndim
|
152 |
+
s = noise / beta
|
153 |
+
s = torch.sum(s, axis=0)
|
154 |
+
s = s - math.log(10.0)
|
155 |
+
s = self.gamma_tau * (s / self.k)
|
156 |
+
return s
|
157 |
+
|
158 |
+
def _hard_sample(self, logits):
|
159 |
+
assert logits.ndim == 2
|
160 |
+
thresholds, _ = torch.sort(logits, dim=-1)
|
161 |
+
thresholds = thresholds[:, - self.k][:, None]
|
162 |
+
return (logits >= thresholds).type(logits.dtype)
|
163 |
+
|
164 |
+
def _soft_sample(self, logits):
|
165 |
+
soft_top_k = logits - torch.mean(logits, dim=-1,
|
166 |
+
keepdim=True)
|
167 |
+
return soft_top_k / torch.norm(soft_top_k, dim=-1,
|
168 |
+
keepdim=True)
|
169 |
+
|
170 |
+
|
171 |
+
class DeterministicTopK(TopKSampler):
|
172 |
+
def __init__(self, k):
|
173 |
+
super().__init__(k, shape=(1, 1))
|
174 |
+
|
175 |
+
def _sampling_noise(self):
|
176 |
+
return 0
|
177 |
+
|
178 |
+
def discreize(self, x):
|
179 |
+
hard_sample = self._hard_sample(x)
|
180 |
+
soft_sample = self._soft_sample(x)
|
181 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
182 |
+
|
183 |
+
class GumbelSampler(Sampler):
|
184 |
+
|
185 |
+
def __init__(self, shape, temperature=1.0):
|
186 |
+
super().__init__(shape)
|
187 |
+
self.temperature = temperature
|
188 |
+
|
189 |
+
def _sampling_noise(self):
|
190 |
+
return - (1e-10 - (
|
191 |
+
torch.rand(* self.shape) + 1e-10).log()).log()
|
192 |
+
|
193 |
+
def _hard_sample(self, logits):
|
194 |
+
assert logits.ndim == 2
|
195 |
+
indices = torch.argmax(logits, dim=-1)
|
196 |
+
zeros = logits * 0
|
197 |
+
ones = torch.ones_like(logits[:, :, :1])
|
198 |
+
return torch.scatter(zeros, -1, indices[:, :, None],
|
199 |
+
ones)
|
200 |
+
|
201 |
+
def _soft_sample(self, logits):
|
202 |
+
return torch.nn.functional.softmax(
|
203 |
+
logits / self.temperature, dim=-1)
|
204 |
+
|
205 |
+
|
206 |
+
class BinarySampler(GumbelSampler):
|
207 |
+
|
208 |
+
def sample(self, probs):
|
209 |
+
# TODO(subhamsahoo): use the temperature parameter.
|
210 |
+
pos_noise = self._sampling_noise().to(
|
211 |
+
dtype=probs.dtype, device=probs.device)
|
212 |
+
neg_noise = self._sampling_noise().to(
|
213 |
+
dtype=probs.dtype, device=probs.device)
|
214 |
+
del_noise_exp = (neg_noise - pos_noise).exp()
|
215 |
+
hard_sample = (probs * (1 + del_noise_exp)
|
216 |
+
> 1).to(probs.dtype)
|
217 |
+
soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
|
218 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
219 |
+
|
220 |
+
|
221 |
+
class GaussianSampler:
|
222 |
+
def __init__(self):
|
223 |
+
self.softplus = torch.nn.Softplus()
|
224 |
+
|
225 |
+
def sample(self, x):
|
226 |
+
assert x.ndim == 2
|
227 |
+
n = x.shape[-1] // 2
|
228 |
+
mu = x[:, :n]
|
229 |
+
sigma = self.softplus(x[:, n:]).sqrt()
|
230 |
+
return mu + sigma * torch.randn_like(mu)
|