|
--- |
|
license: mit |
|
language: |
|
- en |
|
library_name: peft |
|
tags: |
|
- ESM-2 |
|
- QLoRA |
|
- Binding Sites |
|
- biology |
|
--- |
|
|
|
# ESM-2 QLoRA |
|
|
|
These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models. |
|
This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great. |
|
Scaling to larger models for better metrics is in progress. These checkpoints were trained using |
|
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models, |
|
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers |
|
and use this instead: |
|
|
|
``` |
|
pip install --upgrade git+https://github.com/huggingface/transformers.git |
|
``` |
|
|
|
In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers |
|
and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models. |
|
|
|
## Data Curation and Preprocessing |
|
|
|
To create your own datasets and perform the same data preprocessing as was used for this project, you will need to download a TSV file |
|
from UniProt with the following columns (Protein families, Binding sites, Active sites, Protein sequence), and then you can use |
|
[this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8m_qlora_binding_sites_v0/blob/main/data_processing_v1.ipynb) for |
|
separating out the test sequences by choosing random families to use (including all sequences in that family, with no overlap in with |
|
the training data), filtering out proteins with incomplete annotations, merging the binding and active sites, converting them to binary |
|
labels (`0` for non-binding sites, `1` for binding sites), and splitting the sequences into non-overlapping chunks of 1000 residues or |
|
less to accomodate the 1022 sized context window of ESM-2 models. This notebook will also allow you to reduce the size of your dataset |
|
at the end. Note, this step is not currently ideal as it only selects proteins at random from the train and test datasets to keep and does |
|
not take into account that proteins from small families are less likely to be chosen, biasing the models towards larger families. Due to |
|
this shortcoming in our data preprocessing step, smaller models trained on smaller datasets are likely biased towards larger families. |
|
Perhaps an approach that is biased towards smaller families would be better. |
|
|
|
## QLoRA Info |
|
|
|
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices. |
|
|
|
``` |
|
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443 |
|
``` |
|
|
|
It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can |
|
that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters |
|
such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this |
|
transfers to protein language models as well. A general pattern showing that overfitting is improved by adding in adapters for more of the |
|
weight matrices is emerging, so more adapter layers seems to be better in that regard as well. |
|
|
|
## Testing for Overfitting |
|
|
|
### Checkpoint 1 |
|
|
|
Train/Test Split from 600K dataset: |
|
|
|
```python |
|
Train metrics: |
|
{'eval_loss': 0.31757092475891113, |
|
'eval_accuracy': 0.8666164527145709, |
|
'eval_precision': 0.12977997642311132, |
|
'eval_recall': 0.8907064653559833, |
|
'eval_f1': 0.2265505142278714, |
|
'eval_auc': 0.8783913689919987, |
|
'eval_mcc': 0.30996745466311043} |
|
|
|
Test metrics: |
|
{'eval_loss': 0.3398605287075043, |
|
'eval_accuracy': 0.8557050926566265, |
|
'eval_precision': 0.10792930844408741, |
|
'eval_recall': 0.7726298654561553, |
|
'eval_f1': 0.18940102955847055, |
|
'eval_auc': 0.8150939843855006, |
|
'eval_mcc': 0.2535956911257298} |
|
``` |
|
|
|
Metrics for this checkpoint for [these datasets](https://github.com/hamzagamouh/pt-lm-gnn) can be |
|
[found here](https://huggingface.co/AmelieSchreiber/esm2_t6_8m_qlora_binding_sites_v0/blob/main/pdb_struct_metrics.txt). |
|
|
|
### Checkpoint 4 |
|
|
|
```python |
|
Train metrics: |
|
{'eval_loss': 0.24070295691490173, |
|
'eval_accuracy': 0.9018779246397052, |
|
'eval_precision': 0.16624103834249204, |
|
'eval_recall': 0.8651772818812425, |
|
'eval_f1': 0.27889357183237473, |
|
'eval_auc': 0.8839390799308487, |
|
'eval_mcc': 0.3536803490333407} |
|
|
|
Test metrics: |
|
{'eval_loss': 0.26776671409606934, |
|
'eval_accuracy': 0.8902711124906878, |
|
'eval_precision': 0.13008662855482372, |
|
'eval_recall': 0.7084623832213568, |
|
'eval_f1': 0.219811797752809, |
|
'eval_auc': 0.8013943890942485, |
|
'eval_mcc': 0.2721459410994918} |
|
``` |
|
|