sgoel30 commited on
Commit
d061944
·
verified ·
1 Parent(s): 360c784

Upload 12 files

Browse files
Files changed (12) hide show
  1. config.yaml +127 -0
  2. diffusion.py +1434 -0
  3. dit.py +388 -0
  4. ema.py +97 -0
  5. esm_utils.py +15 -0
  6. generate.py +60 -0
  7. main.py +250 -0
  8. mdlm_motif_benchmarking.py +96 -0
  9. mlm_generate_utils.py +108 -0
  10. noise_schedule.py +153 -0
  11. pl_data_loader.py +819 -0
  12. utils.py +230 -0
config.yaml ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
4
+ - /model: small
5
+ - /strategy: ddp
6
+ - /noise: loglinear
7
+ - /lr_scheduler: constant_warmup
8
+
9
+ mode: sample_eval # train / ppl_eval / sample_eval
10
+ diffusion: absorbing_state
11
+ backbone: membrane_esm_finetune # dit / dimamba / ar / vanilla_esm_pretrain / membrane_esm_finetune
12
+ parameterization: subs # subs / d3pm / sedd
13
+ time_conditioning: False
14
+ T: 0 # 0 (continuous time) / 1000
15
+ subs_masking: False
16
+
17
+ seed: 42
18
+
19
+ data:
20
+ train:
21
+ vanilla_esm_train_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/train.csv
22
+ membrane_esm_train_path: /workspace/sg666/MDpLM/data/membrane/train.csv
23
+ wrap: null
24
+ test:
25
+ vanilla_esm_test_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/test.csv
26
+ membrane_esm_test_path: /workspace/sg666/MDpLM/data/membrane/test.csv
27
+ wrap: null
28
+ valid:
29
+ vanilla_esm_valid_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/val.csv
30
+ membrane_esm_valid_path: /workspace/sg666/MDpLM/data/membrane/val.csv
31
+ wrap: null
32
+ wrapping: True
33
+
34
+ loader:
35
+ global_batch_size: 8
36
+ eval_global_batch_size: ${.global_batch_size}
37
+ # Note: batch_size and eval_batch_size are **per machine**
38
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
39
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
40
+ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
41
+ pin_memory: True
42
+
43
+ sampling:
44
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
45
+ steps: 128
46
+ noise_removal: True
47
+ # TODO(yair): @subham, why aren't these params under `eval`?
48
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
49
+ num_sample_log: 2
50
+ semi_ar: False
51
+ stride_length: 1
52
+ num_strides: 1
53
+
54
+ training:
55
+ ema: 0.9999
56
+ antithetic_sampling: True
57
+ importance_sampling: False
58
+ sampling_eps: 1e-3
59
+ change_of_variables: False
60
+ mlm_model_path: /workspace/sg666/MDpLM/benchmarks/MLM/model_ckpts_650M/best_model_epoch
61
+ esm_model_path: facebook/esm2_t30_150M_UR50D
62
+ focus_mask: False
63
+
64
+ eval:
65
+ checkpoint_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/eos-wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/checkpoints/best.ckpt # Used to evaluate a checkpoint after training.
66
+ disable_ema: False
67
+ compute_generative_perplexity: False
68
+ perplexity_batch_size: 8
69
+ compute_perplexity_on_sanity: False
70
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
71
+ generate_samples: True
72
+ generation_model: /workspace/sg666/MDpLM/checkpoints/membrane_automodel/epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
73
+
74
+ optim:
75
+ weight_decay: 0.075
76
+ lr: 3e-4
77
+ beta1: 0.9
78
+ beta2: 0.999
79
+ eps: 1e-8
80
+
81
+ Model:
82
+ hidden_size: 1280
83
+ cond_dim: 256
84
+ n_heads: 20
85
+ n_blocks: 4
86
+ dropout: 0.5
87
+ length: null #512
88
+ scale_by_sigma: True
89
+
90
+ trainer:
91
+ _target_: lightning.Trainer
92
+ accelerator: cuda
93
+ num_nodes: 1
94
+ devices: ${device_count:}
95
+ accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
96
+ gradient_clip_val: 1.0
97
+ precision: bf16
98
+ num_sanity_val_steps: 2
99
+ max_epochs: 60
100
+ max_steps: 1_000_000
101
+ log_every_n_steps: 10
102
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
103
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
104
+ val_check_interval: 955
105
+
106
+ wandb:
107
+ project: MDpLM_finetune_membrane_200k-seqs
108
+ notes: null
109
+ group: programmablebio
110
+ job_type: null
111
+ name: dit_test #dit_wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16
112
+ id: ${.name}_${seed}
113
+
114
+ hydra:
115
+ run:
116
+ dir: /workspace/sg666/MDpLM/outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
117
+ job:
118
+ chdir: true
119
+
120
+ checkpointing:
121
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
122
+ save_dir: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
123
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
124
+ resume_from_ckpt: false
125
+ resume_ckpt_path: ${.save_dir}/epochs30_lr3e-4_bsz8_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params_no-compile/checkpoints/last.ckpt #/checkpoints/last.ckpt
126
+ pretrained_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/vanilla_esm_pretrained_automodel/epochs10_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
127
+ finetuned_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
diffusion.py ADDED
@@ -0,0 +1,1434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ import os
4
+ import sys
5
+ import typing
6
+ from dataclasses import dataclass
7
+
8
+ import hydra.utils
9
+ import lightning as L
10
+ import numpy as np
11
+ import torch.nn as nn
12
+ import torch
13
+ # import dit
14
+ import ema
15
+ import time
16
+ import gc
17
+ import pl_data_loader as dataloader
18
+ import torch.nn.functional as F
19
+ import torchmetrics
20
+ import transformers
21
+ from torch import Tensor
22
+ from torch.optim.lr_scheduler import _LRScheduler
23
+ from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
24
+
25
+ import utils
26
+ import noise_schedule
27
+
28
+ LOG2 = math.log(2)
29
+
30
+ class CosineWarmup(_LRScheduler):
31
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
32
+ self.warmup_steps = warmup_steps
33
+ self.total_steps = total_steps
34
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
35
+ super(CosineWarmup, self).__init__(optimizer, last_epoch)
36
+
37
+ def get_lr(self):
38
+ if self.last_epoch < self.warmup_steps:
39
+ return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
40
+
41
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
42
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
43
+ decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
44
+
45
+ return [decayed_lr * base_lr for base_lr in self.base_lrs]
46
+
47
+
48
+ def _sample_categorical(categorical_probs):
49
+ gumbel_norm = (
50
+ 1e-10
51
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
52
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
53
+
54
+
55
+ def _unsqueeze(x, reference):
56
+ return x.view(
57
+ * x.shape,
58
+ * ((1,) * (len(reference.shape) - len(x.shape))))
59
+
60
+
61
+ @dataclass
62
+ class Loss:
63
+ loss: torch.FloatTensor
64
+ nlls: torch.FloatTensor
65
+ token_mask: torch.FloatTensor
66
+
67
+
68
+ class NLL(torchmetrics.aggregation.MeanMetric):
69
+ pass
70
+
71
+
72
+ class BPD(NLL):
73
+ def compute(self) -> Tensor:
74
+ """Computes the bits per dimension.
75
+
76
+ Returns:
77
+ bpd
78
+ """
79
+ return self.mean_value / self.weight / LOG2
80
+
81
+
82
+ class Perplexity(NLL):
83
+ def compute(self) -> Tensor:
84
+ """Computes the Perplexity.
85
+
86
+ Returns:
87
+ Perplexity
88
+ """
89
+ return torch.exp(self.mean_value / self.weight)
90
+
91
+
92
+ class WrapVanillaESM(nn.Module):
93
+ def __init__(self, bert_model_path):
94
+ super(WrapVanillaESM, self).__init__()
95
+ #self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
96
+ #self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
97
+ self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
98
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
99
+
100
+
101
+ def __call__(self, *args, **kwargs):
102
+ return self.model(*args, **kwargs)
103
+
104
+ def unfreeze_attn_layers(self):
105
+ model_layers = len(self.model.esm.encoder.layer)
106
+
107
+ for i, layer in enumerate(self.model.esm.encoder.layer):
108
+ if i >= model_layers-5: # fine-tune only last n layers
109
+ for module in layer.attention.self.key.modules():
110
+ for param in module.parameters():
111
+ param.requires_grad = True
112
+ for module in layer.attention.self.query.modules():
113
+ for param in module.parameters():
114
+ param.requires_grad = True
115
+ for module in layer.attention.self.value.modules():
116
+ for param in module.parameters():
117
+ param.requires_grad = True
118
+
119
+ def unfreeze_all_layers(self):
120
+ for param in self.model.parameters():
121
+ param.requires_grad = True
122
+
123
+ def forward(self, inputs, sigma, attention_mask):
124
+ logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
125
+ return logits
126
+
127
+ def save_model(self, save_dir):
128
+ self.model.save_pretrained(save_dir)
129
+ self.tokenizer.save_pretrained(save_dir)
130
+
131
+ def load_model(self, load_dir):
132
+ self.model = AutoModel.from_pretrained(load_dir)
133
+ self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
134
+
135
+ class WrapMembraneESM(nn.Module):
136
+ def __init__(self, bert_model_path):
137
+ super(WrapMembraneESM, self).__init__()
138
+ #self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
139
+ #self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
140
+ self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
141
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
142
+
143
+ def __call__(self, *args, **kwargs):
144
+ return self.model(*args, **kwargs)
145
+
146
+ def freeze_model(self):
147
+ for param in self.model.parameters():
148
+ param.requires_grad = False
149
+
150
+ def unfreeze_all_layers(self):
151
+ for param in self.model.parameters():
152
+ param.requires_grad = True
153
+
154
+ def unfreeze_attn_layers(self):
155
+ model_layers = len(self.model.esm.encoder.layer)
156
+
157
+ for i, layer in enumerate(self.model.esm.encoder.layer):
158
+ if i >= model_layers-11: # fine-tune only last n layers
159
+ for module in layer.attention.self.key.modules():
160
+ for param in module.parameters():
161
+ param.requires_grad = True
162
+ for module in layer.attention.self.query.modules():
163
+ for param in module.parameters():
164
+ param.requires_grad = True
165
+ for module in layer.attention.self.value.modules():
166
+ for param in module.parameters():
167
+ param.requires_grad = True
168
+
169
+ def forward(self, inputs, sigma, attention_mask):
170
+ logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
171
+ return logits
172
+
173
+ def save_model(self, save_dir):
174
+ self.model.save_pretrained(save_dir)
175
+ self.tokenizer.save_pretrained(save_dir)
176
+
177
+ def load_model(self, load_dir):
178
+ self.model = AutoModel.from_pretrained(load_dir)
179
+ self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
180
+
181
+ class Diffusion(L.LightningModule):
182
+ def __init__(
183
+ self,
184
+ config,
185
+ tokenizer: transformers.PreTrainedTokenizer):
186
+ super().__init__()
187
+ self.save_hyperparameters()
188
+ self.config = config
189
+
190
+ self.tokenizer = tokenizer
191
+ self.vocab_size = self.tokenizer.vocab_size
192
+ self.sampler = self.config.sampling.predictor
193
+ self.gen_ppl_eval_model_name_or_path = self.config.eval.\
194
+ gen_ppl_eval_model_name_or_path
195
+ self.antithetic_sampling = self.config.training.antithetic_sampling
196
+ self.importance_sampling = self.config.training.importance_sampling
197
+ self.change_of_variables = self.config.training.change_of_variables
198
+ if (not hasattr(self.tokenizer, 'mask_token')
199
+ or self.tokenizer.mask_token is None):
200
+ self.mask_index = self.vocab_size
201
+ self.vocab_size += 1
202
+ else:
203
+ self.mask_index = self.tokenizer.mask_token_id
204
+ self.parameterization = self.config.parameterization
205
+
206
+
207
+ # if self.config.backbone == 'dit':
208
+ # self.backbone = dit.DIT(
209
+ # self.config, vocab_size=self.vocab_size, mlm_model_path=config.training.mlm_model_path)
210
+ if self.config.backbone == "vanilla_esm_pretrain":
211
+ self.backbone = WrapVanillaESM(bert_model_path=self.config.training.esm_model_path)
212
+ self.backbone.unfreeze_all_layers()
213
+ self.backbone = torch.compile(self.backbone)
214
+ elif self.config.backbone == 'membrane_esm_finetune':
215
+ self.backbone = WrapMembraneESM(bert_model_path=self.config.checkpointing.pretrained_esm_mdlm_automodel_path)
216
+ self.backbone.unfreeze_all_layers()
217
+ # self.backbone = torch.compile(self.backbone)
218
+
219
+ # elif self.config.backbone == 'dimamba':
220
+ # self.backbone = dimamba.DiMamba(
221
+ # self.config,
222
+ # vocab_size=self.vocab_size,
223
+ # pad_token_id=self.tokenizer.pad_token_id)
224
+ # elif self.config.backbone == 'ar':
225
+ # self.backbone = autoregressive.AR(
226
+ # self.config,
227
+ # vocab_size=self.vocab_size,
228
+ # mask_index=self.mask_index)
229
+ # elif self.config.backbone == 'hf_dit':
230
+ # self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
231
+ # config.eval.checkpoint_path, trust_remote_code=True)
232
+ # else:
233
+ # raise ValueError(
234
+ # f'Unknown backbone: {self.config.backbone}')
235
+
236
+ self.T = self.config.T
237
+ self.subs_masking = self.config.subs_masking
238
+
239
+ self.softplus = torch.nn.Softplus()
240
+ # metrics are automatically reset at end of epoch
241
+ metrics = torchmetrics.MetricCollection({
242
+ 'nll': NLL(),
243
+ 'bpd': BPD(),
244
+ 'ppl': Perplexity(),
245
+ })
246
+ metrics.set_dtype(torch.float64)
247
+ self.train_metrics = metrics.clone(prefix='train/')
248
+ self.valid_metrics = metrics.clone(prefix='val/')
249
+ self.test_metrics = metrics.clone(prefix='test/')
250
+
251
+ # generative perplexity
252
+ self.gen_ppl_metric = Perplexity()
253
+ self.eval_model_tokenizer = transformers.AutoTokenizer.\
254
+ from_pretrained(self.gen_ppl_eval_model_name_or_path)
255
+ if self.eval_model_tokenizer.pad_token is None:
256
+ self.eval_model_tokenizer.pad_token =\
257
+ self.eval_model_tokenizer.eos_token
258
+ self.eval_model_tokenizer.pad_token_id =\
259
+ self.eval_model_tokenizer.eos_token_id
260
+
261
+ self.noise = noise_schedule.get_noise(self.config,
262
+ dtype=self.dtype)
263
+ if self.config.training.ema > 0:
264
+ self.ema = ema.ExponentialMovingAverage(
265
+ itertools.chain(self.backbone.parameters(),
266
+ self.noise.parameters()),
267
+ decay=self.config.training.ema)
268
+ else:
269
+ self.ema = None
270
+
271
+ self.lr = self.config.optim.lr
272
+ self.sampling_eps = self.config.training.sampling_eps
273
+ self.time_conditioning = self.config.time_conditioning
274
+ self.neg_infinity = -1000000.0
275
+ self.fast_forward_epochs = None
276
+ self.fast_forward_batches = None
277
+ self._validate_configuration()
278
+
279
+ def _validate_configuration(self):
280
+ assert not (self.change_of_variables
281
+ and self.importance_sampling)
282
+ if self.parameterization == 'sedd':
283
+ assert not self.importance_sampling
284
+ assert not self.change_of_variables
285
+ if self.parameterization == 'd3pm':
286
+ assert self.T > 0
287
+ if self.T > 0:
288
+ assert self.parameterization in {'d3pm', 'subs'}
289
+ if self.subs_masking:
290
+ assert self.parameterization == 'd3pm'
291
+
292
+ def on_load_checkpoint(self, checkpoint):
293
+ if self.ema:
294
+ self.ema.load_state_dict(checkpoint['ema'])
295
+ # Copied from:
296
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
297
+ self.fast_forward_epochs = checkpoint['loops'][
298
+ 'fit_loop']['epoch_progress']['current']['completed']
299
+ self.fast_forward_batches = checkpoint['loops'][
300
+ 'fit_loop']['epoch_loop.batch_progress'][
301
+ 'current']['completed']
302
+
303
+ def on_save_checkpoint(self, checkpoint):
304
+ if self.ema:
305
+ checkpoint['ema'] = self.ema.state_dict()
306
+ # Copied from:
307
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
308
+ # ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
309
+ # behind, so we're using the optimizer's progress.
310
+ checkpoint['loops']['fit_loop'][
311
+ 'epoch_loop.batch_progress']['total'][
312
+ 'completed'] = checkpoint['loops']['fit_loop'][
313
+ 'epoch_loop.automatic_optimization.optim_progress'][
314
+ 'optimizer']['step']['total'][
315
+ 'completed'] * self.trainer.accumulate_grad_batches
316
+ checkpoint['loops']['fit_loop'][
317
+ 'epoch_loop.batch_progress']['current'][
318
+ 'completed'] = checkpoint['loops']['fit_loop'][
319
+ 'epoch_loop.automatic_optimization.optim_progress'][
320
+ 'optimizer']['step']['current'][
321
+ 'completed'] * self.trainer.accumulate_grad_batches
322
+ # _batches_that_stepped tracks the number of global steps, not the number
323
+ # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
324
+ checkpoint['loops']['fit_loop'][
325
+ 'epoch_loop.state_dict'][
326
+ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
327
+ 'epoch_loop.automatic_optimization.optim_progress'][
328
+ 'optimizer']['step']['total']['completed']
329
+ if 'sampler' not in checkpoint.keys():
330
+ checkpoint['sampler'] = {}
331
+ if hasattr(self.trainer.train_dataloader.sampler,
332
+ 'state_dict'):
333
+ sampler_state_dict = self.trainer.\
334
+ train_dataloader.sampler.state_dict()
335
+ checkpoint['sampler'][
336
+ 'random_state'] = sampler_state_dict.get(
337
+ 'random_state', None)
338
+ else:
339
+ checkpoint['sampler']['random_state'] = None
340
+
341
+ self.backbone.save_model(self.config.checkpointing.fine_tuned_esm_mdlm_ckpt_path)
342
+
343
+ def on_train_start(self):
344
+ torch.cuda.empty_cache()
345
+ if self.ema:
346
+ self.ema.move_shadow_params_to_device(self.device)
347
+
348
+ # Adapted from:
349
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
350
+ distributed = (
351
+ self.trainer._accelerator_connector.use_distributed_sampler
352
+ and self.trainer._accelerator_connector.is_distributed)
353
+ if distributed:
354
+ sampler_cls = dataloader.FaultTolerantDistributedSampler
355
+ else:
356
+ sampler_cls = dataloader.RandomFaultTolerantSampler
357
+ updated_dls = []
358
+ for dl in self.trainer.fit_loop._combined_loader.flattened:
359
+ if hasattr(dl.sampler, 'shuffle'):
360
+ dl_sampler = sampler_cls(
361
+ dl.dataset, shuffle=dl.sampler.shuffle)
362
+ else:
363
+ dl_sampler = sampler_cls(dl.dataset)
364
+ if (distributed
365
+ and self.fast_forward_epochs is not None
366
+ and self.fast_forward_batches is not None):
367
+ dl_sampler.load_state_dict({
368
+ 'epoch': self.fast_forward_epochs,
369
+ 'counter': (self.fast_forward_batches
370
+ * self.config.loader.batch_size)})
371
+
372
+ from functools import partial
373
+ from pl_data_loader import collate_fn
374
+ collate_partial = partial(collate_fn, tokenizer=self.tokenizer)
375
+ torch.cuda.empty_cache()
376
+
377
+ updated_dls.append(
378
+ torch.utils.data.DataLoader(
379
+ dl.dataset,
380
+ batch_size=self.config.loader.batch_size,
381
+ num_workers=self.config.loader.num_workers,
382
+ pin_memory=self.config.loader.pin_memory,
383
+ sampler=dl_sampler,
384
+ shuffle=False,
385
+ persistent_workers=False,
386
+ collate_fn=collate_partial))
387
+ self.trainer.fit_loop._combined_loader.flattened = updated_dls
388
+
389
+ def optimizer_step(self, *args, **kwargs):
390
+ super().optimizer_step(*args, **kwargs)
391
+
392
+ gc.collect()
393
+ torch.cuda.empty_cache()
394
+
395
+ if self.ema:
396
+ self.ema.update(itertools.chain(
397
+ self.backbone.parameters(),
398
+ self.noise.parameters()))
399
+
400
+ # optimizer_closure = kwargs.get('optimizer_closure', None)
401
+
402
+ # params_with_grad = [p for p in itertools.chain(
403
+ # self.backbone.parameters(),
404
+ # self.noise.parameters()
405
+ # ) if p.requires_grad and p.grad_fn is not None]
406
+
407
+ # # if params_with_grad:
408
+ # # super().optimizer_step(closure=optimizer_closure)
409
+
410
+ # if self.ema:
411
+ # self.ema.update(params_with_grad)
412
+
413
+ # super().optimizer_step(*args, **kwargs)
414
+
415
+ def _subs_parameterization(self, logits, xt):
416
+ # log prob at the mask index = - infinity
417
+ logits = logits.logits
418
+ logits[:, :, self.mask_index] += self.neg_infinity
419
+ # logits[:, :, self.tokenizer.eos_token_id] += self.neg_infinity
420
+ # logits[:, :, self.tokenizer.cls_token_id] += self.neg_infinity
421
+
422
+ # Normalize the logits such that x.exp() is
423
+ # a probability distribution over vocab_size.
424
+ logits = logits - torch.logsumexp(logits, dim=-1,
425
+ keepdim=True)
426
+
427
+ # Apply updates directly in the logits matrix.
428
+ # For the logits of the unmasked tokens, set all values
429
+ # to -infinity except for the indices corresponding to
430
+ # the unmasked tokens.
431
+ unmasked_indices = (xt != self.mask_index)
432
+ logits[unmasked_indices] = self.neg_infinity
433
+ logits[unmasked_indices, xt[unmasked_indices]] = 0
434
+ return logits
435
+
436
+ def _d3pm_parameterization(self, logits):
437
+ if self.subs_masking:
438
+ logits[:, :, self.mask_index] += self.neg_infinity
439
+ logits = logits - torch.logsumexp(logits, dim=-1,
440
+ keepdim=True)
441
+ return logits
442
+
443
+ def _sedd_parameterization(self, logits, xt, sigma):
444
+ esigm1_log = torch.where(
445
+ sigma < 0.5,
446
+ torch.expm1(sigma),
447
+ sigma.exp() - 1).log().to(logits.dtype)
448
+ # logits shape
449
+ # (batch_size, diffusion_model_input_length, vocab_size)
450
+ logits = logits - esigm1_log[:, None, None] - np.log(
451
+ logits.shape[-1] - 1)
452
+ # The below scatter operation sets the log score
453
+ # for the input word to 0.
454
+ logits = torch.scatter(logits, -1, xt[..., None],
455
+ torch.zeros_like(logits[..., :1]))
456
+ return logits
457
+
458
+ def _process_sigma(self, sigma):
459
+ if sigma is None:
460
+ assert self.parameterization == 'ar'
461
+ return sigma
462
+ if sigma.ndim > 1:
463
+ sigma = sigma.squeeze(-1)
464
+ if not self.time_conditioning:
465
+ sigma = torch.zeros_like(sigma)
466
+ assert sigma.ndim == 1, sigma.shape
467
+ return sigma
468
+
469
+ def forward(self, x, sigma, attention_mask, print_logits=False):
470
+ """Returns log score."""
471
+ sigma = self._process_sigma(sigma)
472
+ with torch.amp.autocast("cuda", dtype=torch.float32):
473
+ logits = self.backbone(x, attention_mask)
474
+ # if print_logits:
475
+ # torch.set_printoptions(profile="full")
476
+ # print(logits)
477
+ # torch.set_printoptions(profile="default")
478
+ if self.parameterization == 'subs':
479
+ return self._subs_parameterization(logits=logits, xt=x)
480
+ return logits
481
+
482
+ def _d3pm_loss(self, model_output, xt, x0, t, attention_mask):
483
+ dt = 1 / self.T
484
+
485
+ if torch.is_tensor(t):
486
+ t = t[:, None]
487
+ assert t.ndim == 2
488
+ t = t.clamp(0., 1. - 1e-4)
489
+ alpha_t = 1 - t + torch.zeros_like(xt)
490
+ alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
491
+
492
+ log_x_theta_at_x0 = torch.gather(
493
+ model_output, -1, x0[:, :, None]).squeeze(-1)
494
+ log_x_theta_at_m = model_output[:, :, self.mask_index]
495
+ x_theta_at_m = log_x_theta_at_m.exp()
496
+
497
+ term_1_coef = dt / t
498
+ term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
499
+ term_1_log_dr = log_x_theta_at_x0
500
+
501
+ term_2_coef = 1 - dt / t
502
+ term_2_log_nr = term_1_log_nr
503
+ term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
504
+
505
+ L_vb_masked = (
506
+ term_1_coef * (term_1_log_nr - term_1_log_dr)
507
+ + term_2_coef * (term_2_log_nr - term_2_log_dr))
508
+
509
+ L_vb = L_vb_masked * (xt == self.mask_index)
510
+
511
+ return self.T * L_vb
512
+
513
+ def _compute_loss(self, batch, prefix):
514
+ if 'attention_mask' in batch:
515
+ attention_mask = batch['attention_mask']
516
+ else:
517
+ attention_mask = None
518
+ if 'mask' in batch: mask = batch['mask']
519
+ else: mask = None
520
+
521
+ losses = self._loss(batch['input_ids'], attention_mask, mask)
522
+ loss = losses.loss
523
+
524
+ if prefix == 'train':
525
+ self.train_metrics.update(losses.nlls, losses.token_mask)
526
+ metrics = self.train_metrics
527
+ elif prefix == 'val':
528
+ self.valid_metrics.update(losses.nlls, losses.token_mask)
529
+ metrics = self.valid_metrics
530
+ elif prefix == 'test':
531
+ self.test_metrics.update(losses.nlls, losses.token_mask)
532
+ metrics = self.test_metrics
533
+ else:
534
+ raise ValueError(f'Invalid prefix: {prefix}')
535
+
536
+ self.log_dict(metrics,
537
+ on_step=False,
538
+ on_epoch=True,
539
+ sync_dist=True)
540
+ return loss
541
+
542
+ def on_train_epoch_start(self):
543
+ self.backbone.train()
544
+ self.noise.train()
545
+
546
+ def training_step(self, batch, batch_idx):
547
+ # Initialize throughput calculation
548
+ start_time = time.time()
549
+
550
+ loss = self._compute_loss(batch, prefix='train')
551
+ self.log(name='trainer/loss',
552
+ value=loss.item(),
553
+ on_step=True,
554
+ on_epoch=False,
555
+ sync_dist=True)
556
+
557
+ # Calculate throughput
558
+ elapsed_time = time.time() - start_time
559
+ total_tokens = batch['input_ids'].numel()
560
+ throughput = total_tokens / elapsed_time
561
+
562
+ self.log(name='trainer/throughput',
563
+ value=throughput,
564
+ on_step=True,
565
+ on_epoch=False,
566
+ sync_dist=True)
567
+
568
+ return loss
569
+
570
+ def on_validation_epoch_start(self):
571
+ # params_with_grad = [p for p in itertools.chain(
572
+ # self.backbone.parameters(),
573
+ # self.noise.parameters()
574
+ # ) if p.requires_grad]
575
+ # if self.ema:
576
+ # self.ema.store(params_with_grad)
577
+ # self.ema.copy_to(params_with_grad)
578
+
579
+ gc.collect()
580
+ torch.cuda.empty_cache()
581
+ if self.ema:
582
+ self.ema.store(
583
+ itertools.chain(
584
+ self.backbone.parameters(),
585
+ self.noise.parameters()))
586
+ self.ema.copy_to(itertools.chain(
587
+ self.backbone.parameters(),
588
+ self.noise.parameters()))
589
+ self.backbone.eval()
590
+ self.noise.eval()
591
+ assert self.valid_metrics.nll.mean_value == 0
592
+ assert self.valid_metrics.nll.weight == 0
593
+
594
+
595
+ def validation_step(self, batch, batch_idx):
596
+ loss = self._compute_loss(batch, prefix='val')
597
+ self.log(name='trainer/val_loss',
598
+ value=loss.item(),
599
+ on_step=True,
600
+ on_epoch=False,
601
+ prog_bar=True,
602
+ sync_dist=True)
603
+ return loss
604
+
605
+ def on_validation_epoch_end(self):
606
+ # params_with_grad = [p for p in itertools.chain(
607
+ # self.backbone.parameters(),
608
+ # self.noise.parameters()
609
+ # ) if p.requires_grad]
610
+ # if ((self.config.eval.compute_perplexity_on_sanity
611
+ # or not self.trainer.sanity_checking)
612
+ # and self.config.eval.generate_samples
613
+ # and not self.parameterization == 'ar'):
614
+ # # (justin): implement sampling and kv cache for AR
615
+ # samples, text_samples = None, None
616
+ # for _ in range(
617
+ # self.config.sampling.num_sample_batches):
618
+ # samples = self._sample()
619
+ # # Decode the samples to be re-tokenized by eval model
620
+ # text_samples = self.tokenizer.batch_decode(samples)
621
+ # if self.config.eval.compute_generative_perplexity:
622
+ # self.compute_generative_perplexity(text_samples)
623
+ # if self.trainer.global_rank == 0 and hasattr(
624
+ # self.trainer.logger, 'log_table'):
625
+ # # Log the last generated samples
626
+ # text_samples = text_samples[
627
+ # : self.config.sampling.num_sample_log]
628
+ # self.trainer.logger.log_table(
629
+ # key=f'samples@global_step{self.global_step}',
630
+ # columns=['Generated Samples'],
631
+ # data=[[s] for s in text_samples])
632
+ # if self.config.eval.compute_generative_perplexity:
633
+ # self.log('val/gen_ppl',
634
+ # self.gen_ppl_metric,
635
+ # on_epoch=True,
636
+ # on_step=False,
637
+ # sync_dist=True)
638
+
639
+ gc.collect()
640
+ torch.cuda.empty_cache()
641
+ if self.ema:
642
+ self.ema.restore(
643
+ itertools.chain(
644
+ self.backbone.parameters(),
645
+ self.noise.parameters()))
646
+
647
+ def test_step(self, batch, batch_idx):
648
+ loss = self._compute_loss(batch, prefix='test')
649
+ self.log('test/loss',
650
+ value=loss.item(),
651
+ on_step=False,
652
+ on_epoch=True,
653
+ sync_dist=True)
654
+
655
+ if self.config.eval.compute_generative_perplexity:
656
+ samples, text_samples = None, None
657
+ for _ in range(
658
+ self.config.sampling.num_sample_batches):
659
+ samples = self._sample()
660
+ # Decode the samples to be re-tokenized by eval model
661
+ text_samples = self.tokenizer.batch_decode(samples)
662
+ if self.config.eval.compute_generative_perplexity:
663
+ self.compute_generative_perplexity(text_samples)
664
+ if self.trainer.global_rank == 0 and hasattr(
665
+ self.trainer.logger, 'log_table'):
666
+ # Log the last generated samples
667
+ text_samples = text_samples[
668
+ : self.config.sampling.num_sample_log]
669
+ self.trainer.logger.log_table(
670
+ key=f'samples@global_step{self.global_step}',
671
+ columns=['Generated Samples'],
672
+ data=[[s] for s in text_samples])
673
+ if self.config.eval.compute_generative_perplexity:
674
+ self.log('test/gen_ppl',
675
+ self.gen_ppl_metric,
676
+ on_epoch=False,
677
+ on_step=True,
678
+ sync_dist=True)
679
+
680
+ def on_test_epoch_start(self):
681
+ # params_with_grad = [p for p in itertools.chain(
682
+ # self.backbone.parameters(),
683
+ # self.noise.parameters()
684
+ # ) if p.requires_grad]
685
+
686
+ if self.ema:
687
+ self.ema.store(itertools.chain(
688
+ self.backbone.parameters(),
689
+ self.noise.parameters()))
690
+ self.ema.copy_to(itertools.chain(
691
+ self.backbone.parameters(),
692
+ self.noise.parameters()))
693
+
694
+ self.backbone.eval()
695
+ self.noise.eval()
696
+ self.test_metrics.reset()
697
+
698
+ def on_test_epoch_end(self):
699
+ # params_with_grad = [p for p in itertools.chain(
700
+ # self.backbone.parameters(),
701
+ # self.noise.parameters()
702
+ # ) if p.requires_grad]
703
+
704
+ if self.ema:
705
+ self.ema.restore(itertools.chain(
706
+ self.backbone.parameters(),
707
+ self.noise.parameters()))
708
+
709
+ for metric_name, metric_value in self.test_metrics.compute().items():
710
+ self.log(metric_name, metric_value, sync_dist=True)
711
+
712
+ def configure_optimizers(self):
713
+ # (yair): Lightning currently giving this warning when using `fp16`:
714
+ # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
715
+ # Not clear if this is a problem or not.
716
+ # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
717
+
718
+ # params_with_grad = [p for p in itertools.chain(
719
+ # self.backbone.parameters(),
720
+ # self.noise.parameters()
721
+ # ) if p.requires_grad]
722
+
723
+ optimizer = torch.optim.AdamW(
724
+ itertools.chain(self.backbone.parameters(),
725
+ self.noise.parameters()),
726
+ lr=self.config.optim.lr,
727
+ betas=(self.config.optim.beta1,
728
+ self.config.optim.beta2),
729
+ eps=self.config.optim.eps,
730
+ weight_decay=self.config.optim.weight_decay
731
+ )
732
+
733
+ # scheduler = hydra.utils.instantiate(
734
+ # self.config.lr_scheduler, optimizer=optimizer)
735
+ # scheduler_dict = {
736
+ # 'scheduler': scheduler,
737
+ # 'interval': 'step',
738
+ # 'monitor': 'val/loss',
739
+ # 'name': 'trainer/lr',
740
+ # }
741
+
742
+ self.total_steps = self.config.trainer.max_steps
743
+ scheduler = CosineWarmup(optimizer,
744
+ warmup_steps=self.config.lr_scheduler.num_warmup_steps,
745
+ total_steps=self.total_steps)
746
+
747
+ scheduler_dict = {
748
+ 'scheduler': scheduler,
749
+ 'interval': 'step',
750
+ 'frequency': 1,
751
+ 'monitor': 'val/loss',
752
+ 'name': 'trainer/lr'
753
+ }
754
+
755
+ return [optimizer], [scheduler_dict]
756
+
757
+ @torch.no_grad()
758
+ def eval_retokenize(self, text_samples, max_length):
759
+ """Retokenizes samples for the eval model.
760
+
761
+ Args:
762
+ text_samples: List of sentences generated by the model.
763
+ Returns:
764
+ samples: Samples re-tokenized for the eval model
765
+ attn_mask: Attention mask for the eval model
766
+ eval_context_size: Size of the context for the eval model
767
+ """
768
+ if 'llama2' in self.gen_ppl_eval_model_name_or_path:
769
+ tokenizer_kwargs = {
770
+ 'text_samples': text_samples,
771
+ 'return_tensors': 'pt',
772
+ 'return_token_type_ids': False,
773
+ 'return_attention_mask': True,
774
+ 'truncation': True,
775
+ 'padding': True,
776
+ 'max_length': max_length,
777
+ }
778
+ eval_context_size = 4096
779
+ else:
780
+ tokenizer_kwargs = {
781
+ 'return_tensors': 'pt',
782
+ 'return_token_type_ids': False,
783
+ 'return_attention_mask': True,
784
+ 'truncation': True,
785
+ 'padding': True,
786
+ 'max_length': max_length,
787
+ }
788
+ eval_context_size = 1024
789
+ samples = self.eval_model_tokenizer(
790
+ text_samples, ** tokenizer_kwargs)
791
+ attn_mask = samples['attention_mask']
792
+ samples = samples['input_ids']
793
+ if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
794
+ attn_mask = attn_mask.to(self.device)
795
+ samples = samples.to(self.device)
796
+ return samples, attn_mask, eval_context_size
797
+
798
+ # @torch.no_grad()
799
+ # def compute_generative_perplexity(
800
+ # self,
801
+ # text_samples: typing.List[str],
802
+ # retokenize: bool = True,
803
+ # max_length: typing.Optional[int] = None) -> None:
804
+ # """Compute the generative perplexity of the model.
805
+
806
+ # Args:
807
+ # text_samples: List of sentences generated by the model.
808
+
809
+ # Returns:
810
+ # Perplexity of the generated text under a different
811
+ # pre-trained AR model (e.g., GPT2).
812
+ # """
813
+ # os.environ['TOKENIZERS_PARALLELISM'] = 'false'
814
+ # eval_model = transformers.AutoModelForCausalLM.from_pretrained(
815
+ # self.gen_ppl_eval_model_name_or_path).eval()
816
+ # if max_length is None:
817
+ # max_length = self.config.model.length
818
+ # if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
819
+ # eval_model = eval_model.to(self.device)
820
+ # # Re-tokenize using eval model's tokenizer
821
+ # if retokenize:
822
+ # (samples, attn_mask,
823
+ # eval_context_size) = self.eval_retokenize(
824
+ # text_samples, max_length=max_length)
825
+ # else:
826
+ # samples = text_samples
827
+ # attn_mask = torch.ones(samples.shape).to(self.device)
828
+ # eval_context_size = samples.shape[-1]
829
+ # batch_size = min(
830
+ # self.config.eval.perplexity_batch_size,
831
+ # samples.shape[0])
832
+ # num_batches = samples.shape[0] // batch_size
833
+ # for i in range(num_batches):
834
+ # _samples = torch.split(
835
+ # samples[i * batch_size: (i + 1) * batch_size],
836
+ # eval_context_size,
837
+ # dim=-1)
838
+ # _attn_mask = torch.split(
839
+ # attn_mask[i * batch_size: (i + 1) * batch_size],
840
+ # eval_context_size,
841
+ # dim=-1)
842
+ # for (sample_chunk, attn_mask_chunk) in zip(
843
+ # _samples, _attn_mask):
844
+ # logits = eval_model(
845
+ # sample_chunk, attention_mask=attn_mask_chunk)[0]
846
+ # logits = logits.transpose(-1, -2)
847
+
848
+ # nlls = F.cross_entropy(logits[..., :-1],
849
+ # sample_chunk[..., 1:],
850
+ # reduction='none')
851
+ # first_eos = (sample_chunk == self.eval_model_tokenizer\
852
+ # .eos_token_id).cumsum(-1) == 1
853
+ # token_mask = (
854
+ # sample_chunk
855
+ # != self.eval_model_tokenizer.eos_token_id)
856
+ # self.gen_ppl_metric.update(
857
+ # nlls, first_eos[..., 1:] + token_mask[..., 1:])
858
+
859
+
860
+ @torch.no_grad()
861
+ def compute_masked_perplexity(self, sequences, masked):
862
+ """Compute the pseudo-perplexity of the generated protein sequences."""
863
+ total_nll = 0
864
+ total_tokens = 0
865
+
866
+ for sequence in sequences:
867
+ # Tokenize the sequence
868
+ input_ids = self.tokenizer(masked, return_tensors="pt").input_ids.to(self.device)
869
+ gt_ids = self.tokenizer(sequence.upper(), return_tensors="pt").input_ids.to(self.device)
870
+
871
+ # print(input_ids.shape)
872
+ # print(gt_ids.shape)
873
+
874
+ # Forward pass through the ESM model
875
+ attention_mask = torch.ones_like(input_ids)
876
+ if self.config.mode in ['train', 'ppl_eval']:
877
+ outputs = self.backbone.model.forward(input_ids=input_ids, attention_mask=attention_mask)
878
+ elif self.config.mode == "sample_eval":
879
+ outputs = self.backbone.model.forward(input_ids)
880
+ logits = outputs[-1] # B, L, V
881
+
882
+ # Compute loss
883
+ # shift_logits = logits[:, :-1, :].contiguous() # remove eos
884
+ # shift_labels = input_ids[:, 1:].contiguous()
885
+ # print(masked)
886
+ # print(gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1))
887
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
888
+ gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1),
889
+ reduction='sum')
890
+
891
+ total_nll += loss.item()
892
+ #total_tokens += (input_ids != self.tokenizer.pad_token_id).sum().item() - 1 # -1 for the first token
893
+ total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
894
+ # Compute pseudo-perplexity
895
+ # print(total_nll, ",;,", total_tokens)
896
+ pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
897
+ self.gen_ppl_metric.update(pseudo_perplexity)
898
+
899
+ return pseudo_perplexity.item()
900
+
901
+ @torch.no_grad()
902
+ def compute_generative_perplexity(
903
+ self,
904
+ text_samples: typing.List[str],
905
+ retokenize: bool = True,
906
+ max_length: typing.Optional[int] = None) -> None:
907
+ """Compute the generative perplexity of the model.
908
+
909
+ Args:
910
+ text_samples: List of sentences generated by the model.
911
+
912
+ Returns:
913
+ Perplexity of the generated text under a different
914
+ pre-trained AR model (e.g., GPT2).
915
+ """
916
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
917
+ eval_model = transformers.AutoModelForCausalLM.from_pretrained(
918
+ self.gen_ppl_eval_model_name_or_path).eval()
919
+ if max_length is None:
920
+ max_length = self.config.model.length
921
+ if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
922
+ eval_model = eval_model.to(self.device)
923
+ # Re-tokenize using eval model's tokenizer
924
+ if retokenize:
925
+ (samples, attn_mask,
926
+ eval_context_size) = self.eval_retokenize(
927
+ text_samples, max_length=max_length)
928
+ else:
929
+ samples = text_samples
930
+ attn_mask = torch.ones(samples.shape).to(self.device)
931
+ eval_context_size = samples.shape[-1]
932
+ batch_size = min(
933
+ self.config.eval.perplexity_batch_size,
934
+ samples.shape[0])
935
+ num_batches = samples.shape[0] // batch_size
936
+ for i in range(num_batches):
937
+ _samples = torch.split(
938
+ samples[i * batch_size: (i + 1) * batch_size],
939
+ eval_context_size,
940
+ dim=-1)
941
+ _attn_mask = torch.split(
942
+ attn_mask[i * batch_size: (i + 1) * batch_size],
943
+ eval_context_size,
944
+ dim=-1)
945
+ for (sample_chunk, attn_mask_chunk) in zip(
946
+ _samples, _attn_mask):
947
+ logits = eval_model(
948
+ sample_chunk, attention_mask=attn_mask_chunk)[0]
949
+ logits = logits.transpose(-1, -2)
950
+
951
+ nlls = F.cross_entropy(logits[..., :-1],
952
+ sample_chunk[..., 1:],
953
+ reduction='none')
954
+ first_eos = (sample_chunk == self.eval_model_tokenizer\
955
+ .eos_token_id).cumsum(-1) == 1
956
+ token_mask = (
957
+ sample_chunk
958
+ != self.eval_model_tokenizer.eos_token_id)
959
+ self.gen_ppl_metric.update(
960
+ nlls, first_eos[..., 1:] + token_mask[..., 1:])
961
+
962
+ def q_xt(self, x, move_chance):
963
+ """Computes the noisy sample xt.
964
+
965
+ Args:
966
+ x: int torch.Tensor with shape (batch_size,
967
+ diffusion_model_input_length), input.
968
+ move_chance: float torch.Tensor with shape (batch_size, 1).
969
+ """
970
+
971
+ actual_seq_length = (x != 1).sum(dim=1, keepdim=True)
972
+
973
+ max_mask_length = (actual_seq_length * 0.75).long()
974
+
975
+ move_indices = torch.rand(*x.shape, device=x.device) < move_chance
976
+
977
+ restricted_move_indices = torch.zeros_like(move_indices, dtype=torch.bool)
978
+
979
+ for i in range(x.shape[0]):
980
+ true_positions = torch.where(move_indices[i])[0]
981
+ if len(true_positions) > max_mask_length[i]:
982
+ selected_positions = true_positions[:max_mask_length[i].item()]
983
+ restricted_move_indices[i, selected_positions] = True
984
+ else:
985
+ restricted_move_indices[i] = move_indices[i]
986
+ xt = torch.where(restricted_move_indices, self.mask_index, x)
987
+
988
+ return xt
989
+
990
+ def _sample_prior(self, *batch_dims):
991
+ return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
992
+
993
+ def _ddpm_caching_update(self, x, t, dt, p_x0=None, attention_mask=None):
994
+ assert self.config.noise.type == 'loglinear'
995
+ sigma_t, _ = self.noise(t)
996
+ if t.ndim > 1:
997
+ t = t.squeeze(-1)
998
+ assert t.ndim == 1
999
+ move_chance_t = t[:, None, None]
1000
+ move_chance_s = (t - dt)[:, None, None]
1001
+ assert move_chance_t.ndim == 3, move_chance_t.shape
1002
+ if p_x0 is None:
1003
+ p_x0 = self.forward(x, sigma_t, attention_mask).exp()
1004
+
1005
+ assert move_chance_t.ndim == p_x0.ndim
1006
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
1007
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1008
+ _x = _sample_categorical(q_xs)
1009
+
1010
+ copy_flag = (x != self.mask_index).to(x.dtype)
1011
+ return p_x0, copy_flag * x + (1 - copy_flag) * _x
1012
+
1013
+ def _ddpm_update(self, x, t, dt, attention_mask):
1014
+ sigma_t, _ = self.noise(t)
1015
+ sigma_s, _ = self.noise(t - dt)
1016
+ if sigma_t.ndim > 1:
1017
+ sigma_t = sigma_t.squeeze(-1)
1018
+ if sigma_s.ndim > 1:
1019
+ sigma_s = sigma_s.squeeze(-1)
1020
+ assert sigma_t.ndim == 1, sigma_t.shape
1021
+ assert sigma_s.ndim == 1, sigma_s.shape
1022
+ move_chance_t = 1 - torch.exp(-sigma_t)
1023
+ move_chance_s = 1 - torch.exp(-sigma_s)
1024
+ move_chance_t = move_chance_t[:, None, None]
1025
+ move_chance_s = move_chance_s[:, None, None]
1026
+ unet_conditioning = sigma_t
1027
+ log_p_x0 = self.forward(x, unet_conditioning, attention_mask)
1028
+ assert move_chance_t.ndim == log_p_x0.ndim
1029
+ # Technically, this isn't q_xs since there's a division
1030
+ # term that is missing. This division term doesn't affect
1031
+ # the samples.
1032
+ q_xs = log_p_x0.exp() * (move_chance_t
1033
+ - move_chance_s)
1034
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1035
+ _x = _sample_categorical(q_xs)
1036
+
1037
+ copy_flag = (x != self.mask_index).to(x.dtype)
1038
+ return copy_flag * x + (1 - copy_flag) * _x
1039
+
1040
+ def _ar_sampler(self, bsz):
1041
+ # precompute token buffer
1042
+ num_pred_tokens = self.config.model.length - 1
1043
+ x = torch.zeros(
1044
+ (bsz, num_pred_tokens + 1),
1045
+ dtype=torch.long,
1046
+ device=self.device)
1047
+ x[:, 0] = self.tokenizer.bos_token_id
1048
+ # precompute noise
1049
+ noise = (torch.distributions.Gumbel(0, 1)
1050
+ .sample((bsz, num_pred_tokens, self.vocab_size))
1051
+ .to(self.device))
1052
+ for i in range(num_pred_tokens):
1053
+ next_logits = self.forward(x[:, :i + 1], None)[:, -1]
1054
+ y = (next_logits + noise[:, i]).argmax(-1)
1055
+ x[:, i + 1] = y
1056
+ return x
1057
+
1058
+ @torch.no_grad()
1059
+ def _sample(self, num_steps=None, eps=1e-5, x_input = None):
1060
+ """Generate samples from the model."""
1061
+ batch_size_per_gpu = self.config.eval.perplexity_batch_size
1062
+ if self.parameterization == 'ar':
1063
+ return self._ar_sampler(batch_size_per_gpu)
1064
+ # Lightning auto-casting is not working in this method for some reason
1065
+ if num_steps is None:
1066
+ num_steps = self.config.sampling.steps
1067
+ if x_input is not None:
1068
+ x = x_input.input_ids
1069
+ attention_mask = x_input.attention_mask
1070
+ else:
1071
+ x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
1072
+ attention_mask = torch.ones_like(x)
1073
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
1074
+ dt = (1 - eps) / num_steps
1075
+ p_x0_cache = None
1076
+
1077
+ for i in range(num_steps):
1078
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
1079
+ if self.sampler == 'ddpm':
1080
+ x = self._ddpm_update(x, t, dt)
1081
+ elif self.sampler == 'ddpm_cache':
1082
+ p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, attention_mask=attention_mask)
1083
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
1084
+ # Disable caching
1085
+ p_x0_cache = None
1086
+ x = x_next
1087
+ # print(self.tokenizer.decode(x.squeeze()))
1088
+ else:
1089
+ x = self._analytic_update(x, t, dt, attention_mask)
1090
+
1091
+ if self.config.sampling.noise_removal:
1092
+ t = timesteps[-1] * torch.ones(x.shape[0], 1,
1093
+ device=self.device)
1094
+ if self.sampler == 'analytic':
1095
+ x = self._denoiser_update(x, t)
1096
+ else:
1097
+ unet_conditioning = self.noise(t)[0]
1098
+ x = self.forward(x, unet_conditioning, attention_mask, print_logits=True).argmax(dim=-1)
1099
+ # print(self.tokenizer.decode(x.squeeze()))
1100
+ return x
1101
+
1102
+ def restore_model_and_sample(self, num_steps, eps=1e-5):
1103
+ """Generate samples from the model."""
1104
+ # Lightning auto-casting is not working in this method for some reason
1105
+ # params_with_grad = [p for p in itertools.chain(
1106
+ # self.backbone.parameters(),
1107
+ # self.noise.parameters()
1108
+ # ) if p.requires_grad]
1109
+
1110
+ if self.ema:
1111
+ self.ema.store(itertools.chain(self.backbone.parameters(),
1112
+ self.noise.parameters()))
1113
+ self.ema.copy_to(itertools.chain(self.backbone.parameters(),
1114
+ self.noise.parameters()))
1115
+ self.backbone.eval()
1116
+ self.noise.eval()
1117
+ samples = self._sample(num_steps=num_steps, eps=eps)
1118
+ if self.ema:
1119
+ self.ema.restore(itertools.chain(self.backbone.parameters(),
1120
+ self.noise.parameters()))
1121
+ self.backbone.train()
1122
+ self.noise.train()
1123
+ return samples
1124
+
1125
+ def get_score(self, x, sigma, attention_mask=None):
1126
+ model_output = self.forward(x, sigma, attention_mask)
1127
+ if self.parameterization == 'subs':
1128
+ # score(x, t) = p_t(y) / p_t(x)
1129
+ # => log score(x, t) = log p_t(y) - log p_t(x)
1130
+
1131
+ # case 1: x = masked
1132
+ # (i) y = unmasked
1133
+ # log score(x, t) = log p_\theta(x)|_y + log k
1134
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1135
+ # (ii) y = masked
1136
+ # log score(x, t) = 0
1137
+
1138
+ # case 2: x = unmasked
1139
+ # (i) y != masked, y != x
1140
+ # log score(x_i, t) = - inf
1141
+ # (ii) y = x
1142
+ # log score(x_i, t) = 0
1143
+ # (iii) y = masked token
1144
+ # log score(x_i, t) = - log k
1145
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1146
+
1147
+ log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
1148
+ assert log_k.ndim == 1
1149
+
1150
+ masked_score = model_output + log_k[:, None, None]
1151
+ masked_score[:, :, self.mask_index] = 0
1152
+
1153
+ unmasked_score = self.neg_infinity * torch.ones_like(
1154
+ model_output)
1155
+ unmasked_score = torch.scatter(
1156
+ unmasked_score,
1157
+ -1,
1158
+ x[..., None],
1159
+ torch.zeros_like(unmasked_score[..., :1]))
1160
+ unmasked_score[:, :, self.mask_index] = - (
1161
+ log_k[:, None] * torch.ones_like(x))
1162
+
1163
+ masked_indices = (x == self.mask_index).to(
1164
+ model_output.dtype)[:, :, None]
1165
+ model_output = (
1166
+ masked_score * masked_indices
1167
+ + unmasked_score * (1 - masked_indices))
1168
+ return model_output.exp()
1169
+
1170
+ def _staggered_score(self, score, dsigma):
1171
+ score = score.clone()
1172
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
1173
+ score *= dsigma.exp()[:, None]
1174
+ score[..., self.mask_index] += extra_const
1175
+ return score
1176
+
1177
+ def _analytic_update(self, x, t, step_size, attention_mask=None):
1178
+ curr_sigma, _ = self.noise(t)
1179
+ next_sigma, _ = self.noise(t - step_size)
1180
+ dsigma = curr_sigma - next_sigma
1181
+ score = self.get_score(x, curr_sigma, attention_mask)
1182
+ stag_score = self._staggered_score(score, dsigma)
1183
+ probs = stag_score * self._transp_transition(x, dsigma)
1184
+ return _sample_categorical(probs)
1185
+
1186
+ def _denoiser_update(self, x, t):
1187
+ sigma, _ = self.noise(t)
1188
+ score = self.get_score(x, sigma)
1189
+ stag_score = self._staggered_score(score, sigma)
1190
+ probs = stag_score * self._transp_transition(x, sigma)
1191
+ probs[..., self.mask_index] = 0
1192
+ samples = _sample_categorical(probs)
1193
+ return samples
1194
+
1195
+ def _transp_transition(self, i, sigma):
1196
+ sigma = _unsqueeze(sigma, reference=i[..., None])
1197
+ edge = torch.exp(-sigma) * F.one_hot(
1198
+ i, num_classes=self.vocab_size)
1199
+ edge += torch.where(i == self.mask_index,
1200
+ 1 - torch.exp(-sigma).squeeze(-1),
1201
+ 0)[..., None]
1202
+ return edge
1203
+
1204
+ def _sample_t(self, n, device):
1205
+ _eps_t = torch.rand(n, device=device)
1206
+ if self.antithetic_sampling:
1207
+ offset = torch.arange(n, device=device) / n
1208
+ _eps_t = (_eps_t / n + offset) % 1
1209
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
1210
+ if self.importance_sampling:
1211
+ return self.noise.importance_sampling_transformation(t)
1212
+ return t
1213
+
1214
+ def _maybe_sub_sample(self, x0, attention_mask):
1215
+ # seqlen = x0.shape[1]
1216
+ # if seqlen > self.config.model.length:
1217
+ # assert seqlen == 2 * self.config.model.length
1218
+ # # cropping is needed for text8-crop dataset
1219
+ # # try the same starting point for now
1220
+ # start = np.random.choice(self.config.model.length)
1221
+ # end = start + self.config.model.length
1222
+ # input_tokens = x0[:, start: end]
1223
+ # output_tokens = x0[:, start + 1: end + 1]
1224
+ # new_attention_mask = attention_mask[:, start: end]
1225
+
1226
+ # # Helps with validation PPL, since the val
1227
+ # # examples will all start and end with BOS/EOS
1228
+ # input_tokens[:, 0] = self.tokenizer.bos_token_id
1229
+ # output_tokens[:, -1] = self.tokenizer.eos_token_id
1230
+ # elif self.parameterization == 'ar':
1231
+ # input_tokens = x0[:, :-1]
1232
+ # output_tokens = x0[:, 1:]
1233
+ # new_attention_mask = attention_mask[:, 1:]
1234
+ # else:
1235
+ input_tokens = x0
1236
+ output_tokens = None
1237
+ new_attention_mask = attention_mask
1238
+ return input_tokens, output_tokens, new_attention_mask
1239
+
1240
+ def _reconstruction_loss(self, x0, attention_mask):
1241
+ t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
1242
+ device=self.device)
1243
+ assert self.config.noise.type == 'loglinear'
1244
+ # The above assert is for d3pm parameterization
1245
+ unet_conditioning = self.noise(t0)[0][:, None]
1246
+ model_output_t0 = self.forward(x0, unet_conditioning, attention_mask)
1247
+ return - torch.gather(input=model_output_t0,
1248
+ dim=-1,
1249
+ index=x0[:, :, None]).squeeze(-1)
1250
+
1251
+ def _forward_pass_diffusion(self, x0, attention_mask, mask=None):
1252
+ t = self._sample_t(x0.shape[0], x0.device)
1253
+ if self.T > 0:
1254
+ t = (t * self.T).to(torch.int)
1255
+ t = t / self.T
1256
+ # t \in {1/T, 2/T, ..., 1}
1257
+ t += (1 / self.T)
1258
+
1259
+ if self.change_of_variables:
1260
+ unet_conditioning = t[:, None]
1261
+ f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
1262
+ f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
1263
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
1264
+ move_chance = move_chance[:, None]
1265
+ else:
1266
+ sigma, dsigma = self.noise(t)
1267
+ unet_conditioning = sigma[:, None]
1268
+ move_chance = 1 - torch.exp(-sigma[:, None])
1269
+
1270
+ if mask is None: xt = self.q_xt(x0, move_chance)
1271
+ else: xt = x0.where(mask==1, torch.full_like(x0, self.tokenizer.mask_token_id))
1272
+ model_output = self.forward(xt, unet_conditioning, attention_mask)
1273
+ # print(self.tokenizer.decode(torch.argmax(model_output[0], dim=-1)))
1274
+
1275
+ utils.print_nans(model_output, 'model_output')
1276
+
1277
+ if self.parameterization == 'sedd':
1278
+ return dsigma[:, None] * self._score_entropy(
1279
+ model_output, sigma[:, None], xt, x0)
1280
+
1281
+ if self.T > 0:
1282
+ diffusion_loss = self._d3pm_loss(
1283
+ model_output=model_output, xt=xt, x0=x0, t=t)
1284
+ if self.parameterization == 'd3pm':
1285
+ reconstruction_loss = self._reconstruction_loss(x0)
1286
+ elif self.parameterization == 'subs':
1287
+ reconstruction_loss = 0
1288
+ return reconstruction_loss + diffusion_loss
1289
+
1290
+ # SUBS parameterization, continuous time.
1291
+ log_p_theta = torch.gather(
1292
+ input=model_output,
1293
+ dim=-1,
1294
+ index=x0[:, :, None]).squeeze(-1)
1295
+
1296
+ if self.change_of_variables or self.importance_sampling:
1297
+ return log_p_theta * torch.log1p(
1298
+ - torch.exp(- self.noise.sigma_min))
1299
+
1300
+ return - log_p_theta * (
1301
+ dsigma / torch.expm1(sigma))[:, None]
1302
+
1303
+ def _loss(self, x0, attention_mask, mask=None):
1304
+ (input_tokens, output_tokens,
1305
+ attention_mask) = self._maybe_sub_sample(
1306
+ x0, attention_mask)
1307
+
1308
+ if self.parameterization == 'ar':
1309
+ logprobs = self.backbone(input_tokens, None, attention_mask)
1310
+ loss = - logprobs.gather(
1311
+ -1, output_tokens[:, :, None])[:, :, 0]
1312
+ else:
1313
+ loss = self._forward_pass_diffusion(input_tokens, attention_mask, mask)
1314
+
1315
+ nlls = loss * attention_mask
1316
+ count = attention_mask.sum()
1317
+
1318
+ batch_nll = nlls.sum()
1319
+ token_nll = batch_nll / count
1320
+
1321
+ return Loss(loss=token_nll,
1322
+ nlls=nlls,
1323
+ token_mask=attention_mask)
1324
+
1325
+ def _score_entropy(self, log_score, sigma, xt, x0):
1326
+ """Computes the SEDD loss.
1327
+
1328
+ Args:
1329
+ log_score: float torch.Tensor with shape (batch_size,
1330
+ diffusion_model_input_length, vocab_size),
1331
+ log score, output of the denoising network.
1332
+ xt: int torch.Tensor with shape (batch_size,
1333
+ diffusion_model_input_length), input.
1334
+ x0: int torch.Tensor with shape (batch_size,
1335
+ diffusion_model_input_length), input.
1336
+ sigma: float torch.Tensor with shape (batch_size, 1).
1337
+
1338
+ Returns:
1339
+ loss with shape (batch_size, diffusion_model_input_length)
1340
+ """
1341
+ masked_indices = xt == self.mask_index
1342
+
1343
+ expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
1344
+ q_ratio = 1 / expsig_minus_1[masked_indices]
1345
+
1346
+ words_that_were_masked = x0[masked_indices]
1347
+
1348
+ neg_term = q_ratio * torch.gather(
1349
+ log_score[masked_indices],
1350
+ -1,
1351
+ words_that_were_masked[..., None]).squeeze(-1)
1352
+ score = log_score[masked_indices].exp()
1353
+ if self.mask_index == self.vocab_size - 1:
1354
+ pos_term = score[:, :-1].sum(dim=-1)
1355
+ else:
1356
+ pos_term = score[:, : self.mask_index].sum(
1357
+ dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
1358
+ const = q_ratio * (q_ratio.log() - 1)
1359
+
1360
+ entropy = torch.zeros(* xt.shape, device=xt.device)
1361
+ entropy[masked_indices] += pos_term - neg_term + const
1362
+ return entropy
1363
+
1364
+ @torch.no_grad
1365
+ def sample_subs_guidance(
1366
+ self, n_samples, stride_length, num_strides, dt=0.001):
1367
+ ones = torch.ones(n_samples, dtype=self.dtype,
1368
+ device=self.device)
1369
+
1370
+ num_steps = int(1 / dt)
1371
+ sampling_steps = 0
1372
+ intermediate_tokens = []
1373
+ target = None
1374
+ for _ in range(num_strides + 1):
1375
+ p_x0_cache = None
1376
+ x = self._sample_prior(
1377
+ n_samples,
1378
+ self.config.model.length).to(self.device)
1379
+ if target is not None:
1380
+ x[:, : -stride_length] = target
1381
+ for i in range(num_steps + 1):
1382
+ p_x0_cache, x_next = self._ddpm_caching_update(
1383
+ x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
1384
+ if (not torch.allclose(x_next, x)
1385
+ or self.time_conditioning):
1386
+ p_x0_cache = None
1387
+ sampling_steps += 1
1388
+ x = x_next
1389
+ x = self.forward(x, 0 * ones).argmax(dim=-1)
1390
+ intermediate_tokens.append(
1391
+ x[:, :stride_length].cpu().numpy())
1392
+ target = x[:, stride_length:]
1393
+
1394
+ intermediate_tokens.append(target.cpu().numpy())
1395
+ intermediate_text_samples = []
1396
+ sequence_lengths = ((
1397
+ np.concatenate(intermediate_tokens, axis=1)[:, 1:]
1398
+ == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
1399
+ for i in range(2, len(intermediate_tokens) + 1):
1400
+ intermediate_text_samples.append(
1401
+ self.tokenizer.batch_decode(
1402
+ np.concatenate(intermediate_tokens[:i], axis=1)))
1403
+ return (sampling_steps, intermediate_text_samples,
1404
+ sequence_lengths)
1405
+
1406
+ def restore_model_and_semi_ar_sample(
1407
+ self, stride_length, num_strides, dt=0.001):
1408
+ """Generate samples from the model."""
1409
+ # Lightning auto-casting is not working in this method for some reason
1410
+
1411
+ # params_with_grad = [p for p in itertools.chain(
1412
+ # self.backbone.parameters(),
1413
+ # self.noise.parameters()
1414
+ # ) if p]
1415
+
1416
+ if self.ema:
1417
+ self.ema.store(itertools.chain(self.backbone.parameters(),
1418
+ self.noise.parameters()))
1419
+ self.ema.copy_to(itertools.chain(self.backbone.parameters(),
1420
+ self.noise.parameters()))
1421
+ self.backbone.eval()
1422
+ self.noise.eval()
1423
+ (sampling_steps, samples,
1424
+ sequence_lengths) = self.sample_subs_guidance(
1425
+ n_samples=self.config.loader.eval_batch_size,
1426
+ stride_length=stride_length,
1427
+ num_strides=num_strides,
1428
+ dt=dt)
1429
+ if self.ema:
1430
+ self.ema.restore(itertools.chain(self.backbone.parameters(),
1431
+ self.noise.parameters()))
1432
+ self.backbone.train()
1433
+ self.noise.train()
1434
+ return sampling_steps, samples, sequence_lengths
dit.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import flash_attn
5
+ import flash_attn.layers.rotary
6
+ import huggingface_hub
7
+ import omegaconf
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ from transformers import AutoModel
14
+
15
+ # Flags required to enable jit fusion kernels
16
+ torch._C._jit_set_profiling_mode(False)
17
+ torch._C._jit_set_profiling_executor(False)
18
+ torch._C._jit_override_can_fuse_on_cpu(True)
19
+ torch._C._jit_override_can_fuse_on_gpu(True)
20
+
21
+
22
+ def bias_dropout_add_scale(
23
+ x: torch.Tensor,
24
+ bias: typing.Optional[torch.Tensor],
25
+ scale: torch.Tensor,
26
+ residual: typing.Optional[torch.Tensor],
27
+ prob: float,
28
+ training: bool) -> torch.Tensor:
29
+ if bias is not None:
30
+ out = scale * F.dropout(x + bias, p=prob, training=training)
31
+ else:
32
+ out = scale * F.dropout(x, p=prob, training=training)
33
+
34
+ if residual is not None:
35
+ out = residual + out
36
+ return out
37
+
38
+
39
+ def get_bias_dropout_add_scale(training):
40
+ def _bias_dropout_add(x, bias, scale, residual, prob):
41
+ return bias_dropout_add_scale(
42
+ x, bias, scale, residual, prob, training)
43
+
44
+ return _bias_dropout_add
45
+
46
+
47
+ # function overload
48
+ def modulate(x: torch.Tensor,
49
+ shift: torch.Tensor,
50
+ scale: torch.Tensor) -> torch.Tensor:
51
+ return x * (1 + scale) + shift
52
+
53
+
54
+ @torch.jit.script
55
+ def bias_dropout_add_scale_fused_train(
56
+ x: torch.Tensor,
57
+ bias: typing.Optional[torch.Tensor],
58
+ scale: torch.Tensor,
59
+ residual: typing.Optional[torch.Tensor],
60
+ prob: float) -> torch.Tensor:
61
+ return bias_dropout_add_scale(
62
+ x, bias, scale, residual, prob, True)
63
+
64
+
65
+ @torch.jit.script
66
+ def bias_dropout_add_scale_fused_inference(
67
+ x: torch.Tensor,
68
+ bias: typing.Optional[torch.Tensor],
69
+ scale: torch.Tensor,
70
+ residual: typing.Optional[torch.Tensor],
71
+ prob: float) -> torch.Tensor:
72
+ return bias_dropout_add_scale(
73
+ x, bias, scale, residual, prob, False)
74
+
75
+
76
+ @torch.jit.script
77
+ def modulate_fused(x: torch.Tensor,
78
+ shift: torch.Tensor,
79
+ scale: torch.Tensor) -> torch.Tensor:
80
+ return modulate(x, shift, scale)
81
+
82
+
83
+ class Rotary(torch.nn.Module):
84
+ def __init__(self, dim, base=10_000):
85
+ super().__init__()
86
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
87
+ self.register_buffer('inv_freq', inv_freq)
88
+ self.seq_len_cached = None
89
+ self.cos_cached = None
90
+ self.sin_cached = None
91
+
92
+ def forward(self, x, seq_dim=1):
93
+ seq_len = x.shape[seq_dim]
94
+ if seq_len != self.seq_len_cached:
95
+ self.seq_len_cached = seq_len
96
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
97
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
98
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
99
+ # dims are: batch, seq_len, qkv, head, dim
100
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
101
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
102
+ # This makes the transformation on v an identity.
103
+ self.cos_cached[:,:,2,:,:].fill_(1.)
104
+ self.sin_cached[:,:,2,:,:].fill_(0.)
105
+
106
+ return self.cos_cached, self.sin_cached
107
+
108
+
109
+ def rotate_half(x):
110
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
111
+ return torch.cat((-x2, x1), dim=-1)
112
+
113
+
114
+ def apply_rotary_pos_emb(qkv, cos, sin):
115
+ cos = cos[0,:,0,0,:cos.shape[-1]//2]
116
+ sin = sin[0,:,0,0,:sin.shape[-1]//2]
117
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
118
+
119
+
120
+ # function overload
121
+ def modulate(x, shift, scale):
122
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
123
+
124
+
125
+ #################################################################################
126
+ # Layers #
127
+ #################################################################################
128
+ class LayerNorm(nn.Module):
129
+ def __init__(self, dim):
130
+ super().__init__()
131
+ self.weight = nn.Parameter(torch.ones([dim]))
132
+ self.dim = dim
133
+ def forward(self, x):
134
+ with torch.cuda.amp.autocast(enabled=False):
135
+ x = F.layer_norm(x.float(), [self.dim])
136
+ return x * self.weight[None,None,:]
137
+
138
+
139
+ def residual_linear(x, W, x_skip, residual_scale):
140
+ """x_skip + residual_scale * W @ x"""
141
+ dim_out, dim_in = W.shape[0], W.shape[1]
142
+ return torch.addmm(
143
+ x_skip.view(-1, dim_out),
144
+ x.view(-1, dim_in),
145
+ W.T,
146
+ alpha=residual_scale).view(*x.shape[:-1], dim_out)
147
+
148
+
149
+ #################################################################################
150
+ # Embedding Layers for Timesteps and Class Labels #
151
+ #################################################################################
152
+ class TimestepEmbedder(nn.Module):
153
+ """
154
+ Embeds scalar timesteps into vector representations.
155
+ """
156
+ def __init__(self, hidden_size, frequency_embedding_size=256):
157
+ super().__init__()
158
+ self.mlp = nn.Sequential(
159
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
160
+ nn.SiLU(),
161
+ nn.Linear(hidden_size, hidden_size, bias=True))
162
+ self.frequency_embedding_size = frequency_embedding_size
163
+
164
+ @staticmethod
165
+ def timestep_embedding(t, dim, max_period=10000):
166
+ """
167
+ Create sinusoidal timestep embeddings.
168
+ :param t: a 1-D Tensor of N indices, one per batch element.
169
+ These may be fractional.
170
+ :param dim: the dimension of the output.
171
+ :param max_period: controls the minimum frequency of the embeddings.
172
+ :return: an (N, D) Tensor of positional embeddings.
173
+ """
174
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
175
+ half = dim // 2
176
+ freqs = torch.exp(- math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
177
+
178
+ if t.ndim == 1:
179
+ t = t.unsqueeze(1)
180
+
181
+ args = t.float() * freqs[None, :]
182
+ #args = t[:, None].float() * freqs[None]
183
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
184
+ if dim % 2:
185
+ embedding = torch.cat(
186
+ [embedding,
187
+ torch.zeros_like(embedding[:, :1])], dim=-1)
188
+ return embedding
189
+
190
+ def forward(self, t):
191
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
192
+ t_emb = self.mlp(t_freq)
193
+ return t_emb
194
+
195
+
196
+ class LabelEmbedder(nn.Module):
197
+ """Embeds class labels into vector representations.
198
+
199
+ Also handles label dropout for classifier-free guidance.
200
+ """
201
+ def __init__(self, num_classes, cond_size):
202
+ super().__init__()
203
+ self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
204
+ self.num_classes = num_classes
205
+
206
+ # TODO think of initializing with 0.02 std deviation like in original DiT paper
207
+
208
+ def forward(self, labels):
209
+ embeddings = self.embedding_table(labels)
210
+ return embeddings
211
+
212
+
213
+ #################################################################################
214
+ # Core Model #
215
+ #################################################################################
216
+
217
+
218
+ class DDiTBlock(nn.Module):
219
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
220
+ super().__init__()
221
+ self.n_heads = n_heads
222
+
223
+ self.norm1 = LayerNorm(dim)
224
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
225
+ self.attn_out = nn.Linear(dim, dim, bias=False)
226
+ self.dropout1 = nn.Dropout(dropout)
227
+
228
+ self.norm2 = LayerNorm(dim)
229
+ self.mlp = nn.Sequential(
230
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
231
+ nn.GELU(approximate='tanh'),
232
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
233
+ self.dropout2 = nn.Dropout(dropout)
234
+ self.dropout = dropout
235
+
236
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
237
+ self.adaLN_modulation.weight.data.zero_()
238
+ self.adaLN_modulation.bias.data.zero_()
239
+
240
+
241
+ def _get_bias_dropout_scale(self):
242
+ if self.training:
243
+ return bias_dropout_add_scale_fused_train
244
+ else:
245
+ return bias_dropout_add_scale_fused_inference
246
+
247
+
248
+ def forward(self, x, rotary_cos_sin, c, seqlens=None):
249
+ batch_size, seq_len = x.shape[0], x.shape[1]
250
+
251
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
252
+
253
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
254
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
255
+
256
+ # attention operation
257
+ x_skip = x
258
+ x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
259
+
260
+ qkv = self.attn_qkv(x)
261
+ qkv = rearrange(qkv,
262
+ 'b s (three h d) -> b s three h d',
263
+ three=3,
264
+ h=self.n_heads)
265
+ with torch.cuda.amp.autocast(enabled=False):
266
+ cos, sin = rotary_cos_sin
267
+ qkv = apply_rotary_pos_emb(
268
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
269
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
270
+ if seqlens is None:
271
+ cu_seqlens = torch.arange(
272
+ 0, (batch_size + 1) * seq_len, step=seq_len,
273
+ dtype=torch.int32, device=qkv.device)
274
+ else:
275
+ cu_seqlens = seqlens.cumsum(-1)
276
+ x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
277
+ qkv, cu_seqlens, seq_len, 0., causal=False)
278
+
279
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
280
+
281
+ x = bias_dropout_scale_fn(self.attn_out(x),
282
+ None,
283
+ gate_msa,
284
+ x_skip,
285
+ self.dropout)
286
+
287
+ # mlp operation
288
+ x = bias_dropout_scale_fn(
289
+ self.mlp(modulate_fused(
290
+ self.norm2(x), shift_mlp, scale_mlp)),
291
+ None, gate_mlp, x, self.dropout)
292
+ return x
293
+
294
+
295
+
296
+ class EmbeddingLayer(nn.Module):
297
+ def __init__(self, dim, vocab_dim):
298
+ super().__init__()
299
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
300
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
301
+
302
+ def forward(self, x):
303
+ return self.embedding[x]
304
+
305
+
306
+ class DDitFinalLayer(nn.Module):
307
+ def __init__(self, hidden_size, out_channels, cond_dim):
308
+ super().__init__()
309
+ self.norm_final = LayerNorm(hidden_size)
310
+ self.linear = nn.Linear(hidden_size, out_channels)
311
+ self.linear.weight.data.zero_()
312
+ self.linear.bias.data.zero_()
313
+
314
+ self.adaLN_modulation = nn.Linear(cond_dim,
315
+ 2 * hidden_size,
316
+ bias=True)
317
+ self.adaLN_modulation.weight.data.zero_()
318
+ self.adaLN_modulation.bias.data.zero_()
319
+
320
+
321
+ def forward(self, x, c):
322
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
323
+ x = modulate_fused(self.norm_final(x), shift, scale)
324
+ x = self.linear(x)
325
+ return x
326
+
327
+
328
+ class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
329
+ def __init__(self, config, vocab_size: int, mlm_model_path):
330
+ super().__init__()
331
+ if type(config) == dict:
332
+ config = omegaconf.OmegaConf.create(config)
333
+
334
+ self.config = config
335
+ self.vocab_size = vocab_size
336
+
337
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size,
338
+ vocab_size)
339
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
340
+ self.rotary_emb = Rotary(
341
+ config.model.hidden_size // config.model.n_heads)
342
+
343
+ blocks = []
344
+ for _ in range(config.model.n_blocks):
345
+ blocks.append(DDiTBlock(config.model.hidden_size,
346
+ config.model.n_heads,
347
+ config.model.cond_dim,
348
+ dropout=config.model.dropout))
349
+ self.blocks = nn.ModuleList(blocks)
350
+
351
+ self.output_layer = DDitFinalLayer(
352
+ config.model.hidden_size,
353
+ vocab_size,
354
+ config.model.cond_dim)
355
+ self.scale_by_sigma = config.model.scale_by_sigma
356
+
357
+ self.mlm_model = AutoModel.from_pretrained(mlm_model_path, device_map='cpu')
358
+
359
+ def _get_bias_dropout_scale(self):
360
+ if self.training:
361
+ return bias_dropout_add_scale_fused_train
362
+ else:
363
+ return bias_dropout_add_scale_fused_inference
364
+
365
+ def forward(self, indices, sigma):
366
+ x = self.vocab_embed(indices)
367
+ c_sigma = F.silu(self.sigma_map(sigma))
368
+
369
+ rotary_cos_sin = self.rotary_emb(x)
370
+
371
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
372
+ for i in range(len(self.blocks)):
373
+ x = self.blocks[i](x, rotary_cos_sin, c_sigma, seqlens=None)
374
+ x = self.output_layer(x, c_sigma)
375
+
376
+ # Extract membrane-specific embeddings from final encoder layer
377
+ # of fine-tuned ESM model
378
+ # with torch.no_grad():
379
+ # membrane_embedding = self.mlm_model(input_ids=, attention_mask=).last_hidden_state.squeeze(0)
380
+
381
+ # Fuse MLM embeddings with conditioning vector
382
+ # c = torch.cat([c_sigma, membrane_embedding], dim=-1)
383
+
384
+ # print(membrane_embedding.size())
385
+ # print(c_sigma.size())
386
+
387
+ return x
388
+
ema.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ExponentialMovingAverage:
5
+ """
6
+ Maintains (exponential) moving average of a set of parameters.
7
+ """
8
+
9
+ def __init__(self, parameters, decay, use_num_updates=True):
10
+ """
11
+ Args:
12
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
13
+ `model.parameters()`.
14
+ decay: The exponential decay.
15
+ use_num_updates: Whether to use number of updates when computing
16
+ averages.
17
+ """
18
+ if decay < 0.0 or decay > 1.0:
19
+ raise ValueError('Decay must be between 0 and 1')
20
+ self.decay = decay
21
+ self.num_updates = 0 if use_num_updates else None
22
+ self.shadow_params = [p.clone().detach()
23
+ for p in parameters if p.requires_grad]
24
+ self.collected_params = []
25
+
26
+ def move_shadow_params_to_device(self, device):
27
+ self.shadow_params = [i.to(device) for i in self.shadow_params]
28
+
29
+ def update(self, parameters):
30
+ """
31
+ Update currently maintained parameters.
32
+
33
+ Call this every time the parameters are updated, such as the result of
34
+ the `optimizer.step()` call.
35
+
36
+ Args:
37
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
38
+ parameters used to initialize this object.
39
+ """
40
+ decay = self.decay
41
+ if self.num_updates is not None:
42
+ self.num_updates += 1
43
+ decay = min(decay, (1 + self.num_updates) /
44
+ (10 + self.num_updates))
45
+ one_minus_decay = 1.0 - decay
46
+ with torch.no_grad():
47
+ parameters = [p for p in parameters if p.requires_grad]
48
+ for s_param, param in zip(self.shadow_params, parameters):
49
+ s_param.sub_(one_minus_decay * (s_param - param))
50
+
51
+ def copy_to(self, parameters):
52
+ """
53
+ Copy current parameters into given collection of parameters.
54
+
55
+ Args:
56
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
57
+ updated with the stored moving averages.
58
+ """
59
+ parameters = [p for p in parameters if p.requires_grad]
60
+ for s_param, param in zip(self.shadow_params, parameters):
61
+ if param.requires_grad:
62
+ param.data.copy_(s_param.data)
63
+
64
+ def store(self, parameters):
65
+ """
66
+ Save the current parameters for restoring later.
67
+
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+
82
+ Args:
83
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
84
+ updated with the stored parameters.
85
+ """
86
+ for c_param, param in zip(self.collected_params, parameters):
87
+ param.data.copy_(c_param.data)
88
+
89
+ def state_dict(self):
90
+ return dict(decay=self.decay,
91
+ num_updates=self.num_updates,
92
+ shadow_params=self.shadow_params)
93
+
94
+ def load_state_dict(self, state_dict):
95
+ self.decay = state_dict['decay']
96
+ self.num_updates = state_dict['num_updates']
97
+ self.shadow_params = state_dict['shadow_params']
esm_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import config
3
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
4
+
5
+ def load_esm2_model(model_name):
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ masked_model = AutoModelForMaskedLM.from_pretrained(model_name)
8
+ embedding_model = AutoModel.from_pretrained(model_name)
9
+ return tokenizer, masked_model, embedding_model
10
+
11
+ def get_latents(model, tokenizer, sequence, device):
12
+ inputs = tokenizer(sequence, return_tensors="pt").to(device)
13
+ with torch.no_grad():
14
+ outputs = model(**inputs).last_hidden_state.squeeze(0)
15
+ return outputs
generate.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import random
5
+ import sys
6
+ import pandas as pd
7
+ from mlm_generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
8
+ from diffusion import Diffusion
9
+ import hydra
10
+ from tqdm import tqdm
11
+ from transformers import AutoTokenizer, AutoModel, pipeline
12
+
13
+
14
+ @torch.no_grad()
15
+ def generate_sequence(sequence_length: int, tokenizer, mdlm: Diffusion):
16
+ global masked_sequence
17
+ masked_sequence = mask_for_de_novo(sequence_length)
18
+ inputs = tokenizer(masked_sequence, return_tensors="pt").to(mdlm.device)
19
+ logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
20
+ generated_sequence = tokenizer.decode(logits.squeeze())
21
+
22
+ return generated_sequence
23
+
24
+
25
+ @hydra.main(version_base=None, config_path='configs', config_name='config')
26
+ def mdlm_motif_benchmark(config):
27
+ path = "/workspace/sg666/MDpLM"
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
30
+ mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer)
31
+
32
+ mdlm_model.eval()
33
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
34
+ mdlm_model.to(device)
35
+
36
+ print("loaded models...")
37
+
38
+ # Get 100 random sequence lengths to generate
39
+ sequence_lengths = [random.randint(50, 1000) for _ in range(100)]
40
+
41
+ generation_results = []
42
+ for seq_length in tqdm(sequence_lengths, desc=f"Generating sequences: "):
43
+ generated_sequence = generate_sequence(seq_length, tokenizer, mdlm_model)
44
+ generated_sequence = generated_sequence[5:-5].replace(" ", "") # Remove bos/eos tokens
45
+
46
+ perplexity = mdlm_model.compute_masked_perplexity([generated_sequence], masked_sequence)
47
+ perplexity = round(perplexity, 4)
48
+
49
+ generation_results.append([generated_sequence, perplexity])
50
+
51
+ print(f"perplexity: {perplexity} | length: {seq_length} | generated sequence: {generated_sequence}")
52
+ sys.stdout.flush()
53
+
54
+ df = pd.DataFrame(generation_results, columns=['Generated Sequence', 'Perplexity'])
55
+ df.to_csv(path + f'/benchmarks/mdlm_de-novo_generation_results.csv', index=False)
56
+
57
+
58
+
59
+ if __name__ == "__main__":
60
+ mdlm_motif_benchmark()
main.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import wandb
4
+ import fsspec
5
+ import hydra
6
+ import lightning as L
7
+ import omegaconf
8
+ import rich.syntax
9
+ import rich.tree
10
+ import torch
11
+
12
+ import pl_data_loader as dataloader
13
+ from diffusion import Diffusion
14
+ import utils
15
+
16
+ from lightning.pytorch.strategies import DDPStrategy
17
+ from transformers import AutoTokenizer
18
+ from datasets import load_from_disk, load_dataset
19
+
20
+ #wandb.login(key="2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f")
21
+ omegaconf.OmegaConf.register_new_resolver(
22
+ 'cwd', os.getcwd)
23
+ omegaconf.OmegaConf.register_new_resolver(
24
+ 'device_count', torch.cuda.device_count)
25
+ omegaconf.OmegaConf.register_new_resolver(
26
+ 'eval', eval)
27
+ omegaconf.OmegaConf.register_new_resolver(
28
+ 'div_up', lambda x, y: (x + y - 1) // y)
29
+
30
+
31
+ def _load_from_checkpoint(config, tokenizer):
32
+ if 'hf' in config.backbone:
33
+ return Diffusion(
34
+ config, tokenizer=tokenizer).to('cuda')
35
+ else:
36
+ model= Diffusion.load_from_checkpoint(
37
+ config.eval.checkpoint_path,
38
+ tokenizer=tokenizer,
39
+ config=config)
40
+
41
+ return model
42
+
43
+ @L.pytorch.utilities.rank_zero_only
44
+ def _print_config(
45
+ config: omegaconf.DictConfig,
46
+ resolve: bool = True,
47
+ save_cfg: bool = True) -> None:
48
+ """Prints content of DictConfig using Rich library and its tree structure.
49
+
50
+ Args:
51
+ config (DictConfig): Configuration composed by Hydra.
52
+ resolve (bool): Whether to resolve reference fields of DictConfig.
53
+ save_cfg (bool): Whether to save the configuration tree to a file.
54
+ """
55
+
56
+ style = 'dim'
57
+ tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
58
+
59
+ fields = config.keys()
60
+ for field in fields:
61
+ branch = tree.add(field, style=style, guide_style=style)
62
+
63
+ config_section = config.get(field)
64
+ branch_content = str(config_section)
65
+ if isinstance(config_section, omegaconf.DictConfig):
66
+ branch_content = omegaconf.OmegaConf.to_yaml(
67
+ config_section, resolve=resolve)
68
+
69
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
70
+ rich.print(tree)
71
+ if save_cfg:
72
+ with fsspec.open(
73
+ '{}/config_tree.txt'.format(
74
+ config.checkpointing.save_dir), 'w') as fp:
75
+ rich.print(tree, file=fp)
76
+
77
+
78
+ @L.pytorch.utilities.rank_zero_only
79
+ def _print_batch(train_ds, valid_ds, tokenizer, k=64):
80
+ #for dl_type, dl in [
81
+ #('train', train_ds), ('valid', valid_ds)]:
82
+ for dl_type, dl in [
83
+ ('train', train_ds)]:
84
+ print(f'Printing {dl_type} dataloader batch.')
85
+ batch = next(iter(dl))
86
+ print('Batch input_ids.shape', batch['input_ids'].shape)
87
+ first = batch['input_ids'][0, :k]
88
+ last = batch['input_ids'][0, -k:]
89
+ print(f'First {k} tokens:', tokenizer.decode(first))
90
+ print('ids:', first)
91
+ print(f'Last {k} tokens:', tokenizer.decode(last))
92
+ print('ids:', last)
93
+
94
+
95
+ def generate_samples(config, logger, tokenizer):
96
+ logger.info('Generating samples.')
97
+ model = _load_from_checkpoint(config=config,
98
+ tokenizer=tokenizer)
99
+ model.gen_ppl_metric.reset()
100
+ if config.eval.disable_ema:
101
+ logger.info('Disabling EMA.')
102
+ model.ema = None
103
+ stride_length = config.sampling.stride_length
104
+ num_strides = config.sampling.num_strides
105
+ for _ in range(config.sampling.num_sample_batches):
106
+ if config.sampling.semi_ar:
107
+ _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
108
+ stride_length=stride_length,
109
+ num_strides=num_strides,
110
+ dt=1 / config.sampling.steps)
111
+ text_samples = intermediate_samples[-1]
112
+ # Note: Samples generated using semi-ar method
113
+ # need to to be processed before computing generative perplexity
114
+ # since these samples contain numerous <|endoftext|> tokens
115
+ # and diffusion.compute_generative_perplexity() discards
116
+ # any text after the first EOS token.
117
+ else:
118
+ samples = model.restore_model_and_sample(
119
+ num_steps=config.sampling.steps)
120
+ text_samples = model.tokenizer.batch_decode(samples)
121
+ model.compute_generative_perplexity(text_samples)
122
+ print('Text samples:', text_samples)
123
+ if not config.sampling.semi_ar:
124
+ print('Generative perplexity:',
125
+ model.gen_ppl_metric.compute())
126
+ return text_samples
127
+
128
+ def _ppl_eval(config, logger, tokenizer, data_module):
129
+ logger.info('Starting Zero Shot Eval.')
130
+
131
+ model = _load_from_checkpoint(config=config,
132
+ tokenizer=tokenizer)
133
+ if config.eval.disable_ema:
134
+ logger.info('Disabling EMA.')
135
+ model.ema = None
136
+
137
+ wandb_logger = None
138
+ if config.get('wandb', None) is not None:
139
+ wandb_logger = L.pytorch.loggers.WandbLogger(
140
+ config=omegaconf.OmegaConf.to_object(config),
141
+ ** config.wandb)
142
+ callbacks = []
143
+ if 'callbacks' in config:
144
+ for _, callback in config.callbacks.items():
145
+ callbacks.append(hydra.utils.instantiate(callback))
146
+ trainer = hydra.utils.instantiate(
147
+ config.trainer,
148
+ default_root_dir=os.getcwd(),
149
+ callbacks=callbacks,
150
+ strategy=DDPStrategy(find_unused_parameters=True),
151
+ logger=wandb_logger)
152
+ # _, valid_ds = dataloader.get_dataloaders(
153
+ # config, tokenizer, skip_train=True, valid_seed=config.seed)
154
+ trainer.test(model, data_module)
155
+
156
+
157
+ def _train(config, logger, tokenizer, data_module):
158
+ logger.info('Starting Training.')
159
+ wandb_logger = None
160
+ if config.get('wandb', None) is not None:
161
+ wandb_logger = L.pytorch.loggers.WandbLogger(
162
+ config=omegaconf.OmegaConf.to_object(config),
163
+ ** config.wandb)
164
+
165
+ if (config.checkpointing.resume_from_ckpt
166
+ and config.checkpointing.resume_ckpt_path is not None
167
+ and utils.fsspec_exists(
168
+ config.checkpointing.resume_ckpt_path)):
169
+ ckpt_path = config.checkpointing.resume_ckpt_path
170
+ else:
171
+ ckpt_path = None
172
+
173
+ # Lightning callbacks
174
+ callbacks = []
175
+ if 'callbacks' in config:
176
+ for _, callback in config.callbacks.items():
177
+ callbacks.append(hydra.utils.instantiate(callback))
178
+ '''
179
+ train_ds, valid_ds = dataloader.get_dataloaders(
180
+ config, tokenizer)
181
+ _print_batch(train_ds, valid_ds, tokenizer)
182
+
183
+ model = diffusion.Diffusion(
184
+ config, tokenizer=valid_ds.tokenizer)
185
+ '''
186
+ trainer = hydra.utils.instantiate(
187
+ config.trainer,
188
+ default_root_dir=os.getcwd(),
189
+ callbacks=callbacks,
190
+ accelerator='cuda',
191
+ strategy=DDPStrategy(find_unused_parameters=True),
192
+ logger=wandb_logger)
193
+
194
+ model = Diffusion(
195
+ config, tokenizer=tokenizer)
196
+
197
+ trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
198
+
199
+ '''
200
+ trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)
201
+ '''
202
+
203
+ @hydra.main(version_base=None, config_path='configs', config_name='config')
204
+ def main(config):
205
+ """Main entry point for training."""
206
+ L.seed_everything(config.seed)
207
+ _print_config(config, resolve=True, save_cfg=True)
208
+
209
+ logger = utils.get_logger(__name__)
210
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
211
+
212
+ if config.backbone == "vanilla_esm_pretrain":
213
+ train_dataset = load_dataset('csv', data_files=config.data.train.vanilla_esm_train_path)
214
+ val_dataset = load_dataset('csv', data_files=config.data.valid.vanilla_esm_valid_path)
215
+ test_dataset = load_dataset('csv', data_files=config.data.test.vanilla_esm_test_path)
216
+ elif config.backbone == "membrane_esm_finetune" or config.backbone == "dit":
217
+ train_dataset = load_dataset('csv', data_files=config.data.train.membrane_esm_train_path)
218
+ val_dataset = load_dataset('csv', data_files=config.data.valid.membrane_esm_valid_path)
219
+ test_dataset = load_dataset('csv', data_files=config.data.test.membrane_esm_test_path)
220
+
221
+ lst = [i for i in range(1, 200)]
222
+
223
+ train_dataset = train_dataset['train']#.select(lst)
224
+ val_dataset = val_dataset['train']#.select(lst)
225
+ test_dataset = test_dataset['train']#.select(lst)
226
+
227
+ if config.training.focus_mask :
228
+ collator = dataloader.membrane_collate_fn
229
+ elif config.data.wrapping:
230
+ collator = dataloader.wrap_collate_fn
231
+ else:
232
+ collator = collate_fn
233
+
234
+ data_module = dataloader.CustomDataModule(
235
+ train_dataset, val_dataset, test_dataset,
236
+ tokenizer,
237
+ batch_size=config.loader.batch_size,
238
+ collate_fn=collator
239
+ )
240
+
241
+ if config.mode == 'sample_eval':
242
+ generate_samples(config, logger, tokenizer)
243
+ elif config.mode == 'ppl_eval':
244
+ _ppl_eval(config, logger, tokenizer, data_module)
245
+ else:
246
+ _train(config, logger, tokenizer, data_module)
247
+
248
+
249
+ if __name__ == '__main__':
250
+ main()
mdlm_motif_benchmarking.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import random
5
+ import sys
6
+ import pandas as pd
7
+ from mlm_generate_utils import mask_for_scaffold, calculate_cosine_sim, calculate_hamming_dist
8
+ from diffusion import Diffusion
9
+ import hydra
10
+ from tqdm import tqdm
11
+ from transformers import AutoTokenizer, AutoModel, pipeline
12
+
13
+ def masking_test(sequence: str, generate_case: str, tokenizer, mask_prob: float = 0.50):
14
+ """
15
+ Masks 50% of the tokens in the sequence.
16
+ """
17
+ tokens = list(sequence.upper())
18
+ num_tokens_to_mask = int(mask_prob * len(tokens)) # Select some fraction of the tokens
19
+ print(num_tokens_to_mask,len(tokens))
20
+
21
+ # Get random indices to mask
22
+ mask_indices = random.sample(range(len(tokens)), num_tokens_to_mask)
23
+
24
+ for idx in mask_indices:
25
+ tokens[idx] = tokenizer.mask_token # Replace with mask token
26
+
27
+ return ''.join(tokens)
28
+
29
+
30
+
31
+ @torch.no_grad()
32
+ def generate_scaffold_mdlm(sequence: str, generate_case: str, tokenizer, mdlm: Diffusion):
33
+ # # Mask soluble or transmembrane domains
34
+ # masked_sequence = mask_for_scaffold(sequence, generate_case)
35
+
36
+ # # Test out different masking rates
37
+ # masked_sequence = masking_test(sequence, generate_case, tokenizer)
38
+
39
+ # 100% masking rate, de novo generation
40
+ masked_sequence = len(sequence) * "<mask>"
41
+
42
+ print(masked_sequence)
43
+
44
+ inputs = tokenizer(masked_sequence, return_tensors="pt").to(mdlm.device)
45
+
46
+ logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
47
+ # logits = mdlm.forward(inputs)
48
+ # print(tokenizer.decode(logits.squeeze(), skip_special_tokens=True))
49
+
50
+ return tokenizer.decode(logits.squeeze()), masked_sequence
51
+
52
+
53
+ @hydra.main(version_base=None, config_path='configs', config_name='config')
54
+ def mdlm_motif_benchmark(config):
55
+ path = "/workspace/sg666/MDpLM"
56
+
57
+ test_sequences = pd.read_csv(path + "/data/membrane/test.csv")['Sequence'].tolist()
58
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
59
+
60
+ mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer)
61
+ esm_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D") # model used for functionality testing
62
+
63
+ mdlm_model.eval()
64
+ esm_model.eval()
65
+
66
+ print("loaded models...")
67
+
68
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
69
+ mdlm_model.to(device)
70
+ esm_model.to(device)
71
+
72
+ for generate_case in ['uppercase', 'lowercase']:
73
+ case_results = []
74
+ for original_sequence in tqdm(test_sequences, desc=f"scaffolding ({generate_case}): "):
75
+
76
+ generated_sequence, masked_input = generate_scaffold_mdlm(original_sequence, generate_case, tokenizer, mdlm_model)
77
+ generated_sequence = generated_sequence[5:-5].replace(" ", "") # Remove bos/eos tokens
78
+
79
+ perplexity = mdlm_model.compute_masked_perplexity([original_sequence], masked_input)
80
+ cos_sim = calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, esm_model, device)
81
+ hamming_distance = calculate_hamming_dist(original_sequence, generated_sequence)
82
+
83
+ case_results.append([original_sequence, generated_sequence, perplexity, cos_sim, hamming_distance])
84
+
85
+ print("perplexity: ", perplexity, "cos sim: ", cos_sim, "hamming: ", hamming_distance)
86
+ print(f"generated sequence: {generated_sequence}")
87
+ print(f"original sequence: {original_sequence.upper()}")
88
+ sys.stdout.flush()
89
+
90
+ df = pd.DataFrame(case_results, columns=['Original Sequence', 'Generated Sequence', 'Perplexity', 'Cosine Similarity', 'Hamming Distance'])
91
+ df.to_csv(path + f'/benchmarks/MLM/mlm_{generate_case}_results.csv', index=False)
92
+
93
+
94
+
95
+ if __name__ == "__main__":
96
+ mdlm_motif_benchmark()
mlm_generate_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import config
4
+ import sys
5
+ import pandas as pd
6
+ from esm_utils import get_latents
7
+ from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
8
+
9
+
10
+ def mask_for_de_novo(sequence_length):
11
+ return "<mask>" * sequence_length
12
+
13
+ def generate_de_novo(sequence_length, tokenizer, model):
14
+ masked_sequence = mask_for_de_novo(sequence_length)
15
+ inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
16
+
17
+ with torch.no_grad():
18
+ logits = model(**inputs).logits
19
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
20
+ logits_at_masks = logits[0, mask_token_indices]
21
+
22
+ pred_tokens = []
23
+ for i in mask_token_indices:
24
+ topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
25
+ probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
26
+ predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
27
+ predicted_token_id = topk_indices[predicted_index].item()
28
+ predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
29
+ pred_tokens.append(predicted_token)
30
+
31
+ generated_sequence = ''.join(pred_tokens)
32
+ perplexity = calculate_perplexity(model, tokenizer, generated_sequence)
33
+
34
+ return (generated_sequence, perplexity)
35
+
36
+
37
+ def mask_for_scaffold(sequence, generate_type):
38
+ if generate_type == "uppercase":
39
+ sequence = ''.join(["<mask>" if residue.isupper() else residue.upper() for residue in sequence])
40
+ elif generate_type == "lowercase":
41
+ sequence = ''.join(["<mask>" if residue.islower() else residue for residue in sequence])
42
+ return sequence
43
+
44
+
45
+ def generate_scaffold(sequence, generate_type, tokenizer, model):
46
+ masked_sequence = mask_for_scaffold(sequence, generate_type)
47
+ inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
48
+
49
+ with torch.no_grad():
50
+ logits = model(**inputs).logits
51
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
52
+ logits_at_masks = logits[0, mask_token_indices]
53
+
54
+ pred_tokens = []
55
+ for i in range(len(mask_token_indices)):
56
+ topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
57
+ probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
58
+ predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
59
+ predicted_token_id = topk_indices[predicted_index].item()
60
+ predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
61
+
62
+ pred_tokens.append('G' if predicted_token == '' else predicted_token)
63
+
64
+ generated_sequence = masked_sequence
65
+ for token in pred_tokens:
66
+ generated_sequence = generated_sequence.replace("<mask>", token, 1)
67
+
68
+ return generated_sequence, mask_token_indices
69
+
70
+
71
+ def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
72
+ total_loss = 0.0
73
+ tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
74
+
75
+ for i in mask_token_indices:
76
+ masked_input = tensor_input.clone()
77
+ masked_input[0, i] = tokenizer.mask_token_id
78
+
79
+ labels = torch.full(tensor_input.shape, -100).to(model.device)
80
+ labels[0, i] = tensor_input[0, i]
81
+
82
+ with torch.no_grad():
83
+ outputs = model(masked_input, labels=labels)
84
+ total_loss += outputs.loss.item()
85
+
86
+ num_mask_tokens = len(mask_token_indices)
87
+ if num_mask_tokens == 0:
88
+ perplexity = 10000
89
+ else:
90
+ avg_loss = total_loss / num_mask_tokens
91
+ perplexity = math.exp(avg_loss)
92
+
93
+ return perplexity
94
+
95
+
96
+ def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
97
+ og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
98
+ new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)
99
+
100
+ sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
101
+ cosine_similarity = torch.mean(sequence_similarity).item()
102
+ return cosine_similarity
103
+
104
+
105
+ def calculate_hamming_dist(original_sequence, generated_sequence):
106
+ generated_sequence = generated_sequence.upper()
107
+ original_sequence = original_sequence.upper()
108
+ return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))
noise_schedule.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # Flags required to enable jit fusion kernels
7
+ torch._C._jit_set_profiling_mode(False)
8
+ torch._C._jit_set_profiling_executor(False)
9
+ torch._C._jit_override_can_fuse_on_cpu(True)
10
+ torch._C._jit_override_can_fuse_on_gpu(True)
11
+
12
+
13
+ def get_noise(config, dtype=torch.float32):
14
+ return LogLinearNoise()
15
+
16
+ if config.noise.type == 'geometric':
17
+ return GeometricNoise(config.noise.sigma_min,
18
+ config.noise.sigma_max)
19
+ elif config.noise.type == 'loglinear':
20
+ return LogLinearNoise()
21
+ elif config.noise.type == 'cosine':
22
+ return CosineNoise()
23
+ elif config.noise.type == 'cosinesqr':
24
+ return CosineSqrNoise()
25
+ elif config.noise.type == 'linear':
26
+ return Linear(config.noise.sigma_min,
27
+ config.noise.sigma_max,
28
+ dtype)
29
+ else:
30
+ raise ValueError(f'{config.noise.type} is not a valid noise')
31
+
32
+
33
+ def binary_discretization(z):
34
+ z_hard = torch.sign(z)
35
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
36
+ return z_soft + (z_hard - z_soft).detach()
37
+
38
+
39
+ class Noise(abc.ABC, nn.Module):
40
+ """
41
+ Baseline forward method to get the total + rate of noise at a timestep
42
+ """
43
+ def forward(self, t):
44
+ # Assume time goes from 0 to 1
45
+ return self.total_noise(t), self.rate_noise(t)
46
+
47
+ @abc.abstractmethod
48
+ def rate_noise(self, t):
49
+ """
50
+ Rate of change of noise ie g(t)
51
+ """
52
+ pass
53
+
54
+ @abc.abstractmethod
55
+ def total_noise(self, t):
56
+ """
57
+ Total noise ie \int_0^t g(t) dt + g(0)
58
+ """
59
+ pass
60
+
61
+
62
+ class CosineNoise(Noise):
63
+ def __init__(self, eps=1e-3):
64
+ super().__init__()
65
+ self.eps = eps
66
+
67
+ def rate_noise(self, t):
68
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
69
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
70
+ scale = torch.pi / 2
71
+ return scale * sin / (cos + self.eps)
72
+
73
+ def total_noise(self, t):
74
+ cos = torch.cos(t * torch.pi / 2)
75
+ return - torch.log(self.eps + (1 - self.eps) * cos)
76
+
77
+
78
+ class CosineSqrNoise(Noise):
79
+ def __init__(self, eps=1e-3):
80
+ super().__init__()
81
+ self.eps = eps
82
+
83
+ def rate_noise(self, t):
84
+ cos = (1 - self.eps) * (
85
+ torch.cos(t * torch.pi / 2) ** 2)
86
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
87
+ scale = torch.pi / 2
88
+ return scale * sin / (cos + self.eps)
89
+
90
+ def total_noise(self, t):
91
+ cos = torch.cos(t * torch.pi / 2) ** 2
92
+ return - torch.log(self.eps + (1 - self.eps) * cos)
93
+
94
+
95
+ class Linear(Noise):
96
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
97
+ super().__init__()
98
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
99
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
100
+
101
+ def rate_noise(self, t):
102
+ return self.sigma_max - self.sigma_min
103
+
104
+ def total_noise(self, t):
105
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
106
+
107
+ def importance_sampling_transformation(self, t):
108
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
109
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
110
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
111
+ return (sigma_t - self.sigma_min) / (
112
+ self.sigma_max - self.sigma_min)
113
+
114
+
115
+ class GeometricNoise(Noise):
116
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
117
+ super().__init__()
118
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
119
+
120
+ def rate_noise(self, t):
121
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
122
+ self.sigmas[1].log() - self.sigmas[0].log())
123
+
124
+ def total_noise(self, t):
125
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
126
+
127
+
128
+ class LogLinearNoise(Noise):
129
+ """Log Linear noise schedule.
130
+
131
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
132
+ ~1 when t varies from 0 to 1. Total noise is
133
+ -log(1 - (1 - eps) * t), so the sigma will be
134
+ (1 - eps) * t.
135
+ """
136
+ def __init__(self, eps=1e-3):
137
+ super().__init__()
138
+ self.eps = eps
139
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
140
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
141
+
142
+ def rate_noise(self, t):
143
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
144
+
145
+ def total_noise(self, t):
146
+ return -torch.log1p(-(1 - self.eps) * t)
147
+
148
+ def importance_sampling_transformation(self, t):
149
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
150
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
151
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
152
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
153
+ return t
pl_data_loader.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+ import json
4
+ import math
5
+ import os
6
+ import re
7
+ import shutil
8
+ import typing
9
+ import urllib
10
+ import zipfile
11
+
12
+ import datasets
13
+ import fsspec
14
+ import requests
15
+ import tokenizers
16
+ import torch
17
+ import transformers
18
+
19
+ import utils
20
+
21
+ LOGGER = utils.get_logger(__name__)
22
+
23
+
24
+ def wt_detokenizer(string):
25
+ # contractions
26
+ string = string.replace("s '", "s'")
27
+ string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
28
+ # number separators
29
+ string = string.replace(" @-@ ", "-")
30
+ string = string.replace(" @,@ ", ",")
31
+ string = string.replace(" @.@ ", ".")
32
+ # punctuation
33
+ string = string.replace(" : ", ": ")
34
+ string = string.replace(" ; ", "; ")
35
+ string = string.replace(" . ", ". ")
36
+ string = string.replace(" ! ", "! ")
37
+ string = string.replace(" ? ", "? ")
38
+ string = string.replace(" , ", ", ")
39
+ # double brackets
40
+ string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
41
+ string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
42
+ string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
43
+ string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
44
+ string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
45
+ # miscellaneous
46
+ string = string.replace("= = = =", "====")
47
+ string = string.replace("= = =", "===")
48
+ string = string.replace("= =", "==")
49
+ string = string.replace(" " + chr(176) + " ", chr(176))
50
+ string = string.replace(" \n", "\n")
51
+ string = string.replace("\n ", "\n")
52
+ string = string.replace(" N ", " 1 ")
53
+ string = string.replace(" 's", "'s")
54
+ return string
55
+
56
+
57
+ def ptb_detokenizer(x):
58
+ x = x.replace(" 's", "'s")
59
+ x = x.replace("s ' ", "s' ")
60
+ x = x.replace(" n't", "n't")
61
+ x = x.replace(" \n ", "\n")
62
+ x = x.replace("\\/", "/")
63
+ for _ in range(10):
64
+ x = x.replace(" N ", " 1 ")
65
+ x = x.replace("$ 1", "$1")
66
+ x = x.replace("# 1", "#1")
67
+ x = x.replace("<unk>", "?")
68
+ return x
69
+
70
+
71
+ def lm1b_detokenizer(x):
72
+ x = x.replace('http : / / ', 'http://')
73
+ x = x.replace('https : / / ', 'https://')
74
+ x = re.sub(r' \'(\w+)', r"'\1", x)
75
+ x = re.sub(r' (\w+) \. ', r' \1. ', x)
76
+ x = re.sub(r' (\w+) \.$', r' \1.', x)
77
+ x = x.replace(' ? ', '? ')
78
+ x = re.sub(r' \?$', '?', x)
79
+ x = x.replace(' ! ', '! ')
80
+ x = re.sub(r' \!$', '!', x)
81
+ x = x.replace(' , ', ', ')
82
+ x = x.replace(' : ', ': ')
83
+ x = x.replace(' ; ', '; ')
84
+ x = x.replace(' / ', '/')
85
+ x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
86
+ x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
87
+ x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
88
+ x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
89
+ x = x.replace('$ ', '$')
90
+ x = x.replace('£ ', '£')
91
+ return x
92
+
93
+
94
+ def lambada_detokenizer(text):
95
+ text = text.replace("“", '"')
96
+ text = text.replace("”", '"')
97
+ return '\n'+text.strip()
98
+
99
+
100
+ def scientific_papers_detokenizer(x):
101
+ x = wt_detokenizer(x)
102
+ x = lm1b_detokenizer(x)
103
+ return x
104
+
105
+
106
+ class Text8Tokenizer(transformers.PreTrainedTokenizer):
107
+ def __init__(
108
+ self,
109
+ bos_token='[BOS]',
110
+ eos_token='[EOS]',
111
+ sep_token='[SEP]',
112
+ cls_token='[CLS]',
113
+ pad_token='[PAD]',
114
+ mask_token='[MASK]',
115
+ unk_token='[UNK]',
116
+ **kwargs):
117
+ self.characters = list('abcdefghijklmnopqrstuvwxyz ')
118
+ self._vocab_str_to_int = {
119
+ '[CLS]': 0,
120
+ '[SEP]': 1,
121
+ '[BOS]': 2,
122
+ '[EOS]': 3,
123
+ '[MASK]': 4,
124
+ '[PAD]': 5,
125
+ '[RESERVED]': 6,
126
+ '[UNK]': 7,
127
+ ** {ch: i + 8 for i, ch in enumerate(self.characters)}}
128
+ self._vocab_int_to_str = {
129
+ v: k for k, v in self._vocab_str_to_int.items()}
130
+ super().__init__(
131
+ bos_token=bos_token,
132
+ eos_token=eos_token,
133
+ sep_token=sep_token,
134
+ cls_token=cls_token,
135
+ pad_token=pad_token,
136
+ mask_token=mask_token,
137
+ unk_token=unk_token,
138
+ **kwargs)
139
+
140
+ @property
141
+ def vocab_size(self) -> int:
142
+ return len(self._vocab_str_to_int)
143
+
144
+ def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
145
+ return list(text.lower())
146
+
147
+ def _convert_token_to_id(self, token: str) -> int:
148
+ return self._vocab_str_to_int.get(
149
+ token, self._vocab_str_to_int['[UNK]'])
150
+
151
+ def _convert_id_to_token(self, index: int) -> str:
152
+ return self._vocab_int_to_str[index]
153
+
154
+ def convert_tokens_to_string(self, tokens):
155
+ return ''.join(tokens)
156
+
157
+ def get_vocab(self) -> typing.Dict[str, int]:
158
+ return self._vocab_str_to_int
159
+
160
+
161
+ def get_lambada_test_dataset():
162
+ url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"
163
+
164
+ def read_jsonl_to_list(url):
165
+ response = requests.get(url, stream=True)
166
+ data_list = []
167
+
168
+ # Process each line in the response content
169
+ for line in response.iter_lines(decode_unicode=True):
170
+ if line:
171
+ data = json.loads(line)
172
+ data_list.append(data)
173
+
174
+ return data_list
175
+
176
+ lambada_data = read_jsonl_to_list(url)
177
+ dataset = datasets.Dataset.from_list(lambada_data)
178
+ return dataset
179
+
180
+ def get_text8_dataset(cache_dir, max_seq_length=256,
181
+ drop_last=True, crop_train=False):
182
+ """Adapted from:
183
+ https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344
184
+
185
+ Args:
186
+ cache_dir: str, path to cache directory.
187
+ max_seq_length: int, maximum length of sequences.
188
+ (default: 256, as in D3PM codebase.)
189
+ drop_last: bool, whether to drop the last incomplete
190
+ batch. (default: True, as in D3PM codebase.)
191
+ crop_train: bool, whether to subsample contiguous
192
+ subsequences from training example. serves to
193
+ make sure transformer models with absolute position
194
+ embeddings do not have incorrect position-wise
195
+ marginals. (default: False, but necessary to match D3PM AR)
196
+
197
+ Returns:
198
+ dataset: dataset.DatasetDict, with keys 'train',
199
+ 'valid', 'test'.
200
+ """
201
+ url = 'http://mattmahoney.net/dc/text8.zip'
202
+ if not crop_train:
203
+ cache_dir = f'{cache_dir}/text8'
204
+ else:
205
+ cache_dir = f'{cache_dir}/text8-crop-train'
206
+ split_names = ['train', 'validation', 'test']
207
+ if not all([
208
+ utils.fsspec_exists(os.path.join(cache_dir, split))
209
+ for split in split_names
210
+ ]):
211
+ # Check if raw data exists
212
+ raw_cache_dir = os.path.join(cache_dir, 'raw_data')
213
+ if not all([
214
+ utils.fsspec_exists(
215
+ os.path.join(raw_cache_dir, f'text8.{split}.txt'))
216
+ for split in split_names
217
+ ]):
218
+ if not utils.fsspec_exists(
219
+ os.path.join(raw_cache_dir, 'text8.zip')):
220
+ utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)
221
+ LOGGER.info('Downloading text8 from URL {}.'.format(url))
222
+ with urllib.request.urlopen(url) as in_stream:
223
+ with open(os.path.join(raw_cache_dir, 'text8.zip'), 'wb') as out_file:
224
+ shutil.copyfileobj(in_stream, out_file)
225
+
226
+ with fsspec.open(
227
+ os.path.join(raw_cache_dir, 'text8.zip'),
228
+ 'rb') as f:
229
+ rawdata = zipfile.ZipFile(f).read(
230
+ 'text8').decode('utf-8')
231
+
232
+ # Splits taken from D3PM codebase
233
+ splits = {
234
+ 'train': rawdata[:90000000],
235
+ 'validation': rawdata[90000000: 95000000],
236
+ 'test': rawdata[95000000:],
237
+ }
238
+
239
+ for split, data in splits.items():
240
+ _path = os.path.join(raw_cache_dir,
241
+ f'text8.{split}.txt')
242
+ with fsspec.open(_path, 'w') as f:
243
+ f.write(data)
244
+ else:
245
+ splits = {}
246
+ for split in split_names:
247
+ _path = os.path.join(raw_cache_dir,
248
+ f'text8.{split}.txt')
249
+ with fsspec.open(_path, 'r') as f:
250
+ splits[split] = f.read()
251
+
252
+ # Chunk and save as datasets.DatasetDict
253
+ def chunks(lst, n):
254
+ """Yield successive n-sized chunks from lst."""
255
+ for i in range(0, len(lst), n):
256
+ yield lst[i:i + n]
257
+
258
+ dataset_dict = {}
259
+ for k, v in splits.items():
260
+ if k == 'train' and crop_train == True:
261
+ chunk_size = 2 * max_seq_length
262
+ else:
263
+ chunk_size = max_seq_length
264
+ text = list(chunks(v, chunk_size))
265
+ if drop_last and len(text[-1]) < chunk_size:
266
+ text = text[:-1]
267
+ dataset_dict[k] = datasets.Dataset.from_dict({'text': text})
268
+ dataset = datasets.DatasetDict(dataset_dict)
269
+ dataset.save_to_disk(cache_dir)
270
+ else:
271
+ dataset = datasets.load_from_disk(cache_dir)
272
+
273
+ return dataset
274
+
275
+
276
+ def _group_texts(examples, block_size, bos, eos):
277
+ # Concatenate all texts.
278
+ concatenated_examples = list(itertools.chain(* examples['input_ids']))
279
+ total_length = len(concatenated_examples)
280
+ # TODO(yair): look into not dropping the remainder but rather padding it.
281
+ # We drop the small remainder, and if the total_length < block_size - 2
282
+ # we exclude this batch and return an empty dict.
283
+ # We could add padding if the model supported it instead of
284
+ # this drop, you can customize this part to your needs.
285
+ new_block_size = block_size - 2 # [BOS] and [EOS] to be added
286
+ total_length = (total_length // new_block_size) * new_block_size
287
+ # Split by chunks of max_len.
288
+ result = {}
289
+ _values = []
290
+ _attn_masks = []
291
+ for i in range(0, total_length, new_block_size):
292
+ _values.append(
293
+ [bos]
294
+ + concatenated_examples[i : i + new_block_size]
295
+ + [eos])
296
+ _attn_masks.append(torch.ones(block_size))
297
+ result['input_ids'] = _values
298
+ result['attention_mask'] = _attn_masks
299
+ return result
300
+
301
+
302
+ def get_dataset(
303
+ dataset_name, tokenizer, wrap, mode, cache_dir,
304
+ block_size=1024, num_proc=len(os.sched_getaffinity(0)), streaming=False):
305
+ if wrap:
306
+ filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped.dat'
307
+ else:
308
+ filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped.dat'
309
+ _path = os.path.join(cache_dir, filename)
310
+
311
+ if utils.fsspec_exists(_path):
312
+ LOGGER.info(f'Loading data from: {_path}')
313
+ return datasets.load_from_disk(_path).with_format('torch')
314
+ LOGGER.info(f'Generating new data at: {_path}')
315
+
316
+ crop_train = dataset_name == 'text8-crop'
317
+ if mode == 'train' and crop_train:
318
+ # double block size for sub-sampling
319
+ block_size *= 2
320
+
321
+ if dataset_name == 'wikitext103':
322
+ dataset = datasets.load_dataset(
323
+ 'wikitext',
324
+ name='wikitext-103-raw-v1',
325
+ cache_dir=cache_dir)
326
+ elif dataset_name == 'wikitext2':
327
+ dataset = datasets.load_dataset(
328
+ 'wikitext',
329
+ name='wikitext-2-raw-v1',
330
+ cache_dir=cache_dir)
331
+ elif dataset_name == 'ptb':
332
+ dataset = datasets.load_dataset(
333
+ 'ptb_text_only', cache_dir=cache_dir)
334
+ elif dataset_name == 'lambada':
335
+ dataset = get_lambada_test_dataset()
336
+ elif dataset_name == 'text8':
337
+ assert wrap
338
+ dataset = get_text8_dataset(
339
+ cache_dir, max_seq_length=block_size)
340
+ elif dataset_name == 'text8-crop':
341
+ dataset = get_text8_dataset(
342
+ cache_dir, max_seq_length=block_size, crop_train=True)
343
+ elif dataset_name == 'openwebtext-train':
344
+ dataset = datasets.load_dataset(
345
+ 'openwebtext',
346
+ split='train[:-100000]',
347
+ cache_dir=cache_dir,
348
+ streaming=streaming)
349
+ elif dataset_name == 'openwebtext-valid':
350
+ dataset = datasets.load_dataset(
351
+ 'openwebtext',
352
+ split='train[-100000:]',
353
+ cache_dir=cache_dir,
354
+ streaming=streaming)
355
+ elif dataset_name == 'scientific_papers_arxiv':
356
+ dataset = datasets.load_dataset(
357
+ 'scientific_papers', 'arxiv',
358
+ trust_remote_code=True,
359
+ cache_dir=cache_dir,
360
+ streaming=streaming)
361
+ elif dataset_name == 'scientific_papers_pubmed':
362
+ dataset = datasets.load_dataset(
363
+ 'scientific_papers', 'pubmed',
364
+ trust_remote_code=True,
365
+ cache_dir=cache_dir,
366
+ streaming=streaming)
367
+ elif dataset_name == 'ag_news':
368
+ dataset = datasets.load_dataset(
369
+ 'ag_news',
370
+ cache_dir=cache_dir,
371
+ streaming=streaming)
372
+ else:
373
+ dataset = datasets.load_dataset(
374
+ dataset_name,
375
+ cache_dir=cache_dir,
376
+ streaming=streaming)
377
+
378
+ if dataset_name in ['lambada', 'openwebtext-train',
379
+ 'openwebtext-valid']:
380
+ data = dataset
381
+ else:
382
+ data = dataset[mode]
383
+
384
+ if dataset_name.startswith('wikitext'):
385
+ detokenizer = wt_detokenizer
386
+ elif dataset_name == 'ptb':
387
+ detokenizer = ptb_detokenizer
388
+ elif dataset_name == 'lm1b':
389
+ detokenizer = lm1b_detokenizer
390
+ elif dataset_name == 'lambada':
391
+ detokenizer = lambada_detokenizer
392
+ elif dataset_name.startswith('scientific_papers'):
393
+ detokenizer = scientific_papers_detokenizer
394
+ else:
395
+ detokenizer = None
396
+
397
+ def _apply_detokenizer(detokenizer):
398
+ def detok(text):
399
+ for i, t in enumerate(text, 0):
400
+ text[i] = detokenizer(t)
401
+ return text
402
+ return detok
403
+
404
+ EOS = tokenizer.encode(tokenizer.eos_token)[0]
405
+ BOS = tokenizer.encode(tokenizer.bos_token)[0]
406
+
407
+ def preprocess_and_tokenize(example):
408
+ if dataset_name == 'ptb':
409
+ text = example['sentence']
410
+ elif 'scientific_papers' in dataset_name:
411
+ text = example['article']
412
+ else:
413
+ text = example['text']
414
+
415
+ if detokenizer is not None:
416
+ text = _apply_detokenizer(detokenizer)(text)
417
+
418
+ tokenizer.padding_side = 'right'
419
+ tokenizer.truncation_side = 'right'
420
+
421
+ if wrap:
422
+ tokens = tokenizer(text,
423
+ add_special_tokens=False,
424
+ return_attention_mask=False,
425
+ return_token_type_ids=False)
426
+ tokens = {'input_ids':
427
+ [t + [EOS] for t in tokens['input_ids']]}
428
+ # Still missing BOS, but will be added in group_texts
429
+ else:
430
+ tokens = tokenizer(text,
431
+ max_length=block_size,
432
+ padding='max_length',
433
+ truncation=True,
434
+ add_special_tokens=True,
435
+ return_attention_mask=True,
436
+ return_token_type_ids=True)
437
+ return tokens
438
+
439
+ if streaming:
440
+ tokenized_dataset = data.map(
441
+ preprocess_and_tokenize,
442
+ batched=True,
443
+ desc='Tokenizing')
444
+ else:
445
+ tokenized_dataset = data.map(
446
+ preprocess_and_tokenize,
447
+ batched=True,
448
+ num_proc=num_proc,
449
+ load_from_cache_file=True,
450
+ desc='Tokenizing')
451
+ if dataset_name == 'ptb':
452
+ tokenized_dataset = tokenized_dataset.remove_columns(
453
+ 'sentence')
454
+ elif 'scientific_papers' in dataset_name:
455
+ tokenized_dataset = tokenized_dataset.remove_columns([
456
+ 'article', 'abstract', 'section_names'])
457
+ elif dataset_name == 'ag_news':
458
+ tokenized_dataset = tokenized_dataset.remove_columns(
459
+ ['text', 'label'])
460
+ else:
461
+ tokenized_dataset = tokenized_dataset.remove_columns(
462
+ 'text')
463
+
464
+ if not wrap:
465
+ tokenized_dataset.save_to_disk(_path)
466
+ return tokenized_dataset.with_format('torch')
467
+
468
+ group_texts = functools.partial(
469
+ _group_texts, block_size=block_size, bos=BOS, eos=EOS)
470
+ if streaming:
471
+ chunked_dataset = tokenized_dataset.map(
472
+ group_texts,
473
+ batched=True,
474
+ desc='Grouping')
475
+ else:
476
+ chunked_dataset = tokenized_dataset.map(
477
+ group_texts,
478
+ batched=True,
479
+ num_proc=num_proc,
480
+ load_from_cache_file=True,
481
+ desc='Grouping')
482
+ chunked_dataset.save_to_disk(_path)
483
+ chunked_dataset = chunked_dataset.with_format('torch')
484
+ return chunked_dataset
485
+
486
+
487
+ def get_tokenizer(config):
488
+ if config.data.tokenizer_name_or_path == 'text8':
489
+ tokenizer = Text8Tokenizer()
490
+ elif config.data.tokenizer_name_or_path == 'bert-base-uncased':
491
+ tokenizer = transformers.BertTokenizer.\
492
+ from_pretrained('bert-base-uncased')
493
+ else:
494
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
495
+ config.data.tokenizer_name_or_path)
496
+
497
+ if (isinstance(tokenizer, transformers.GPT2TokenizerFast)
498
+ or isinstance(tokenizer, transformers.GPT2Tokenizer)):
499
+ tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
500
+ (tokenizer.bos_token, tokenizer.bos_token_id),
501
+ (tokenizer.eos_token, tokenizer.eos_token_id))
502
+
503
+ # For wrapped batches:
504
+ # [BOS] sent1 [EOS] sent2-fragment [EOS]
505
+ # [BOS] sent2-fragment [EOS] sent3 [EOS]
506
+ if tokenizer.bos_token is None:
507
+ if tokenizer.cls_token is None:
508
+ raise AttributeError(
509
+ 'Tokenizer must have a bos_token or '
510
+ f'cls_token: {tokenizer}')
511
+ tokenizer.bos_token = tokenizer.cls_token
512
+ if tokenizer.eos_token is None:
513
+ if tokenizer.sep_token is None:
514
+ raise AttributeError(
515
+ 'Tokenizer must have a eos_token '
516
+ f'or sep_token: {tokenizer}')
517
+ tokenizer.eos_token = tokenizer.sep_token
518
+ if tokenizer.pad_token is None:
519
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
520
+
521
+ return tokenizer
522
+
523
+
524
+ def get_dataloaders(config, tokenizer, skip_train=False,
525
+ skip_valid=False, valid_seed=None):
526
+ num_gpus = torch.cuda.device_count()
527
+ assert (config.loader.global_batch_size
528
+ == (config.loader.batch_size
529
+ * config.trainer.num_nodes
530
+ * num_gpus
531
+ * config.trainer.accumulate_grad_batches))
532
+ if config.loader.global_batch_size % (
533
+ num_gpus * config.trainer.accumulate_grad_batches) != 0:
534
+ raise ValueError(
535
+ f'Train Batch Size {config.training.batch_size}'
536
+ f'not divisible by {num_gpus} gpus with accumulation '
537
+ f'{config.trainer.accumulate_grad_batches}.')
538
+ if config.loader.eval_global_batch_size % num_gpus != 0:
539
+ raise ValueError(
540
+ f'Eval Batch Size for {config.eval.batch_size} '
541
+ f'not divisible by {num_gpus}.')
542
+ if skip_train:
543
+ train_set = None
544
+ else:
545
+ train_set = get_dataset(
546
+ config.data.train,
547
+ tokenizer,
548
+ mode='train',
549
+ wrap=config.data.wrap,
550
+ #cache_dir=config.data.cache_dir,
551
+ block_size=config.model.length)
552
+
553
+ if config.data.valid in ['text8', 'lm1b', 'ag_news']:
554
+ validation_split = 'test'
555
+ else:
556
+ validation_split = 'validation'
557
+ if skip_valid:
558
+ valid_set = None
559
+ else:
560
+ valid_set = get_dataset(
561
+ config.data.valid,
562
+ tokenizer,
563
+ wrap=config.data.wrap,
564
+ mode=validation_split,
565
+ #cache_dir=config.data.cache_dir,
566
+ block_size=config.model.length,
567
+ streaming=False)
568
+
569
+ if skip_train:
570
+ train_loader = None
571
+ else:
572
+ train_loader = torch.utils.data.DataLoader(
573
+ train_set,
574
+ batch_size=config.loader.batch_size,
575
+ num_workers=config.loader.num_workers,
576
+ pin_memory=config.loader.pin_memory,
577
+ shuffle=not config.data.streaming,
578
+ persistent_workers=True)
579
+ train_loader.tokenizer = tokenizer
580
+ if skip_valid:
581
+ valid_loader = None
582
+ else:
583
+ if valid_seed is None:
584
+ shuffle_valid = False
585
+ generator = None
586
+ else:
587
+ shuffle_valid = True
588
+ generator = torch.Generator().manual_seed(valid_seed)
589
+ valid_loader = torch.utils.data.DataLoader(
590
+ valid_set,
591
+ batch_size=config.loader.eval_batch_size,
592
+ num_workers=config.loader.num_workers,
593
+ pin_memory=config.loader.pin_memory,
594
+ shuffle=shuffle_valid,
595
+ generator=generator)
596
+ # Will be used in generative perplexity calculation
597
+ valid_loader.tokenizer = tokenizer
598
+
599
+ return train_loader, valid_loader
600
+
601
+
602
+ # Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
603
+
604
+
605
+ class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
606
+
607
+ def __init__(self, *args, generator=None, **kwargs):
608
+ # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
609
+ # which should be reproducible if pl.seed_everything was called beforehand.
610
+ # This means that changing the seed of the experiment will also change the
611
+ # sampling order.
612
+ if generator is None:
613
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
614
+ generator = torch.Generator().manual_seed(seed)
615
+ kwargs.pop('shuffle', None)
616
+ super().__init__(*args, generator=generator, **kwargs)
617
+ self.counter = 0
618
+ self.restarting = False
619
+
620
+ def state_dict(self):
621
+ return {'random_state': self.generator.get_state(),
622
+ 'counter': self.counter}
623
+
624
+ def load_state_dict(self, state_dict):
625
+ self.generator.set_state(state_dict.get('random_state'))
626
+ self.counter = state_dict['counter']
627
+ # self.start_counter = self.counter
628
+ self.restarting = True
629
+
630
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
631
+ # epoch, and subsequent epoch will have very few batches.
632
+
633
+ def __iter__(self) -> typing.Iterator[int]:
634
+ n = len(self.data_source)
635
+
636
+ self.state = self.generator.get_state()
637
+ indices = torch.randperm(n, generator=self.generator).tolist()
638
+
639
+ if not self.restarting:
640
+ self.counter = 0
641
+ else:
642
+ indices = indices[self.counter:]
643
+ self.restarting = False
644
+
645
+ for index in indices:
646
+ self.counter += 1
647
+ yield index
648
+
649
+ self.counter = 0
650
+
651
+
652
+ class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
653
+
654
+ def __init__(self, *args, **kwargs):
655
+ super().__init__(*args, **kwargs)
656
+ self.counter = 0
657
+ self.restarting = False
658
+
659
+ def state_dict(self):
660
+ return {'epoch': self.epoch, 'counter': self.counter}
661
+
662
+ def load_state_dict(self, state_dict):
663
+ self.epoch = state_dict['epoch']
664
+ self.counter = state_dict['counter']
665
+ self.restarting = True
666
+
667
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
668
+ # epoch, and subsequent epoch will have very few batches.
669
+ def __iter__(self):
670
+ if self.shuffle:
671
+ # deterministically shuffle based on epoch and seed
672
+ g = torch.Generator()
673
+ g.manual_seed(self.seed + self.epoch)
674
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
675
+ else:
676
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
677
+
678
+ if not self.drop_last:
679
+ # add extra samples to make it evenly divisible
680
+ padding_size = self.total_size - len(indices)
681
+ if padding_size <= len(indices):
682
+ indices += indices[:padding_size]
683
+ else:
684
+ indices += (indices * math.ceil(
685
+ padding_size / len(indices)))[:padding_size]
686
+ else:
687
+ # remove tail of data to make it evenly divisible.
688
+ indices = indices[:self.total_size]
689
+ assert len(indices) == self.total_size
690
+
691
+ # subsample
692
+ indices = indices[self.rank:self.total_size:self.num_replicas]
693
+ assert len(indices) == self.num_samples
694
+
695
+ if not self.restarting:
696
+ self.counter = 0
697
+ else:
698
+ indices = indices[self.counter:]
699
+ self.restarting = False
700
+
701
+ for index in indices:
702
+ self.counter += 1
703
+ yield index
704
+
705
+ self.counter = 0
706
+
707
+ from torch.utils.data import Dataset, DataLoader
708
+ import lightning.pytorch as pl
709
+ from functools import partial
710
+ import sys
711
+
712
+ class CustomDataset(torch.utils.data.Dataset):
713
+ def __init__(self, dataset, indices):
714
+ self.dataset = dataset
715
+ self.indices = indices
716
+
717
+ def __len__(self):
718
+ return len(self.indices)
719
+
720
+ def __getitem__(self, idx):
721
+ actual_idx = int(self.indices[idx])
722
+ item = self.dataset[actual_idx]
723
+ return item
724
+
725
+ def membrane_collate_fn(batch, tokenizer):
726
+ """Custom data collator that masks TM/soluble residues for focused training"""
727
+ MAX_LENGTH = 1024
728
+ sequences = [item['Sequence'].upper() for item in batch]
729
+
730
+ masks = []
731
+ for item in batch:
732
+ if item["Label"] == 0:
733
+ mask = [1 if i.isupper() else 0 for i in item["Sequence"]]
734
+ else:
735
+ mask = [0 if i.isupper() else 1 for i in item["Sequence"]]
736
+ mask = [1] + mask
737
+ if len(mask) > MAX_LENGTH: # Truncate
738
+ mask = mask[:MAX_LENGTH]
739
+ elif len(mask) < MAX_LENGTH: # Pad
740
+ mask += [1] * (MAX_LENGTH - len(mask))
741
+
742
+ masks.append(torch.as_tensor(mask))
743
+
744
+ mask_t = torch.stack(masks, dim=0)
745
+ tokens = tokenizer(sequences, return_tensors='pt', padding='max_length', truncation=True, max_length=MAX_LENGTH)
746
+
747
+ return {
748
+ 'input_ids': tokens['input_ids'],
749
+ 'attention_mask': tokens['attention_mask'],
750
+ 'mask': mask_t
751
+ }
752
+
753
+ def wrap_collate_fn(batch, tokenizer):
754
+ """Standard data collator that wraps sequences over padding them"""
755
+ # Define sequence size
756
+ chunk_size = 1024
757
+ eos_placeholder = "k"
758
+ eos = "<eos>"
759
+
760
+ # Wrap sequences by collecting and splitting them into chunks
761
+ # From MDLM paper: insert <eos> at start/end of chunks and in between sequences
762
+ sequences = eos_placeholder.join([item['Sequence'].upper() for item in batch])
763
+ sequences = eos_placeholder + sequences + eos_placeholder
764
+ wrapped_sequences = [sequences[i:i+chunk_size] for i in range(0, len(sequences), chunk_size)]
765
+ for idx, seq in enumerate(wrapped_sequences):
766
+ wrapped_sequences[idx] = seq.replace(eos_placeholder, eos)
767
+
768
+ # Tokenize for input ids and attention masks
769
+ tokens = tokenizer(wrapped_sequences, return_tensors='pt', padding=True)
770
+
771
+ return {
772
+ "input_ids": tokens['input_ids'],
773
+ "attention_mask": tokens['attention_mask']
774
+ }
775
+
776
+
777
+
778
+ def collate_fn(batch, tokenizer):
779
+ """Standard data collator that truncates/pad sequences based on max_length"""
780
+ sequences = [item['Sequence'].upper() for item in batch]
781
+ max_len = max([len(seq) for seq in sequences])
782
+ #labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
783
+
784
+ tokens = tokenizer(sequences, return_tensors='pt', padding='max_length', truncation=True, max_length=1024)
785
+
786
+ #attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
787
+
788
+ return {
789
+ 'input_ids': tokens['input_ids'],
790
+ 'attention_mask': tokens['attention_mask']
791
+ }
792
+
793
+ class CustomDataModule(pl.LightningDataModule):
794
+ def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size: int=8, collate_fn=collate_fn):
795
+ super().__init__()
796
+ self.train_dataset = train_dataset
797
+ self.val_dataset = val_dataset
798
+ self.test_dataset = test_dataset
799
+ self.batch_size = batch_size
800
+ self.tokenizer = tokenizer
801
+ self.collate_fn = collate_fn
802
+
803
+ def train_dataloader(self):
804
+ return DataLoader(self.train_dataset, batch_size=self.batch_size,
805
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
806
+ num_workers=8, pin_memory=True)
807
+
808
+
809
+ def val_dataloader(self):
810
+ return DataLoader(self.val_dataset, batch_size=self.batch_size,
811
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
812
+ num_workers=8, pin_memory=True)
813
+
814
+ def test_dataloader(self):
815
+ return DataLoader(self.test_dataset, batch_size=self.batch_size,
816
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
817
+ num_workers=8, pin_memory=True)
818
+
819
+
utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+ import math
9
+
10
+ import fsspec
11
+ import lightning
12
+ import torch
13
+ from timm.scheduler import CosineLRScheduler
14
+
15
+
16
+ def fsspec_exists(filename):
17
+ """Check if a file exists using fsspec."""
18
+ fs, _ = fsspec.core.url_to_fs(filename)
19
+ return fs.exists(filename)
20
+
21
+
22
+ def fsspec_listdir(dirname):
23
+ """Listdir in manner compatible with fsspec."""
24
+ fs, _ = fsspec.core.url_to_fs(dirname)
25
+ return fs.ls(dirname)
26
+
27
+
28
+ def fsspec_mkdirs(dirname, exist_ok=True):
29
+ """Mkdirs in manner compatible with fsspec."""
30
+ fs, _ = fsspec.core.url_to_fs(dirname)
31
+ fs.makedirs(dirname, exist_ok=exist_ok)
32
+
33
+
34
+ def print_nans(tensor, name):
35
+ if torch.isnan(tensor).any():
36
+ print(name, tensor)
37
+
38
+
39
+ class CosineDecayWarmupLRScheduler(
40
+ CosineLRScheduler,
41
+ torch.optim.lr_scheduler._LRScheduler):
42
+ """Wrap timm.scheduler.CosineLRScheduler
43
+ Enables calling scheduler.step() without passing in epoch.
44
+ Supports resuming as well.
45
+ Adapted from:
46
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
47
+ """
48
+
49
+ def __init__(self, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+ self._last_epoch = -1
52
+ self.step(epoch=0)
53
+
54
+ def step(self, epoch=None):
55
+ if epoch is None:
56
+ self._last_epoch += 1
57
+ else:
58
+ self._last_epoch = epoch
59
+ # We call either step or step_update, depending on
60
+ # whether we're using the scheduler every epoch or every
61
+ # step.
62
+ # Otherwise, lightning will always call step (i.e.,
63
+ # meant for each epoch), and if we set scheduler
64
+ # interval to "step", then the learning rate update will
65
+ # be wrong.
66
+ if self.t_in_epochs:
67
+ super().step(epoch=self._last_epoch)
68
+ else:
69
+ super().step_update(num_updates=self._last_epoch)
70
+
71
+
72
+ class LoggingContext:
73
+ """Context manager for selective logging."""
74
+ def __init__(self, logger, level=None, handler=None, close=True):
75
+ self.logger = logger
76
+ self.level = level
77
+ self.handler = handler
78
+ self.close = close
79
+
80
+ def __enter__(self):
81
+ if self.level is not None:
82
+ self.old_level = self.logger.level
83
+ self.logger.setLevel(self.level)
84
+ if self.handler:
85
+ self.logger.addHandler(self.handler)
86
+
87
+ def __exit__(self, et, ev, tb):
88
+ if self.level is not None:
89
+ self.logger.setLevel(self.old_level)
90
+ if self.handler:
91
+ self.logger.removeHandler(self.handler)
92
+ if self.handler and self.close:
93
+ self.handler.close()
94
+
95
+
96
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
97
+ """Initializes multi-GPU-friendly python logger."""
98
+
99
+ logger = logging.getLogger(name)
100
+ logger.setLevel(level)
101
+
102
+ # this ensures all logging levels get marked with the rank zero decorator
103
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
104
+ for level in ('debug', 'info', 'warning', 'error',
105
+ 'exception', 'fatal', 'critical'):
106
+ setattr(logger,
107
+ level,
108
+ lightning.pytorch.utilities.rank_zero_only(
109
+ getattr(logger, level)))
110
+
111
+ return logger
112
+
113
+
114
+ class Sampler:
115
+ def __init__(self, shape):
116
+ self.shape = shape
117
+
118
+ def _sampling_noise(self):
119
+ pass
120
+
121
+ def _hard_sample(self, logits):
122
+ pass
123
+
124
+ def _soft_sample(self, logits):
125
+ return 0
126
+
127
+ def sample(self, logits):
128
+ noise = self._sampling_noise()
129
+ noise = noise[: logits.shape[0], :]
130
+ logits = logits + noise.to(
131
+ dtype=logits.dtype, device=logits.device)
132
+ hard_sample = self._hard_sample(logits)
133
+ soft_sample = self._soft_sample(logits)
134
+ return soft_sample + (hard_sample - soft_sample).detach()
135
+
136
+
137
+ class TopKSampler(Sampler):
138
+ def __init__(self, k, shape, gamma_tau=1.0):
139
+ super().__init__(shape)
140
+ self.k = k
141
+ self.gamma_tau = gamma_tau
142
+ self.num_betas = 10
143
+ self.sampler = torch.distributions.gamma.Gamma(
144
+ 1 / k * torch.ones(self.num_betas, * self.shape), 1.0)
145
+
146
+ def _sampling_noise(self):
147
+ noise = self.sampler.sample()
148
+ beta = self.k / torch.arange(1, self.num_betas + 1, 1,
149
+ dtype=torch.float32)
150
+ beta = beta[:, None, None]
151
+ assert beta.ndim == noise.ndim
152
+ s = noise / beta
153
+ s = torch.sum(s, axis=0)
154
+ s = s - math.log(10.0)
155
+ s = self.gamma_tau * (s / self.k)
156
+ return s
157
+
158
+ def _hard_sample(self, logits):
159
+ assert logits.ndim == 2
160
+ thresholds, _ = torch.sort(logits, dim=-1)
161
+ thresholds = thresholds[:, - self.k][:, None]
162
+ return (logits >= thresholds).type(logits.dtype)
163
+
164
+ def _soft_sample(self, logits):
165
+ soft_top_k = logits - torch.mean(logits, dim=-1,
166
+ keepdim=True)
167
+ return soft_top_k / torch.norm(soft_top_k, dim=-1,
168
+ keepdim=True)
169
+
170
+
171
+ class DeterministicTopK(TopKSampler):
172
+ def __init__(self, k):
173
+ super().__init__(k, shape=(1, 1))
174
+
175
+ def _sampling_noise(self):
176
+ return 0
177
+
178
+ def discreize(self, x):
179
+ hard_sample = self._hard_sample(x)
180
+ soft_sample = self._soft_sample(x)
181
+ return soft_sample + (hard_sample - soft_sample).detach()
182
+
183
+ class GumbelSampler(Sampler):
184
+
185
+ def __init__(self, shape, temperature=1.0):
186
+ super().__init__(shape)
187
+ self.temperature = temperature
188
+
189
+ def _sampling_noise(self):
190
+ return - (1e-10 - (
191
+ torch.rand(* self.shape) + 1e-10).log()).log()
192
+
193
+ def _hard_sample(self, logits):
194
+ assert logits.ndim == 2
195
+ indices = torch.argmax(logits, dim=-1)
196
+ zeros = logits * 0
197
+ ones = torch.ones_like(logits[:, :, :1])
198
+ return torch.scatter(zeros, -1, indices[:, :, None],
199
+ ones)
200
+
201
+ def _soft_sample(self, logits):
202
+ return torch.nn.functional.softmax(
203
+ logits / self.temperature, dim=-1)
204
+
205
+
206
+ class BinarySampler(GumbelSampler):
207
+
208
+ def sample(self, probs):
209
+ # TODO(subhamsahoo): use the temperature parameter.
210
+ pos_noise = self._sampling_noise().to(
211
+ dtype=probs.dtype, device=probs.device)
212
+ neg_noise = self._sampling_noise().to(
213
+ dtype=probs.dtype, device=probs.device)
214
+ del_noise_exp = (neg_noise - pos_noise).exp()
215
+ hard_sample = (probs * (1 + del_noise_exp)
216
+ > 1).to(probs.dtype)
217
+ soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
218
+ return soft_sample + (hard_sample - soft_sample).detach()
219
+
220
+
221
+ class GaussianSampler:
222
+ def __init__(self):
223
+ self.softplus = torch.nn.Softplus()
224
+
225
+ def sample(self, x):
226
+ assert x.ndim == 2
227
+ n = x.shape[-1] // 2
228
+ mu = x[:, :n]
229
+ sigma = self.softplus(x[:, n:]).sqrt()
230
+ return mu + sigma * torch.randn_like(mu)