|
import random |
|
|
|
from .data import get_data_loader, preload_and_process_data |
|
from .imports import * |
|
from .model import GeneformerMultiTask |
|
from .train import objective, train_model |
|
from .utils import save_model |
|
|
|
|
|
def set_seed(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def run_manual_tuning(config): |
|
|
|
set_seed(config["seed"]) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
( |
|
train_dataset, |
|
train_cell_id_mapping, |
|
val_dataset, |
|
val_cell_id_mapping, |
|
num_labels_list, |
|
) = preload_and_process_data(config) |
|
train_loader = get_data_loader(train_dataset, config["batch_size"]) |
|
val_loader = get_data_loader(val_dataset, config["batch_size"]) |
|
|
|
|
|
print("\nManual hyperparameters being used:") |
|
for key, value in config["manual_hyperparameters"].items(): |
|
print(f"{key}: {value}") |
|
print() |
|
|
|
|
|
for key, value in config["manual_hyperparameters"].items(): |
|
config[key] = value |
|
|
|
|
|
val_loss, trained_model = train_model( |
|
config, |
|
device, |
|
train_loader, |
|
val_loader, |
|
train_cell_id_mapping, |
|
val_cell_id_mapping, |
|
num_labels_list, |
|
) |
|
|
|
print(f"\nValidation loss with manual hyperparameters: {val_loss}") |
|
|
|
|
|
model_save_directory = os.path.join( |
|
config["model_save_path"], "GeneformerMultiTask" |
|
) |
|
save_model(trained_model, model_save_directory) |
|
|
|
|
|
hyperparams_to_save = { |
|
**config["manual_hyperparameters"], |
|
"dropout_rate": config["dropout_rate"], |
|
"use_task_weights": config["use_task_weights"], |
|
"task_weights": config["task_weights"], |
|
"max_layers_to_freeze": config["max_layers_to_freeze"], |
|
"use_attention_pooling": config["use_attention_pooling"], |
|
} |
|
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") |
|
with open(hyperparams_path, "w") as f: |
|
json.dump(hyperparams_to_save, f) |
|
print(f"Manual hyperparameters saved to {hyperparams_path}") |
|
|
|
return val_loss |
|
|
|
|
|
def run_optuna_study(config): |
|
|
|
set_seed(config["seed"]) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
( |
|
train_dataset, |
|
train_cell_id_mapping, |
|
val_dataset, |
|
val_cell_id_mapping, |
|
num_labels_list, |
|
) = preload_and_process_data(config) |
|
train_loader = get_data_loader(train_dataset, config["batch_size"]) |
|
val_loader = get_data_loader(val_dataset, config["batch_size"]) |
|
|
|
if config["use_manual_hyperparameters"]: |
|
train_model( |
|
config, |
|
device, |
|
train_loader, |
|
val_loader, |
|
train_cell_id_mapping, |
|
val_cell_id_mapping, |
|
num_labels_list, |
|
) |
|
else: |
|
objective_with_config_and_data = functools.partial( |
|
objective, |
|
train_loader=train_loader, |
|
val_loader=val_loader, |
|
train_cell_id_mapping=train_cell_id_mapping, |
|
val_cell_id_mapping=val_cell_id_mapping, |
|
num_labels_list=num_labels_list, |
|
config=config, |
|
device=device, |
|
) |
|
|
|
study = optuna.create_study( |
|
direction="minimize", |
|
study_name=config["study_name"], |
|
|
|
load_if_exists=True, |
|
) |
|
|
|
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"]) |
|
|
|
|
|
best_params = study.best_trial.params |
|
best_task_weights = study.best_trial.user_attrs["task_weights"] |
|
print("Saving the best model and its hyperparameters...") |
|
|
|
|
|
best_model = GeneformerMultiTask( |
|
config["pretrained_path"], |
|
num_labels_list, |
|
dropout_rate=best_params["dropout_rate"], |
|
use_task_weights=config["use_task_weights"], |
|
task_weights=best_task_weights, |
|
) |
|
|
|
|
|
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"] |
|
|
|
|
|
best_model_state_dict = { |
|
k.replace("module.", ""): v for k, v in best_model_state_dict.items() |
|
} |
|
|
|
|
|
best_model.load_state_dict(best_model_state_dict, strict=False) |
|
|
|
model_save_directory = os.path.join( |
|
config["model_save_path"], "GeneformerMultiTask" |
|
) |
|
save_model(best_model, model_save_directory) |
|
|
|
|
|
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") |
|
|
|
with open(hyperparams_path, "w") as f: |
|
json.dump({**best_params, "task_weights": best_task_weights}, f) |
|
print(f"Best hyperparameters and task weights saved to {hyperparams_path}") |
|
|