File size: 9,097 Bytes
a049c3f
 
c20ef98
 
 
 
 
 
 
 
a049c3f
c20ef98
45e45d6
 
8b577db
 
508e6f9
 
bbe0e74
 
 
 
 
 
 
2248c60
 
c20ef98
45e45d6
 
 
 
 
 
 
0b6b91e
49868ca
77ab0f2
2f8efbb
77ab0f2
 
0b6b91e
 
 
 
c66d466
 
c1701b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c66d466
 
3290895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcb3788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3290895
0b6b91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
---
license: mit
language:
- en
library_name: peft
tags:
- ESM-2
- QLoRA
- Binding Sites
- biology
---

# ESM-2 QLoRA

These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models. 
This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great. 
Scaling to larger models for better metrics is in progress. These checkpoints were trained using 
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models, 
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:

```
pip install --upgrade git+https://github.com/huggingface/transformers.git
```

In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers 
and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models. 

## QLoRA Info

Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices. 

```
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
```

It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can 
that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters
such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this 
transfers to protein language models as well. 

## Testing for Overfitting

### Checkpoint 1

Train/Test Split from 600K dataset:

```python
Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}

Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}
```

Metrics for [these datasets](https://github.com/hamzagamouh/pt-lm-gnn):

```python
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 54/54 [00:04<00:00, 11.49it/s]
Dataset: GTP_Training.txt
Accuracy: 0.8777
Precision: 0.1488
Recall: 0.5517
F1 Score: 0.2344
AUC: 0.7204
MCC: 0.2407
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 82/82 [00:02<00:00, 32.07it/s]
Dataset: GDP_Training.txt
Accuracy: 0.8711
Precision: 0.1768
Recall: 0.6022
F1 Score: 0.2733
AUC: 0.7423
MCC: 0.2768
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 172/172 [00:06<00:00, 27.98it/s]
Dataset: FE_Training.txt
Accuracy: 0.8424
Precision: 0.0547
Recall: 0.5452
F1 Score: 0.0994
AUC: 0.6962
MCC: 0.1344
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 145/145 [00:04<00:00, 33.86it/s]
Dataset: AMP_Training.txt
Accuracy: 0.8191
Precision: 0.0975
Recall: 0.5078
F1 Score: 0.1636
AUC: 0.6691
MCC: 0.1609
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 206/206 [00:05<00:00, 34.97it/s]
Dataset: HEME_Training.txt
Accuracy: 0.8561
Precision: 0.2089
Recall: 0.2795
F1 Score: 0.2391
AUC: 0.5932
MCC: 0.1636
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 221/221 [00:06<00:00, 31.64it/s]
Dataset: ATP_Training.txt
Accuracy: 0.8631
Precision: 0.1459
Recall: 0.4975
F1 Score: 0.2256
AUC: 0.6879
MCC: 0.2146
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 335/335 [00:10<00:00, 33.49it/s]
Dataset: DNA_Training.txt
Accuracy: 0.8387
Precision: 0.1608
Recall: 0.2233
F1 Score: 0.1870
AUC: 0.5589
MCC: 0.1017
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 296/296 [00:08<00:00, 32.99it/s]
Dataset: ADP_Training.txt
Accuracy: 0.8653
Precision: 0.1415
Recall: 0.5142
F1 Score: 0.2219
AUC: 0.6966
MCC: 0.2176
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 334/334 [00:10<00:00, 31.30it/s]
Dataset: MN_Training.txt
Accuracy: 0.8507
Precision: 0.0488
Recall: 0.5602
F1 Score: 0.0898
AUC: 0.7074
MCC: 0.1320
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1152/1152 [00:36<00:00, 31.70it/s]
Dataset: ZN_Training.txt
Accuracy: 0.8418
Precision: 0.0437
Recall: 0.4674
F1 Score: 0.0799
AUC: 0.6574
MCC: 0.1041
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1131/1131 [00:35<00:00, 31.87it/s]
Dataset: MG_Training.txt
Accuracy: 0.8454
Precision: 0.0327
Recall: 0.4617
F1 Score: 0.0611
AUC: 0.6556
MCC: 0.0896
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 961/961 [00:30<00:00, 31.67it/s]
Dataset: CA_Training.txt
Accuracy: 0.8524
Precision: 0.0251
Recall: 0.2057
F1 Score: 0.0447
AUC: 0.5346
MCC: 0.0258
```
```python
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 27/27 [00:01<00:00, 26.47it/s]
Dataset: HEME_Validation.txt
Accuracy: 0.8891
Precision: 0.2125
Recall: 0.2810
F1 Score: 0.2420
AUC: 0.6055
MCC: 0.1855
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 20.36it/s]
Dataset: GTP_Validation.txt
Accuracy: 0.8012
Precision: 0.1377
Recall: 0.6404
F1 Score: 0.2266
AUC: 0.7247
MCC: 0.2292
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [00:00<00:00, 17.96it/s]
Dataset: GDP_Validation.txt
Accuracy: 0.7954
Precision: 0.1456
Recall: 0.7423
F1 Score: 0.2434
AUC: 0.7701
MCC: 0.2658
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 26/26 [00:00<00:00, 27.91it/s]
Dataset: FE_Validation.txt
Accuracy: 0.8523
Precision: 0.0571
Recall: 0.6667
F1 Score: 0.1052
AUC: 0.7607
MCC: 0.1646
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 58/58 [00:01<00:00, 30.49it/s]
Dataset: MN_Validation.txt
Accuracy: 0.8445
Precision: 0.0458
Recall: 0.5359
F1 Score: 0.0844
AUC: 0.6923
MCC: 0.1216
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 33/33 [00:00<00:00, 34.34it/s]
Dataset: AMP_Validation.txt
Accuracy: 0.8116
Precision: 0.1065
Recall: 0.5638
F1 Score: 0.1792
AUC: 0.6924
MCC: 0.1827
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 52/52 [00:01<00:00, 32.70it/s]
Dataset: DNA_Validation.txt
Accuracy: 0.8849
Precision: 0.1306
Recall: 0.1829
F1 Score: 0.1524
AUC: 0.5550
MCC: 0.0940
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:01<00:00, 33.79it/s]
Dataset: ATP_Validation.txt
Accuracy: 0.8497
Precision: 0.1220
Recall: 0.4869
F1 Score: 0.1952
AUC: 0.6753
MCC: 0.1868
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 47/47 [00:01<00:00, 31.43it/s]
Dataset: ADP_Validation.txt
Accuracy: 0.8652
Precision: 0.1279
Recall: 0.5379
F1 Score: 0.2067
AUC: 0.7071
MCC: 0.2139
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 176/176 [00:05<00:00, 32.21it/s]
Dataset: ZN_Validation.txt
Accuracy: 0.8486
Precision: 0.0461
Recall: 0.4516
F1 Score: 0.0837
AUC: 0.6532
MCC: 0.1054
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 165/165 [00:05<00:00, 32.32it/s]
Dataset: CA_Validation.txt
Accuracy: 0.8577
Precision: 0.0263
Recall: 0.2471
F1 Score: 0.0476
AUC: 0.5568
MCC: 0.0396
--------------------------------------------------
Processing rows: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 217/217 [00:06<00:00, 33.25it/s]
Dataset: MG_Validation.txt
Accuracy: 0.8572
Precision: 0.0297
Recall: 0.3533
F1 Score: 0.0547
AUC: 0.6082
MCC: 0.0672
```

### Checkpoint 4

```python
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}

Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}
```