--- license: mit language: - en library_name: peft tags: - ESM-2 - QLoRA - Binding Sites - biology --- # ESM-2 QLoRA These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models. This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great. Scaling to larger models for better metrics is in progress. These checkpoints were trained using [the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models, you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers and use this instead: ``` pip install --upgrade git+https://github.com/huggingface/transformers.git ``` In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models. ## QLoRA Info Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices. ``` trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443 ``` It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this transfers to protein language models as well. ## Testing for Overfitting ### Checkpoint 1 Train/Test Split from 600K dataset: ```python Train metrics: {'eval_loss': 0.31757092475891113, 'eval_accuracy': 0.8666164527145709, 'eval_precision': 0.12977997642311132, 'eval_recall': 0.8907064653559833, 'eval_f1': 0.2265505142278714, 'eval_auc': 0.8783913689919987, 'eval_mcc': 0.30996745466311043} Test metrics: {'eval_loss': 0.3398605287075043, 'eval_accuracy': 0.8557050926566265, 'eval_precision': 0.10792930844408741, 'eval_recall': 0.7726298654561553, 'eval_f1': 0.18940102955847055, 'eval_auc': 0.8150939843855006, 'eval_mcc': 0.2535956911257298} ``` Metrics for [these datasets](https://github.com/hamzagamouh/pt-lm-gnn): ```python -------------------------------------------------- Processing rows: 100%|██████████| 54/54 [00:04<00:00, 11.49it/s] Dataset: GTP_Training.txt Accuracy: 0.8777 Precision: 0.1488 Recall: 0.5517 F1 Score: 0.2344 AUC: 0.7204 MCC: 0.2407 -------------------------------------------------- Processing rows: 100%|██████████| 82/82 [00:02<00:00, 32.07it/s] Dataset: GDP_Training.txt Accuracy: 0.8711 Precision: 0.1768 Recall: 0.6022 F1 Score: 0.2733 AUC: 0.7423 MCC: 0.2768 -------------------------------------------------- Processing rows: 100%|██████████| 172/172 [00:06<00:00, 27.98it/s] Dataset: FE_Training.txt Accuracy: 0.8424 Precision: 0.0547 Recall: 0.5452 F1 Score: 0.0994 AUC: 0.6962 MCC: 0.1344 -------------------------------------------------- Processing rows: 100%|██████████| 145/145 [00:04<00:00, 33.86it/s] Dataset: AMP_Training.txt Accuracy: 0.8191 Precision: 0.0975 Recall: 0.5078 F1 Score: 0.1636 AUC: 0.6691 MCC: 0.1609 -------------------------------------------------- Processing rows: 100%|██████████| 206/206 [00:05<00:00, 34.97it/s] Dataset: HEME_Training.txt Accuracy: 0.8561 Precision: 0.2089 Recall: 0.2795 F1 Score: 0.2391 AUC: 0.5932 MCC: 0.1636 -------------------------------------------------- Processing rows: 100%|██████████| 221/221 [00:06<00:00, 31.64it/s] Dataset: ATP_Training.txt Accuracy: 0.8631 Precision: 0.1459 Recall: 0.4975 F1 Score: 0.2256 AUC: 0.6879 MCC: 0.2146 -------------------------------------------------- Processing rows: 100%|██████████| 335/335 [00:10<00:00, 33.49it/s] Dataset: DNA_Training.txt Accuracy: 0.8387 Precision: 0.1608 Recall: 0.2233 F1 Score: 0.1870 AUC: 0.5589 MCC: 0.1017 -------------------------------------------------- Processing rows: 100%|██████████| 296/296 [00:08<00:00, 32.99it/s] Dataset: ADP_Training.txt Accuracy: 0.8653 Precision: 0.1415 Recall: 0.5142 F1 Score: 0.2219 AUC: 0.6966 MCC: 0.2176 -------------------------------------------------- Processing rows: 100%|██████████| 334/334 [00:10<00:00, 31.30it/s] Dataset: MN_Training.txt Accuracy: 0.8507 Precision: 0.0488 Recall: 0.5602 F1 Score: 0.0898 AUC: 0.7074 MCC: 0.1320 -------------------------------------------------- Processing rows: 100%|██████████| 1152/1152 [00:36<00:00, 31.70it/s] Dataset: ZN_Training.txt Accuracy: 0.8418 Precision: 0.0437 Recall: 0.4674 F1 Score: 0.0799 AUC: 0.6574 MCC: 0.1041 -------------------------------------------------- Processing rows: 100%|██████████| 1131/1131 [00:35<00:00, 31.87it/s] Dataset: MG_Training.txt Accuracy: 0.8454 Precision: 0.0327 Recall: 0.4617 F1 Score: 0.0611 AUC: 0.6556 MCC: 0.0896 -------------------------------------------------- Processing rows: 100%|██████████| 961/961 [00:30<00:00, 31.67it/s] Dataset: CA_Training.txt Accuracy: 0.8524 Precision: 0.0251 Recall: 0.2057 F1 Score: 0.0447 AUC: 0.5346 MCC: 0.0258 ``` ```python -------------------------------------------------- Processing rows: 100%|██████████| 27/27 [00:01<00:00, 26.47it/s] Dataset: HEME_Validation.txt Accuracy: 0.8891 Precision: 0.2125 Recall: 0.2810 F1 Score: 0.2420 AUC: 0.6055 MCC: 0.1855 -------------------------------------------------- Processing rows: 100%|██████████| 7/7 [00:00<00:00, 20.36it/s] Dataset: GTP_Validation.txt Accuracy: 0.8012 Precision: 0.1377 Recall: 0.6404 F1 Score: 0.2266 AUC: 0.7247 MCC: 0.2292 -------------------------------------------------- Processing rows: 100%|██████████| 14/14 [00:00<00:00, 17.96it/s] Dataset: GDP_Validation.txt Accuracy: 0.7954 Precision: 0.1456 Recall: 0.7423 F1 Score: 0.2434 AUC: 0.7701 MCC: 0.2658 -------------------------------------------------- Processing rows: 100%|██████████| 26/26 [00:00<00:00, 27.91it/s] Dataset: FE_Validation.txt Accuracy: 0.8523 Precision: 0.0571 Recall: 0.6667 F1 Score: 0.1052 AUC: 0.7607 MCC: 0.1646 -------------------------------------------------- Processing rows: 100%|██████████| 58/58 [00:01<00:00, 30.49it/s] Dataset: MN_Validation.txt Accuracy: 0.8445 Precision: 0.0458 Recall: 0.5359 F1 Score: 0.0844 AUC: 0.6923 MCC: 0.1216 -------------------------------------------------- Processing rows: 100%|██████████| 33/33 [00:00<00:00, 34.34it/s] Dataset: AMP_Validation.txt Accuracy: 0.8116 Precision: 0.1065 Recall: 0.5638 F1 Score: 0.1792 AUC: 0.6924 MCC: 0.1827 -------------------------------------------------- Processing rows: 100%|██████████| 52/52 [00:01<00:00, 32.70it/s] Dataset: DNA_Validation.txt Accuracy: 0.8849 Precision: 0.1306 Recall: 0.1829 F1 Score: 0.1524 AUC: 0.5550 MCC: 0.0940 -------------------------------------------------- Processing rows: 100%|██████████| 50/50 [00:01<00:00, 33.79it/s] Dataset: ATP_Validation.txt Accuracy: 0.8497 Precision: 0.1220 Recall: 0.4869 F1 Score: 0.1952 AUC: 0.6753 MCC: 0.1868 -------------------------------------------------- Processing rows: 100%|██████████| 47/47 [00:01<00:00, 31.43it/s] Dataset: ADP_Validation.txt Accuracy: 0.8652 Precision: 0.1279 Recall: 0.5379 F1 Score: 0.2067 AUC: 0.7071 MCC: 0.2139 -------------------------------------------------- Processing rows: 100%|██████████| 176/176 [00:05<00:00, 32.21it/s] Dataset: ZN_Validation.txt Accuracy: 0.8486 Precision: 0.0461 Recall: 0.4516 F1 Score: 0.0837 AUC: 0.6532 MCC: 0.1054 -------------------------------------------------- Processing rows: 100%|██████████| 165/165 [00:05<00:00, 32.32it/s] Dataset: CA_Validation.txt Accuracy: 0.8577 Precision: 0.0263 Recall: 0.2471 F1 Score: 0.0476 AUC: 0.5568 MCC: 0.0396 -------------------------------------------------- Processing rows: 100%|██████████| 217/217 [00:06<00:00, 33.25it/s] Dataset: MG_Validation.txt Accuracy: 0.8572 Precision: 0.0297 Recall: 0.3533 F1 Score: 0.0547 AUC: 0.6082 MCC: 0.0672 ``` ### Checkpoint 4 ```python Train metrics: {'eval_loss': 0.24070295691490173, 'eval_accuracy': 0.9018779246397052, 'eval_precision': 0.16624103834249204, 'eval_recall': 0.8651772818812425, 'eval_f1': 0.27889357183237473, 'eval_auc': 0.8839390799308487, 'eval_mcc': 0.3536803490333407} Test metrics: {'eval_loss': 0.26776671409606934, 'eval_accuracy': 0.8902711124906878, 'eval_precision': 0.13008662855482372, 'eval_recall': 0.7084623832213568, 'eval_f1': 0.219811797752809, 'eval_auc': 0.8013943890942485, 'eval_mcc': 0.2721459410994918} ```