File size: 6,775 Bytes
79a0c41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b9d69
79a0c41
45b9d69
 
79a0c41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python
# coding: utf-8

# hyperparameter optimization with raytune for disease classification

# imports
import os
import subprocess
GPU_NUMBER = [0,1,2,3]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"

# initiate runtime environment for raytune
import pyarrow # must occur prior to ray import
import ray
from ray import tune
from ray.tune import ExperimentAnalysis
from ray.tune.suggest.hyperopt import HyperOptSearch
runtime_env = {"conda": "base",
               "env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
ray.init(runtime_env=runtime_env)

import datetime
import numpy as np
import pandas as pd
import random
import seaborn as sns; sns.set()
from collections import Counter
from datasets import load_from_disk
from scipy.stats import ranksums
from sklearn.metrics import accuracy_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

# number of CPU cores
num_proc=30

# load train dataset with columns: 
    # cell_type (annotation of each cell's type)
    # disease (healthy or disease state)
    # individual (unique ID for each patient)
    # length (length of that cell's rank value encoding)
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")

# filter dataset for given cell_type
def if_cell_type(example):
    return example["cell_type"].startswith("Cardiomyocyte")

trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)

# create dictionary of disease states : label ids
target_names = ["healthy", "disease1", "disease2"]
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))

trainset_v3 = trainset_v2.rename_column("disease","label")

# change labels to numerical ids
def classes_to_ids(example):
    example["label"] = target_name_id_dict[example["label"]]
    return example

trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)

# separate into train, validation, test sets
indiv_set = set(trainset_v4["individual"])
random.seed(42)
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]

def if_train(example):
    return example["individual"] in train_indiv

classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)

def if_valid(example):
    return example["individual"] in valid_indiv

classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)

# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training parameters
# how many pretrained layers to freeze
freeze_layers = 2
# batch size for training and eval
geneformer_batch_size = 12
# number of epochs
epochs = 1
# logging steps
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)

# define function to initiate model
def model_init():
    model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
                                                          num_labels=len(target_names),
                                                          output_attentions = False,
                                                          output_hidden_states = False)
    if freeze_layers is not None:
        modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
        for module in modules_to_freeze:
            for param in module.parameters():
                param.requires_grad = False

    model = model.to("cuda:0")
    return model

# define metrics
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy using sklearn's function
    acc = accuracy_score(labels, preds)
    return {
      'accuracy': acc,
    }

# set training arguments
training_args = {
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "steps",
    "eval_steps": logging_steps,
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": True,
    "skip_memory_metrics": True, # memory tracker causes errors in raytune
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model_init=model_init,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=classifier_trainset,
    eval_dataset=classifier_validset,
    compute_metrics=compute_metrics,
)

# specify raytune hyperparameter search space
ray_config = {
    "num_train_epochs": tune.choice([epochs]),
    "learning_rate": tune.loguniform(1e-6, 1e-3),
    "weight_decay": tune.uniform(0.0, 0.3),
    "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
    "warmup_steps": tune.uniform(100, 2000),
    "seed": tune.uniform(0,100),
    "per_device_train_batch_size": tune.choice([geneformer_batch_size])
}

hyperopt_search = HyperOptSearch(
    metric="eval_accuracy", mode="max")

# optimize hyperparameters
trainer.hyperparameter_search(
    direction="maximize",
    backend="ray",
    resources_per_trial={"cpu":8,"gpu":1},
    hp_space=lambda _: ray_config,
    search_alg=hyperopt_search,
    n_trials=100, # number of trials
    progress_reporter=tune.CLIReporter(max_report_frequency=600,
                                                   sort_by_metric=True,
                                                   max_progress_rows=100,
                                                   mode="max",
                                                   metric="eval_accuracy",
                                                   metric_columns=["loss", "eval_loss", "eval_accuracy"])
)