Using LoRA to finetune
Hello,
I am trying to finetune the Caduceus PS model using LoRA to speed up my finetuning due to limited compute resources. However, I have noticed that the model doesn't learn when using LoRA. This is because the forward pass doesn't seem to use the LoRA layers despite being inserted into the model, and after a backward pass, the 'grad' attribute of the LoRA layer weights are all 'None'. The only layer where the LoRA parameters have 'grad' values after a backward pass is the 'embedding' layer. Any layer within the 'BiMambaWrapper' doesn't seem to be functioning correctly. I was wondering if you have any insights into why this may be happening and how it can be fixed?
In the code given below, I create a dummy dataset of seq_len=1024 just to test if the gradients are being calculated for the lora layers. When it is run, the model shows 194 trainable parameters, but only 2/194 have 'grad' values after the backward pass, and these two are the lora_A and lora_B parameters of the embedding layer.
Package versions:
transformers==4.38.1
accelerate==0.34.0
peft==0.13.2
torch==2.2.0
Code:
import datetime
import datasets
import torch
import yaml
from peft import LoraConfig, get_peft_model
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
set_seed,
)
RANDOM_SEED = 3456
def check_loss_graph(model, batch):
print()
print("\nChecking loss graph:")
outputs = model(**batch)
loss = outputs.loss
print(f" Loss requires grad: {loss.requires_grad}")
print(f" Loss graph connected: {loss.grad_fn is not None}")
print(f" Loss graph type: {type(loss.grad_fn)}")
# Force a backward pass
loss.backward()
# Check grads after backward
param_cnt = 0
param_grad_cnt = 0
for name, param in model.named_parameters():
if param.requires_grad:
param_cnt += 1
if param.grad is not None:
param_grad_cnt += 1
# if param.grad is not None:
# print(f" grad norm = {param.grad.norm().item()}")
print(f" Total trainable parameters: {param_cnt}")
print(f" Total trainable parameters with grads: {param_grad_cnt}")
def load_dataset(seq_length, tokenizer):
## Random data sequence
def get_random_sequence(length):
nt_dict = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
rand_seq = torch.randint(0, 4, (length,))
return ''.join([nt_dict[int(i)] for i in rand_seq])
def tokenize_function(examples):
return tokenizer(examples["sequence"], add_special_tokens=False)
tokenizer.model_max_length = seq_length + 1
dataset = datasets.Dataset.from_dict({"sequence": [get_random_sequence(seq_length) for _ in range(100)]})
new_dset = dataset.train_test_split(test_size=0.2, seed=RANDOM_SEED)
new_dset_test = new_dset['test']
new_dset_test = new_dset_test.train_test_split(test_size=0.5, seed=RANDOM_SEED)
new_dset_val = new_dset_test['test']
new_dset_test = new_dset_test['train']
new_dset_train = new_dset['train']
new_dset = datasets.DatasetDict(
{'train': new_dset_train,
'test': new_dset_test,
'val': new_dset_val})
dataset = new_dset
print("Tokenizing the dataset...")
tokenized_dataset = dataset.map(tokenize_function,
batched=True,
remove_columns=dataset['train'].column_names)
# Add labels to the tokenized dataset
print("Adding labels to the tokenized dataset...")
tokenized_dataset = tokenized_dataset.map(lambda example: {"labels": example["input_ids"]},
batched=True)
# tokenized_dataset = tokenized_dataset.map(lambda example: {"attention_mask": [1]*len(example["input_ids"])},
# batched=True)
# Create a data collator for masked language modeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=0.15)
return tokenized_dataset, data_collator
set_seed(RANDOM_SEED)
## Load tokenizer and model
model_name = "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name,
trust_remote_code=True)
# Set the padding token id to the pad token id of the tokenizer so that
# the trainer passes this to the ignore_index of the loss function. By
# default, the padding_token_id for this model is None, and this throws
# an error in the Trainer.
# The token id -100 is used by the collator to mark non mask tokens. The
# loss function seems to run into problems here if it is not ignored.
model.config.pad_token_id = -100
## LORA for pre-training
lora_config = LoraConfig(
r=8,
target_modules=["embedding", "x_proj", "in_proj", "out_proj"],
bias="none"
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
# Debugging call
# Check if the intended parameters are set to trainable
for n, p in peft_model.named_parameters():
if 'lora' in n.lower():
assert p.requires_grad, f"Parameter has requires_grad==False: {n}"
# Debugging call
# Check if model is in training mode
assert peft_model.training, "Model is not in training mode"
for name, module in peft_model.named_modules():
if 'lora' in name.lower():
assert module.training, f"Not in training mode: {name}"
tokenized_dataset, data_collator = load_dataset(1024, tokenizer)
# get date and time a ddMMYYYY_HHMMSS
now = datetime.datetime.now()
date = now.strftime("%d%b%Y_%H%M%S")
# Define training arguments
training_args = {
'max_steps': 500,
'num_train_epochs': 3,
'learning_rate': 8.0e-3,
'weight_decay': 0.01,
'per_device_train_batch_size': 1,
'gradient_accumulation_steps': 8,
'logging_strategy': "steps",
'logging_steps': 100,
'save_strategy': "steps",
'save_steps': 100,
'save_total_limit': 1,
'eval_accumulation_steps': 10,
'evaluation_strategy': "steps",
'eval_steps': 100,
'label_names': ["labels"],
'lr_scheduler_type': "constant",
'report_to': "tensorboard"
}
lora = "lora"
output_dir = f"results/caduceus_lora_debug_Jan_2025/.outputs/{lora}_{date}"
logging_dir = f"results/caduceus_lora_debug_Jan_2025/.logs/{lora}_{date}"
training_args.update({
'output_dir': output_dir,
'logging_dir': logging_dir
}
)
training_args = TrainingArguments(**training_args)
# Create Trainer
trainer = Trainer(
model=peft_model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["val"],
data_collator=data_collator,
tokenizer=tokenizer,
)
sample_batch = next(iter(trainer.get_train_dataloader()))
# Debugging call
check_loss_graph(peft_model, sample_batch)