license: mit
language:
- en
library_name: peft
tags:
- ESM-2
- QLoRA
- Binding Sites
- biology
ESM-2 QLoRA
These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models.
This is the smallest esm2_t6_8M_UR50D
model, so the metrics aren't great.
Scaling to larger models for better metrics is in progress. These checkpoints were trained using
the 600K dataset. To replicate the training of QLoRA for ESM-2 models,
you can use the conda-environment.yml
file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:
pip install --upgrade git+https://github.com/huggingface/transformers.git
In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models.
QLoRA Info
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices.
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this transfers to protein language models as well.
Testing for Overfitting
Checkpoint 1
Train/Test Split from 600K dataset:
Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}
Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}
Metrics for these datasets:
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 54/54 [00:04<00:00, 11.49it/s]
Dataset: GTP_Training.txt
Accuracy: 0.8777
Precision: 0.1488
Recall: 0.5517
F1 Score: 0.2344
AUC: 0.7204
MCC: 0.2407
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 82/82 [00:02<00:00, 32.07it/s]
Dataset: GDP_Training.txt
Accuracy: 0.8711
Precision: 0.1768
Recall: 0.6022
F1 Score: 0.2733
AUC: 0.7423
MCC: 0.2768
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 172/172 [00:06<00:00, 27.98it/s]
Dataset: FE_Training.txt
Accuracy: 0.8424
Precision: 0.0547
Recall: 0.5452
F1 Score: 0.0994
AUC: 0.6962
MCC: 0.1344
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 145/145 [00:04<00:00, 33.86it/s]
Dataset: AMP_Training.txt
Accuracy: 0.8191
Precision: 0.0975
Recall: 0.5078
F1 Score: 0.1636
AUC: 0.6691
MCC: 0.1609
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 206/206 [00:05<00:00, 34.97it/s]
Dataset: HEME_Training.txt
Accuracy: 0.8561
Precision: 0.2089
Recall: 0.2795
F1 Score: 0.2391
AUC: 0.5932
MCC: 0.1636
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 221/221 [00:06<00:00, 31.64it/s]
Dataset: ATP_Training.txt
Accuracy: 0.8631
Precision: 0.1459
Recall: 0.4975
F1 Score: 0.2256
AUC: 0.6879
MCC: 0.2146
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 335/335 [00:10<00:00, 33.49it/s]
Dataset: DNA_Training.txt
Accuracy: 0.8387
Precision: 0.1608
Recall: 0.2233
F1 Score: 0.1870
AUC: 0.5589
MCC: 0.1017
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 296/296 [00:08<00:00, 32.99it/s]
Dataset: ADP_Training.txt
Accuracy: 0.8653
Precision: 0.1415
Recall: 0.5142
F1 Score: 0.2219
AUC: 0.6966
MCC: 0.2176
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 334/334 [00:10<00:00, 31.30it/s]
Dataset: MN_Training.txt
Accuracy: 0.8507
Precision: 0.0488
Recall: 0.5602
F1 Score: 0.0898
AUC: 0.7074
MCC: 0.1320
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 1152/1152 [00:36<00:00, 31.70it/s]
Dataset: ZN_Training.txt
Accuracy: 0.8418
Precision: 0.0437
Recall: 0.4674
F1 Score: 0.0799
AUC: 0.6574
MCC: 0.1041
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 1131/1131 [00:35<00:00, 31.87it/s]
Dataset: MG_Training.txt
Accuracy: 0.8454
Precision: 0.0327
Recall: 0.4617
F1 Score: 0.0611
AUC: 0.6556
MCC: 0.0896
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 961/961 [00:30<00:00, 31.67it/s]
Dataset: CA_Training.txt
Accuracy: 0.8524
Precision: 0.0251
Recall: 0.2057
F1 Score: 0.0447
AUC: 0.5346
MCC: 0.0258
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 27/27 [00:01<00:00, 26.47it/s]
Dataset: HEME_Validation.txt
Accuracy: 0.8891
Precision: 0.2125
Recall: 0.2810
F1 Score: 0.2420
AUC: 0.6055
MCC: 0.1855
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 7/7 [00:00<00:00, 20.36it/s]
Dataset: GTP_Validation.txt
Accuracy: 0.8012
Precision: 0.1377
Recall: 0.6404
F1 Score: 0.2266
AUC: 0.7247
MCC: 0.2292
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 14/14 [00:00<00:00, 17.96it/s]
Dataset: GDP_Validation.txt
Accuracy: 0.7954
Precision: 0.1456
Recall: 0.7423
F1 Score: 0.2434
AUC: 0.7701
MCC: 0.2658
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 26/26 [00:00<00:00, 27.91it/s]
Dataset: FE_Validation.txt
Accuracy: 0.8523
Precision: 0.0571
Recall: 0.6667
F1 Score: 0.1052
AUC: 0.7607
MCC: 0.1646
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 58/58 [00:01<00:00, 30.49it/s]
Dataset: MN_Validation.txt
Accuracy: 0.8445
Precision: 0.0458
Recall: 0.5359
F1 Score: 0.0844
AUC: 0.6923
MCC: 0.1216
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 33/33 [00:00<00:00, 34.34it/s]
Dataset: AMP_Validation.txt
Accuracy: 0.8116
Precision: 0.1065
Recall: 0.5638
F1 Score: 0.1792
AUC: 0.6924
MCC: 0.1827
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 52/52 [00:01<00:00, 32.70it/s]
Dataset: DNA_Validation.txt
Accuracy: 0.8849
Precision: 0.1306
Recall: 0.1829
F1 Score: 0.1524
AUC: 0.5550
MCC: 0.0940
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 50/50 [00:01<00:00, 33.79it/s]
Dataset: ATP_Validation.txt
Accuracy: 0.8497
Precision: 0.1220
Recall: 0.4869
F1 Score: 0.1952
AUC: 0.6753
MCC: 0.1868
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 47/47 [00:01<00:00, 31.43it/s]
Dataset: ADP_Validation.txt
Accuracy: 0.8652
Precision: 0.1279
Recall: 0.5379
F1 Score: 0.2067
AUC: 0.7071
MCC: 0.2139
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 176/176 [00:05<00:00, 32.21it/s]
Dataset: ZN_Validation.txt
Accuracy: 0.8486
Precision: 0.0461
Recall: 0.4516
F1 Score: 0.0837
AUC: 0.6532
MCC: 0.1054
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 165/165 [00:05<00:00, 32.32it/s]
Dataset: CA_Validation.txt
Accuracy: 0.8577
Precision: 0.0263
Recall: 0.2471
F1 Score: 0.0476
AUC: 0.5568
MCC: 0.0396
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 217/217 [00:06<00:00, 33.25it/s]
Dataset: MG_Validation.txt
Accuracy: 0.8572
Precision: 0.0297
Recall: 0.3533
F1 Score: 0.0547
AUC: 0.6082
MCC: 0.0672
Checkpoint 4
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}
Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}