File size: 2,137 Bytes
9a73cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
## Training Script

This folder holds code for training the model (`train.py`), defining the model architecture (`model.py`), and defining utility functions including masking rate schedulers adn dataloaders (`utils.py`). There is also a script for running ESM-2 on the test data (`test_esm2.py`). 

The weights and other necessary files for loading FusOn-pLM are stored in `checkpoints/best/ckpt`. Results on the test set are stored in `checkpoints/best/test_results.csv`. 

### Usage
#### Configs
The `config.py` script holds configurations for **training** and **plotting**. 

```python
# Model parameters
EPOCHS = 30
BATCH_SIZE = 8
MAX_LENGTH = 2000
LEARNING_RATE = 3e-4
N_UNFROZEN_LAYERS = 8
UNFREEZE_QUERY = True
UNFREEZE_KEY = True
UNFREEZE_VALUE = True

### Masking parameters - must use either variable or fixed masking rate
# var masking rate (choice 1)
VAR_MASK_RATE = True            # if this is 
MASK_LOW = 0.15
MASK_HIGH = 0.40
MASK_STEPS = 20
MASK_SCHEDULER = "cosine"       # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
# fixed masking rate (choice 2)
MASK_PERCENTAGE = 0.15          # if VAR_MASK_RATE = False, code will use fixed masking rate

# To continue training a model you already started, fill in the following parameters
FINETUNE_FROM_SCRATCH = True                       # Set to False if you want to finetune from a checkpoint  
PATH_TO_STARTING_CKPT = ''  # only set the path if FINETUNE_FROM_SCRATCH = False 

# File paths - do not change unless you move the training dta 
TRAIN_PATH = '../data/splits/train_df.csv'     
VAL_PATH = '../data/splits/val_df.csv'
TEST_PATH = '../data/splits/test_df.csv'

# WandB parameters
# Fill these in with your own WandB account info 
WANDB_PROJECT = '' 
WANDB_ENTITY = ''
WANDB_API_KEY=''

# GPU parameters
CUDA_VISIBLE_DEVICES = "0"
```

#### Training
The `train.py` script trains a fusion-aware ESM model according to the settings specified in  `config.py`. 

To run, enter in terminal:
```bash
python train.py
```
or, to run the (long) training process in the background:
```bash
nohup python train.py > train.out 2> train.err &