Christina Theodoris
Change list of individual IDs to set to ensure unique before subsetting into train/valid/test sets
45b9d69
#!/usr/bin/env python | |
# coding: utf-8 | |
# hyperparameter optimization with raytune for disease classification | |
# imports | |
import os | |
import subprocess | |
GPU_NUMBER = [0,1,2,3] | |
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) | |
os.environ["NCCL_DEBUG"] = "INFO" | |
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56" | |
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib" | |
# initiate runtime environment for raytune | |
import pyarrow # must occur prior to ray import | |
import ray | |
from ray import tune | |
from ray.tune import ExperimentAnalysis | |
from ray.tune.suggest.hyperopt import HyperOptSearch | |
runtime_env = {"conda": "base", | |
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}} | |
ray.init(runtime_env=runtime_env) | |
import datetime | |
import numpy as np | |
import pandas as pd | |
import random | |
import seaborn as sns; sns.set() | |
from collections import Counter | |
from datasets import load_from_disk | |
from scipy.stats import ranksums | |
from sklearn.metrics import accuracy_score | |
from transformers import BertForSequenceClassification | |
from transformers import Trainer | |
from transformers.training_args import TrainingArguments | |
from geneformer import DataCollatorForCellClassification | |
# number of CPU cores | |
num_proc=30 | |
# load train dataset with columns: | |
# cell_type (annotation of each cell's type) | |
# disease (healthy or disease state) | |
# individual (unique ID for each patient) | |
# length (length of that cell's rank value encoding) | |
train_dataset=load_from_disk("/path/to/disease_train_data.dataset") | |
# filter dataset for given cell_type | |
def if_cell_type(example): | |
return example["cell_type"].startswith("Cardiomyocyte") | |
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc) | |
# create dictionary of disease states : label ids | |
target_names = ["healthy", "disease1", "disease2"] | |
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))])) | |
trainset_v3 = trainset_v2.rename_column("disease","label") | |
# change labels to numerical ids | |
def classes_to_ids(example): | |
example["label"] = target_name_id_dict[example["label"]] | |
return example | |
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc) | |
# separate into train, validation, test sets | |
indiv_set = set(trainset_v4["individual"]) | |
random.seed(42) | |
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set))) | |
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv] | |
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv))) | |
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv] | |
def if_train(example): | |
return example["individual"] in train_indiv | |
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42) | |
def if_valid(example): | |
return example["individual"] in valid_indiv | |
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42) | |
# define output directory path | |
current_date = datetime.datetime.now() | |
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" | |
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/" | |
# ensure not overwriting previously saved model | |
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin") | |
if os.path.isfile(saved_model_test) == True: | |
raise Exception("Model already saved to this directory.") | |
# make output directory | |
subprocess.call(f'mkdir {output_dir}', shell=True) | |
# set training parameters | |
# how many pretrained layers to freeze | |
freeze_layers = 2 | |
# batch size for training and eval | |
geneformer_batch_size = 12 | |
# number of epochs | |
epochs = 1 | |
# logging steps | |
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10) | |
# define function to initiate model | |
def model_init(): | |
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/", | |
num_labels=len(target_names), | |
output_attentions = False, | |
output_hidden_states = False) | |
if freeze_layers is not None: | |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] | |
for module in modules_to_freeze: | |
for param in module.parameters(): | |
param.requires_grad = False | |
model = model.to("cuda:0") | |
return model | |
# define metrics | |
def compute_metrics(pred): | |
labels = pred.label_ids | |
preds = pred.predictions.argmax(-1) | |
# calculate accuracy using sklearn's function | |
acc = accuracy_score(labels, preds) | |
return { | |
'accuracy': acc, | |
} | |
# set training arguments | |
training_args = { | |
"do_train": True, | |
"do_eval": True, | |
"evaluation_strategy": "steps", | |
"eval_steps": logging_steps, | |
"logging_steps": logging_steps, | |
"group_by_length": True, | |
"length_column_name": "length", | |
"disable_tqdm": True, | |
"skip_memory_metrics": True, # memory tracker causes errors in raytune | |
"per_device_train_batch_size": geneformer_batch_size, | |
"per_device_eval_batch_size": geneformer_batch_size, | |
"num_train_epochs": epochs, | |
"load_best_model_at_end": True, | |
"output_dir": output_dir, | |
} | |
training_args_init = TrainingArguments(**training_args) | |
# create the trainer | |
trainer = Trainer( | |
model_init=model_init, | |
args=training_args_init, | |
data_collator=DataCollatorForCellClassification(), | |
train_dataset=classifier_trainset, | |
eval_dataset=classifier_validset, | |
compute_metrics=compute_metrics, | |
) | |
# specify raytune hyperparameter search space | |
ray_config = { | |
"num_train_epochs": tune.choice([epochs]), | |
"learning_rate": tune.loguniform(1e-6, 1e-3), | |
"weight_decay": tune.uniform(0.0, 0.3), | |
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]), | |
"warmup_steps": tune.uniform(100, 2000), | |
"seed": tune.uniform(0,100), | |
"per_device_train_batch_size": tune.choice([geneformer_batch_size]) | |
} | |
hyperopt_search = HyperOptSearch( | |
metric="eval_accuracy", mode="max") | |
# optimize hyperparameters | |
trainer.hyperparameter_search( | |
direction="maximize", | |
backend="ray", | |
resources_per_trial={"cpu":8,"gpu":1}, | |
hp_space=lambda _: ray_config, | |
search_alg=hyperopt_search, | |
n_trials=100, # number of trials | |
progress_reporter=tune.CLIReporter(max_report_frequency=600, | |
sort_by_metric=True, | |
max_progress_rows=100, | |
mode="max", | |
metric="eval_accuracy", | |
metric_columns=["loss", "eval_loss", "eval_accuracy"]) | |
) |