license: mit
language:
- en
library_name: peft
tags:
- ESM-2
- QLoRA
- Binding Sites
- biology
ESM-2 QLoRA
These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models.
This is the smallest esm2_t6_8M_UR50D
model, so the metrics aren't great.
Scaling to larger models for better metrics is in progress. These checkpoints were trained using
the 600K dataset. To replicate the training of QLoRA for ESM-2 models,
you can use the conda-environment.yml
file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:
pip install --upgrade git+https://github.com/huggingface/transformers.git
Once the transformers library is updated, you should be able to simply use the latest version of transformers and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models.
QLoRA Info
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices.
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this transfers to protein language models as well.
Testing for Overfitting
Checkpoint 1
Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}
Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}
Checkpoint 4
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}
Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}