|
--- |
|
license: mit |
|
language: |
|
- en |
|
library_name: peft |
|
tags: |
|
- ESM-2 |
|
- QLoRA |
|
- Binding Sites |
|
- biology |
|
--- |
|
|
|
# ESM-2 QLoRA |
|
|
|
These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models. |
|
This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great. |
|
Scaling to larger models for better metrics is in progress. These checkpoints were trained using |
|
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models, |
|
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers |
|
and use this instead: |
|
|
|
``` |
|
pip install --upgrade git+https://github.com/huggingface/transformers.git |
|
``` |
|
|
|
In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers |
|
and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models. |
|
|
|
## QLoRA Info |
|
|
|
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices. |
|
|
|
``` |
|
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443 |
|
``` |
|
|
|
It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can |
|
that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters |
|
such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this |
|
transfers to protein language models as well. |
|
|
|
## Testing for Overfitting |
|
|
|
### Checkpoint 1 |
|
|
|
Train/Test Split from 600K dataset: |
|
|
|
```python |
|
Train metrics: |
|
{'eval_loss': 0.31757092475891113, |
|
'eval_accuracy': 0.8666164527145709, |
|
'eval_precision': 0.12977997642311132, |
|
'eval_recall': 0.8907064653559833, |
|
'eval_f1': 0.2265505142278714, |
|
'eval_auc': 0.8783913689919987, |
|
'eval_mcc': 0.30996745466311043} |
|
|
|
Test metrics: |
|
{'eval_loss': 0.3398605287075043, |
|
'eval_accuracy': 0.8557050926566265, |
|
'eval_precision': 0.10792930844408741, |
|
'eval_recall': 0.7726298654561553, |
|
'eval_f1': 0.18940102955847055, |
|
'eval_auc': 0.8150939843855006, |
|
'eval_mcc': 0.2535956911257298} |
|
``` |
|
|
|
Metrics for [these datasets](https://github.com/hamzagamouh/pt-lm-gnn): |
|
|
|
```python |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 54/54 [00:04<00:00, 11.49it/s] |
|
Dataset: GTP_Training.txt |
|
Accuracy: 0.8777 |
|
Precision: 0.1488 |
|
Recall: 0.5517 |
|
F1 Score: 0.2344 |
|
AUC: 0.7204 |
|
MCC: 0.2407 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 82/82 [00:02<00:00, 32.07it/s] |
|
Dataset: GDP_Training.txt |
|
Accuracy: 0.8711 |
|
Precision: 0.1768 |
|
Recall: 0.6022 |
|
F1 Score: 0.2733 |
|
AUC: 0.7423 |
|
MCC: 0.2768 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 172/172 [00:06<00:00, 27.98it/s] |
|
Dataset: FE_Training.txt |
|
Accuracy: 0.8424 |
|
Precision: 0.0547 |
|
Recall: 0.5452 |
|
F1 Score: 0.0994 |
|
AUC: 0.6962 |
|
MCC: 0.1344 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 145/145 [00:04<00:00, 33.86it/s] |
|
Dataset: AMP_Training.txt |
|
Accuracy: 0.8191 |
|
Precision: 0.0975 |
|
Recall: 0.5078 |
|
F1 Score: 0.1636 |
|
AUC: 0.6691 |
|
MCC: 0.1609 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 206/206 [00:05<00:00, 34.97it/s] |
|
Dataset: HEME_Training.txt |
|
Accuracy: 0.8561 |
|
Precision: 0.2089 |
|
Recall: 0.2795 |
|
F1 Score: 0.2391 |
|
AUC: 0.5932 |
|
MCC: 0.1636 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 221/221 [00:06<00:00, 31.64it/s] |
|
Dataset: ATP_Training.txt |
|
Accuracy: 0.8631 |
|
Precision: 0.1459 |
|
Recall: 0.4975 |
|
F1 Score: 0.2256 |
|
AUC: 0.6879 |
|
MCC: 0.2146 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 335/335 [00:10<00:00, 33.49it/s] |
|
Dataset: DNA_Training.txt |
|
Accuracy: 0.8387 |
|
Precision: 0.1608 |
|
Recall: 0.2233 |
|
F1 Score: 0.1870 |
|
AUC: 0.5589 |
|
MCC: 0.1017 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 296/296 [00:08<00:00, 32.99it/s] |
|
Dataset: ADP_Training.txt |
|
Accuracy: 0.8653 |
|
Precision: 0.1415 |
|
Recall: 0.5142 |
|
F1 Score: 0.2219 |
|
AUC: 0.6966 |
|
MCC: 0.2176 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 334/334 [00:10<00:00, 31.30it/s] |
|
Dataset: MN_Training.txt |
|
Accuracy: 0.8507 |
|
Precision: 0.0488 |
|
Recall: 0.5602 |
|
F1 Score: 0.0898 |
|
AUC: 0.7074 |
|
MCC: 0.1320 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 1152/1152 [00:36<00:00, 31.70it/s] |
|
Dataset: ZN_Training.txt |
|
Accuracy: 0.8418 |
|
Precision: 0.0437 |
|
Recall: 0.4674 |
|
F1 Score: 0.0799 |
|
AUC: 0.6574 |
|
MCC: 0.1041 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 1131/1131 [00:35<00:00, 31.87it/s] |
|
Dataset: MG_Training.txt |
|
Accuracy: 0.8454 |
|
Precision: 0.0327 |
|
Recall: 0.4617 |
|
F1 Score: 0.0611 |
|
AUC: 0.6556 |
|
MCC: 0.0896 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 961/961 [00:30<00:00, 31.67it/s] |
|
Dataset: CA_Training.txt |
|
Accuracy: 0.8524 |
|
Precision: 0.0251 |
|
Recall: 0.2057 |
|
F1 Score: 0.0447 |
|
AUC: 0.5346 |
|
MCC: 0.0258 |
|
``` |
|
```python |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 27/27 [00:01<00:00, 26.47it/s] |
|
Dataset: HEME_Validation.txt |
|
Accuracy: 0.8891 |
|
Precision: 0.2125 |
|
Recall: 0.2810 |
|
F1 Score: 0.2420 |
|
AUC: 0.6055 |
|
MCC: 0.1855 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 7/7 [00:00<00:00, 20.36it/s] |
|
Dataset: GTP_Validation.txt |
|
Accuracy: 0.8012 |
|
Precision: 0.1377 |
|
Recall: 0.6404 |
|
F1 Score: 0.2266 |
|
AUC: 0.7247 |
|
MCC: 0.2292 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 14/14 [00:00<00:00, 17.96it/s] |
|
Dataset: GDP_Validation.txt |
|
Accuracy: 0.7954 |
|
Precision: 0.1456 |
|
Recall: 0.7423 |
|
F1 Score: 0.2434 |
|
AUC: 0.7701 |
|
MCC: 0.2658 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 26/26 [00:00<00:00, 27.91it/s] |
|
Dataset: FE_Validation.txt |
|
Accuracy: 0.8523 |
|
Precision: 0.0571 |
|
Recall: 0.6667 |
|
F1 Score: 0.1052 |
|
AUC: 0.7607 |
|
MCC: 0.1646 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 58/58 [00:01<00:00, 30.49it/s] |
|
Dataset: MN_Validation.txt |
|
Accuracy: 0.8445 |
|
Precision: 0.0458 |
|
Recall: 0.5359 |
|
F1 Score: 0.0844 |
|
AUC: 0.6923 |
|
MCC: 0.1216 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 33/33 [00:00<00:00, 34.34it/s] |
|
Dataset: AMP_Validation.txt |
|
Accuracy: 0.8116 |
|
Precision: 0.1065 |
|
Recall: 0.5638 |
|
F1 Score: 0.1792 |
|
AUC: 0.6924 |
|
MCC: 0.1827 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 52/52 [00:01<00:00, 32.70it/s] |
|
Dataset: DNA_Validation.txt |
|
Accuracy: 0.8849 |
|
Precision: 0.1306 |
|
Recall: 0.1829 |
|
F1 Score: 0.1524 |
|
AUC: 0.5550 |
|
MCC: 0.0940 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 50/50 [00:01<00:00, 33.79it/s] |
|
Dataset: ATP_Validation.txt |
|
Accuracy: 0.8497 |
|
Precision: 0.1220 |
|
Recall: 0.4869 |
|
F1 Score: 0.1952 |
|
AUC: 0.6753 |
|
MCC: 0.1868 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 47/47 [00:01<00:00, 31.43it/s] |
|
Dataset: ADP_Validation.txt |
|
Accuracy: 0.8652 |
|
Precision: 0.1279 |
|
Recall: 0.5379 |
|
F1 Score: 0.2067 |
|
AUC: 0.7071 |
|
MCC: 0.2139 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 176/176 [00:05<00:00, 32.21it/s] |
|
Dataset: ZN_Validation.txt |
|
Accuracy: 0.8486 |
|
Precision: 0.0461 |
|
Recall: 0.4516 |
|
F1 Score: 0.0837 |
|
AUC: 0.6532 |
|
MCC: 0.1054 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 165/165 [00:05<00:00, 32.32it/s] |
|
Dataset: CA_Validation.txt |
|
Accuracy: 0.8577 |
|
Precision: 0.0263 |
|
Recall: 0.2471 |
|
F1 Score: 0.0476 |
|
AUC: 0.5568 |
|
MCC: 0.0396 |
|
-------------------------------------------------- |
|
Processing rows: 100%|ββββββββββ| 217/217 [00:06<00:00, 33.25it/s] |
|
Dataset: MG_Validation.txt |
|
Accuracy: 0.8572 |
|
Precision: 0.0297 |
|
Recall: 0.3533 |
|
F1 Score: 0.0547 |
|
AUC: 0.6082 |
|
MCC: 0.0672 |
|
``` |
|
|
|
### Checkpoint 4 |
|
|
|
```python |
|
Train metrics: |
|
{'eval_loss': 0.24070295691490173, |
|
'eval_accuracy': 0.9018779246397052, |
|
'eval_precision': 0.16624103834249204, |
|
'eval_recall': 0.8651772818812425, |
|
'eval_f1': 0.27889357183237473, |
|
'eval_auc': 0.8839390799308487, |
|
'eval_mcc': 0.3536803490333407} |
|
|
|
Test metrics: |
|
{'eval_loss': 0.26776671409606934, |
|
'eval_accuracy': 0.8902711124906878, |
|
'eval_precision': 0.13008662855482372, |
|
'eval_recall': 0.7084623832213568, |
|
'eval_f1': 0.219811797752809, |
|
'eval_auc': 0.8013943890942485, |
|
'eval_mcc': 0.2721459410994918} |
|
``` |
|
|