|
|
|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
|
|
|
|
import os |
|
|
|
os.environ["NCCL_DEBUG"] = "INFO" |
|
os.environ["OMPI_MCA_opal_cuda_support"] = "true" |
|
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56" |
|
|
|
import pickle |
|
import random |
|
import subprocess |
|
|
|
import numpy as np |
|
import pytz |
|
import torch |
|
from datasets import load_from_disk |
|
from transformers import BertConfig, BertForMaskedLM, TrainingArguments |
|
|
|
from geneformer import GeneformerPretrainer |
|
|
|
seed_num = 0 |
|
random.seed(seed_num) |
|
np.random.seed(seed_num) |
|
seed_val = 42 |
|
torch.manual_seed(seed_val) |
|
torch.cuda.manual_seed_all(seed_val) |
|
|
|
|
|
timezone = pytz.timezone("US/Eastern") |
|
rootdir = "/parent_ouput_directory" |
|
|
|
|
|
|
|
model_type = "bert" |
|
|
|
max_input_size = 2**11 |
|
|
|
num_layers = 6 |
|
|
|
num_attn_heads = 4 |
|
|
|
num_embed_dim = 256 |
|
|
|
intermed_size = num_embed_dim * 2 |
|
|
|
activ_fn = "relu" |
|
|
|
initializer_range = 0.02 |
|
layer_norm_eps = 1e-12 |
|
attention_probs_dropout_prob = 0.02 |
|
hidden_dropout_prob = 0.02 |
|
|
|
|
|
|
|
|
|
num_examples = 27_406_208 |
|
|
|
num_gpus = 12 |
|
|
|
geneformer_batch_size = 12 |
|
|
|
max_lr = 1e-3 |
|
|
|
lr_schedule_fn = "linear" |
|
|
|
warmup_steps = 10_000 |
|
|
|
epochs = 3 |
|
|
|
optimizer = "adamw" |
|
|
|
weight_decay = 0.001 |
|
|
|
|
|
|
|
current_date = datetime.datetime.now(tz=timezone) |
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}" |
|
run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}" |
|
training_output_dir = f"{rootdir}/models/{run_name}/" |
|
logging_dir = f"{rootdir}/runs/{run_name}/" |
|
model_output_dir = os.path.join(training_output_dir, "models/") |
|
|
|
|
|
|
|
model_output_file = os.path.join(model_output_dir, "pytorch_model.bin") |
|
if os.path.isfile(model_output_file) is True: |
|
raise Exception("Model already saved to this directory.") |
|
|
|
|
|
|
|
subprocess.call(f"mkdir {training_output_dir}", shell=True) |
|
subprocess.call(f"mkdir {model_output_dir}", shell=True) |
|
|
|
|
|
|
|
with open("token_dictionary.pkl", "rb") as fp: |
|
token_dictionary = pickle.load(fp) |
|
|
|
|
|
config = { |
|
"hidden_size": num_embed_dim, |
|
"num_hidden_layers": num_layers, |
|
"initializer_range": initializer_range, |
|
"layer_norm_eps": layer_norm_eps, |
|
"attention_probs_dropout_prob": attention_probs_dropout_prob, |
|
"hidden_dropout_prob": hidden_dropout_prob, |
|
"intermediate_size": intermed_size, |
|
"hidden_act": activ_fn, |
|
"max_position_embeddings": max_input_size, |
|
"model_type": model_type, |
|
"num_attention_heads": num_attn_heads, |
|
"pad_token_id": token_dictionary.get("<pad>"), |
|
"vocab_size": len(token_dictionary), |
|
} |
|
|
|
config = BertConfig(**config) |
|
model = BertForMaskedLM(config) |
|
model = model.train() |
|
|
|
|
|
training_args = { |
|
"learning_rate": max_lr, |
|
"do_train": True, |
|
"do_eval": False, |
|
"group_by_length": True, |
|
"length_column_name": "length", |
|
"disable_tqdm": False, |
|
"lr_scheduler_type": lr_schedule_fn, |
|
"warmup_steps": warmup_steps, |
|
"weight_decay": weight_decay, |
|
"per_device_train_batch_size": geneformer_batch_size, |
|
"num_train_epochs": epochs, |
|
"save_strategy": "steps", |
|
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), |
|
"logging_steps": 1000, |
|
"output_dir": training_output_dir, |
|
"logging_dir": logging_dir, |
|
} |
|
training_args = TrainingArguments(**training_args) |
|
|
|
print("Starting training.") |
|
|
|
|
|
trainer = GeneformerPretrainer( |
|
model=model, |
|
args=training_args, |
|
|
|
train_dataset=load_from_disk("genecorpus_30M_2048.dataset"), |
|
|
|
example_lengths_file="genecorpus_30M_2048_sorted_lengths.pkl", |
|
token_dictionary=token_dictionary, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model(model_output_dir) |
|
|