svincoff's picture
uploaded training code and model weights
9a73cb0
|
raw
history blame
2.14 kB

Training Script

This folder holds code for training the model (train.py), defining the model architecture (model.py), and defining utility functions including masking rate schedulers adn dataloaders (utils.py). There is also a script for running ESM-2 on the test data (test_esm2.py).

The weights and other necessary files for loading FusOn-pLM are stored in checkpoints/best/ckpt. Results on the test set are stored in checkpoints/best/test_results.csv.

Usage

Configs

The config.py script holds configurations for training and plotting.

# Model parameters
EPOCHS = 30
BATCH_SIZE = 8
MAX_LENGTH = 2000
LEARNING_RATE = 3e-4
N_UNFROZEN_LAYERS = 8
UNFREEZE_QUERY = True
UNFREEZE_KEY = True
UNFREEZE_VALUE = True

### Masking parameters - must use either variable or fixed masking rate
# var masking rate (choice 1)
VAR_MASK_RATE = True            # if this is 
MASK_LOW = 0.15
MASK_HIGH = 0.40
MASK_STEPS = 20
MASK_SCHEDULER = "cosine"       # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
# fixed masking rate (choice 2)
MASK_PERCENTAGE = 0.15          # if VAR_MASK_RATE = False, code will use fixed masking rate

# To continue training a model you already started, fill in the following parameters
FINETUNE_FROM_SCRATCH = True                       # Set to False if you want to finetune from a checkpoint  
PATH_TO_STARTING_CKPT = ''  # only set the path if FINETUNE_FROM_SCRATCH = False 

# File paths - do not change unless you move the training dta 
TRAIN_PATH = '../data/splits/train_df.csv'     
VAL_PATH = '../data/splits/val_df.csv'
TEST_PATH = '../data/splits/test_df.csv'

# WandB parameters
# Fill these in with your own WandB account info 
WANDB_PROJECT = '' 
WANDB_ENTITY = ''
WANDB_API_KEY=''

# GPU parameters
CUDA_VISIBLE_DEVICES = "0"

Training

The train.py script trains a fusion-aware ESM model according to the settings specified in config.py.

To run, enter in terminal:

python train.py

or, to run the (long) training process in the background:

nohup python train.py > train.out 2> train.err &