File size: 9,097 Bytes
a049c3f c20ef98 a049c3f c20ef98 45e45d6 8b577db 508e6f9 bbe0e74 2248c60 c20ef98 45e45d6 0b6b91e 49868ca 77ab0f2 2f8efbb 77ab0f2 0b6b91e c66d466 c1701b9 c66d466 3290895 bcb3788 3290895 0b6b91e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
---
license: mit
language:
- en
library_name: peft
tags:
- ESM-2
- QLoRA
- Binding Sites
- biology
---
# ESM-2 QLoRA
These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models.
This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great.
Scaling to larger models for better metrics is in progress. These checkpoints were trained using
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models,
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:
```
pip install --upgrade git+https://github.com/huggingface/transformers.git
```
In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers
and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models.
## QLoRA Info
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices.
```
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
```
It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can
that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters
such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this
transfers to protein language models as well.
## Testing for Overfitting
### Checkpoint 1
Train/Test Split from 600K dataset:
```python
Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}
Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}
```
Metrics for [these datasets](https://github.com/hamzagamouh/pt-lm-gnn):
```python
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 54/54 [00:04<00:00, 11.49it/s]
Dataset: GTP_Training.txt
Accuracy: 0.8777
Precision: 0.1488
Recall: 0.5517
F1 Score: 0.2344
AUC: 0.7204
MCC: 0.2407
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 82/82 [00:02<00:00, 32.07it/s]
Dataset: GDP_Training.txt
Accuracy: 0.8711
Precision: 0.1768
Recall: 0.6022
F1 Score: 0.2733
AUC: 0.7423
MCC: 0.2768
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 172/172 [00:06<00:00, 27.98it/s]
Dataset: FE_Training.txt
Accuracy: 0.8424
Precision: 0.0547
Recall: 0.5452
F1 Score: 0.0994
AUC: 0.6962
MCC: 0.1344
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 145/145 [00:04<00:00, 33.86it/s]
Dataset: AMP_Training.txt
Accuracy: 0.8191
Precision: 0.0975
Recall: 0.5078
F1 Score: 0.1636
AUC: 0.6691
MCC: 0.1609
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 206/206 [00:05<00:00, 34.97it/s]
Dataset: HEME_Training.txt
Accuracy: 0.8561
Precision: 0.2089
Recall: 0.2795
F1 Score: 0.2391
AUC: 0.5932
MCC: 0.1636
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 221/221 [00:06<00:00, 31.64it/s]
Dataset: ATP_Training.txt
Accuracy: 0.8631
Precision: 0.1459
Recall: 0.4975
F1 Score: 0.2256
AUC: 0.6879
MCC: 0.2146
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 335/335 [00:10<00:00, 33.49it/s]
Dataset: DNA_Training.txt
Accuracy: 0.8387
Precision: 0.1608
Recall: 0.2233
F1 Score: 0.1870
AUC: 0.5589
MCC: 0.1017
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 296/296 [00:08<00:00, 32.99it/s]
Dataset: ADP_Training.txt
Accuracy: 0.8653
Precision: 0.1415
Recall: 0.5142
F1 Score: 0.2219
AUC: 0.6966
MCC: 0.2176
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 334/334 [00:10<00:00, 31.30it/s]
Dataset: MN_Training.txt
Accuracy: 0.8507
Precision: 0.0488
Recall: 0.5602
F1 Score: 0.0898
AUC: 0.7074
MCC: 0.1320
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 1152/1152 [00:36<00:00, 31.70it/s]
Dataset: ZN_Training.txt
Accuracy: 0.8418
Precision: 0.0437
Recall: 0.4674
F1 Score: 0.0799
AUC: 0.6574
MCC: 0.1041
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 1131/1131 [00:35<00:00, 31.87it/s]
Dataset: MG_Training.txt
Accuracy: 0.8454
Precision: 0.0327
Recall: 0.4617
F1 Score: 0.0611
AUC: 0.6556
MCC: 0.0896
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 961/961 [00:30<00:00, 31.67it/s]
Dataset: CA_Training.txt
Accuracy: 0.8524
Precision: 0.0251
Recall: 0.2057
F1 Score: 0.0447
AUC: 0.5346
MCC: 0.0258
```
```python
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 27/27 [00:01<00:00, 26.47it/s]
Dataset: HEME_Validation.txt
Accuracy: 0.8891
Precision: 0.2125
Recall: 0.2810
F1 Score: 0.2420
AUC: 0.6055
MCC: 0.1855
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 7/7 [00:00<00:00, 20.36it/s]
Dataset: GTP_Validation.txt
Accuracy: 0.8012
Precision: 0.1377
Recall: 0.6404
F1 Score: 0.2266
AUC: 0.7247
MCC: 0.2292
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 14/14 [00:00<00:00, 17.96it/s]
Dataset: GDP_Validation.txt
Accuracy: 0.7954
Precision: 0.1456
Recall: 0.7423
F1 Score: 0.2434
AUC: 0.7701
MCC: 0.2658
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 26/26 [00:00<00:00, 27.91it/s]
Dataset: FE_Validation.txt
Accuracy: 0.8523
Precision: 0.0571
Recall: 0.6667
F1 Score: 0.1052
AUC: 0.7607
MCC: 0.1646
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 58/58 [00:01<00:00, 30.49it/s]
Dataset: MN_Validation.txt
Accuracy: 0.8445
Precision: 0.0458
Recall: 0.5359
F1 Score: 0.0844
AUC: 0.6923
MCC: 0.1216
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 33/33 [00:00<00:00, 34.34it/s]
Dataset: AMP_Validation.txt
Accuracy: 0.8116
Precision: 0.1065
Recall: 0.5638
F1 Score: 0.1792
AUC: 0.6924
MCC: 0.1827
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 52/52 [00:01<00:00, 32.70it/s]
Dataset: DNA_Validation.txt
Accuracy: 0.8849
Precision: 0.1306
Recall: 0.1829
F1 Score: 0.1524
AUC: 0.5550
MCC: 0.0940
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 50/50 [00:01<00:00, 33.79it/s]
Dataset: ATP_Validation.txt
Accuracy: 0.8497
Precision: 0.1220
Recall: 0.4869
F1 Score: 0.1952
AUC: 0.6753
MCC: 0.1868
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 47/47 [00:01<00:00, 31.43it/s]
Dataset: ADP_Validation.txt
Accuracy: 0.8652
Precision: 0.1279
Recall: 0.5379
F1 Score: 0.2067
AUC: 0.7071
MCC: 0.2139
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 176/176 [00:05<00:00, 32.21it/s]
Dataset: ZN_Validation.txt
Accuracy: 0.8486
Precision: 0.0461
Recall: 0.4516
F1 Score: 0.0837
AUC: 0.6532
MCC: 0.1054
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 165/165 [00:05<00:00, 32.32it/s]
Dataset: CA_Validation.txt
Accuracy: 0.8577
Precision: 0.0263
Recall: 0.2471
F1 Score: 0.0476
AUC: 0.5568
MCC: 0.0396
--------------------------------------------------
Processing rows: 100%|ββββββββββ| 217/217 [00:06<00:00, 33.25it/s]
Dataset: MG_Validation.txt
Accuracy: 0.8572
Precision: 0.0297
Recall: 0.3533
F1 Score: 0.0547
AUC: 0.6082
MCC: 0.0672
```
### Checkpoint 4
```python
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}
Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}
```
|