|
import torch |
|
import config |
|
import math |
|
import sys |
|
import os |
|
from tqdm import tqdm |
|
from torch.optim import AdamW |
|
from transformers import AutoTokenizer |
|
from diffusion import WrapESM, Diffusion |
|
from data_loader import get_dataloaders |
|
|
|
def save_hyperparams(ckpt_dir): |
|
hyperparms_txt_file = os.path.join(ckpt_dir, "hyperparameters.txt") |
|
with open(hyperparms_txt_file, 'w') as f: |
|
for k, v in vars(config).items(): |
|
if k.isupper(): |
|
f.write(f"{k}: {v}\n") |
|
|
|
def train_and_validate(model, optimizer, device, train_loader, val_loader, num_epochs, ckpt_dir): |
|
best_val_loss = float('inf') |
|
|
|
for epoch in range(num_epochs): |
|
model.train() |
|
|
|
print(f"EPOCH {epoch+1}/{num_epochs}") |
|
sys.stderr.flush() |
|
total_loss = 0.0 |
|
train_tokens = 0 |
|
weighted_total_train_loss = 0.0 |
|
|
|
train_update_interval = len(train_loader) // 4 |
|
|
|
with tqdm(enumerate(train_loader), desc="Training batch", total=len(train_loader), leave=True, position=0, ncols=100) as trainbar: |
|
for step, inputs in trainbar: |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
optimizer.zero_grad() |
|
outputs = model(**inputs) |
|
train_loss = diffusion_model.compute_loss(inputs["input_ids"], inputs['attention_mask'], |
|
val=False).loss |
|
train_loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += train_loss.item() |
|
weighted_total_train_loss += train_loss.item() * inputs['input_ids'].shape[1] |
|
train_tokens += inputs['input_ids'].shape[1] |
|
|
|
if (step+1) % train_update_interval == 0: |
|
trainbar.update(train_update_interval) |
|
|
|
avg_train_loss = total_loss / len(train_loader) |
|
avg_train_neg_log_likelihood = weighted_total_train_loss / train_tokens |
|
train_perplexity = math.exp(avg_train_neg_log_likelihood) |
|
|
|
|
|
train_ckpt_path = os.path.join(config.Eval.CHECKPOINT_PATH, f'epoch{epoch+1}') |
|
model.save_model(train_ckpt_path) |
|
save_hyperparams(train_ckpt_path) |
|
|
|
|
|
if val_loader: |
|
model.eval() |
|
total_val_loss = 0.0 |
|
weighted_total_val_loss = 0.0 |
|
val_tokens = 0 |
|
|
|
with torch.no_grad(): |
|
val_update_interval = len(val_loader) // 4 |
|
|
|
with tqdm(enumerate(val_loader), desc='Validiation batch', total=len(val_loader), leave=True, position=0) as valbar: |
|
for step, inputs in valbar: |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
outputs = model(**inputs) |
|
val_loss = diffusion_model.compute_loss(inputs['input_ids'], inputs['attention_mask'], |
|
val=True).loss.item() |
|
|
|
total_val_loss += val_loss |
|
weighted_total_val_loss += val_loss * inputs['input_ids'].shape[1] |
|
val_tokens += inputs['input_ids'].shape[1] |
|
|
|
if (step+1) % val_update_interval == 0: |
|
valbar.update(val_update_interval) |
|
|
|
avg_val_loss = total_val_loss / len(val_loader) |
|
avg_val_log_likelihood = weighted_total_val_loss / val_tokens |
|
val_perplexity = math.exp(avg_val_log_likelihood) |
|
|
|
|
|
if avg_val_loss < best_val_loss: |
|
best_val_loss = avg_val_loss |
|
val_ckpt_path = os.path.join(config.Eval.CHECKPOINT_PATH, "best_model_epoch") |
|
model.save_model(val_ckpt_path) |
|
save_hyperparams(val_ckpt_path) |
|
|
|
|
|
print(f"Average train loss: {avg_train_loss}") |
|
print(f"Average train perplexity: {train_perplexity}\n") |
|
sys.stdout.flush() |
|
|
|
print(f"Average validation loss: {avg_val_loss}") |
|
print(f"Average validation perplexity: {val_perplexity}\n") |
|
sys.stdout.flush() |
|
|
|
|
|
return avg_train_loss, train_perplexity, avg_val_loss, val_perplexity |
|
|
|
|
|
def test(model, test_loader, device): |
|
model.to(device).eval() |
|
total_test_loss = 0.0 |
|
weighted_total_test_loss = 0.0 |
|
test_tokens = 0 |
|
|
|
with torch.no_grad(): |
|
for step, inputs in enumerate(test_loader): |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
outputs = model(**inputs) |
|
test_loss = diffusion_model.compute_loss(inputs['input_ids'], inputs['attention_mask'], |
|
val=True).loss.item() |
|
|
|
total_test_loss += test_loss |
|
weighted_total_test_loss += test_loss * inputs['input_ids'].shape[1] |
|
test_tokens += inputs['input_ids'].shape[1] |
|
|
|
avg_test_loss = total_test_loss / len(test_loader) |
|
avg_test_log_likelihood = weighted_total_test_loss / test_tokens |
|
test_perplexity = math.exp(avg_test_log_likelihood) |
|
|
|
return avg_test_loss, test_perplexity |
|
|
|
|
|
if __name__ == "__main__": |
|
device = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
|
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) |
|
|
|
esm_model = WrapESM() |
|
diffusion_model = Diffusion(config, tokenizer=tokenizer) |
|
|
|
print(f'Trainable params before unfreezing: {sum(p.numel() for p in esm_model.parameters() if p.requires_grad)}') |
|
|
|
esm_model.to(device) |
|
diffusion_model.to(device) |
|
|
|
esm_model.freeze_model() |
|
esm_model.unfreeze_n_layers() |
|
|
|
print(f'Trainable params after unfreezing: {sum(p.numel() for p in esm_model.parameters() if p.requires_grad)}') |
|
|
|
train_loader, val_loader, test_loader = get_dataloaders(config) |
|
optimizer = AdamW(filter(lambda p: p.requires_grad, esm_model.parameters()), lr=config.Optim.LR) |
|
|
|
|
|
avg_train_loss, train_ppl, avg_val_loss, val_ppl = train_and_validate(esm_model, optimizer, device, train_loader, val_loader, config.Training.NUM_EPOCHS, config.Eval.CHECKPOINT_PATH) |
|
avg_test_loss, test_ppl = test(esm_model, test_loader, device) |
|
|
|
results_dict = {"Average train loss": avg_train_loss, |
|
"Average train perplexity": train_ppl, |
|
"Average val loss": avg_val_loss, |
|
"Average val perplexity": val_ppl, |
|
"Average test loss": avg_test_loss, |
|
"Average test perplexity": test_ppl, |
|
} |
|
|
|
print("TRAIN AND TEST RESULTS") |
|
for k, v in results_dict.items(): |
|
print(f"{k}: {v}\n") |