File size: 5,453 Bytes
f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import random
from .data import get_data_loader, preload_and_process_data
from .imports import *
from .model import GeneformerMultiTask
from .train import objective, train_model
from .utils import save_model
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def run_manual_tuning(config):
# Set seed for reproducibility
set_seed(config["seed"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
(
train_dataset,
train_cell_id_mapping,
val_dataset,
val_cell_id_mapping,
num_labels_list,
) = preload_and_process_data(config)
train_loader = get_data_loader(train_dataset, config["batch_size"])
val_loader = get_data_loader(val_dataset, config["batch_size"])
# Print the manual hyperparameters being used
print("\nManual hyperparameters being used:")
for key, value in config["manual_hyperparameters"].items():
print(f"{key}: {value}")
print() # Add an empty line for better readability
# Use the manual hyperparameters
for key, value in config["manual_hyperparameters"].items():
config[key] = value
# Train the model
val_loss, trained_model = train_model(
config,
device,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
)
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
# Save the trained model
model_save_directory = os.path.join(
config["model_save_path"], "GeneformerMultiTask"
)
save_model(trained_model, model_save_directory)
# Save the hyperparameters
hyperparams_to_save = {
**config["manual_hyperparameters"],
"dropout_rate": config["dropout_rate"],
"use_task_weights": config["use_task_weights"],
"task_weights": config["task_weights"],
"max_layers_to_freeze": config["max_layers_to_freeze"],
"use_attention_pooling": config["use_attention_pooling"],
}
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
with open(hyperparams_path, "w") as f:
json.dump(hyperparams_to_save, f)
print(f"Manual hyperparameters saved to {hyperparams_path}")
return val_loss
def run_optuna_study(config):
# Set seed for reproducibility
set_seed(config["seed"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
(
train_dataset,
train_cell_id_mapping,
val_dataset,
val_cell_id_mapping,
num_labels_list,
) = preload_and_process_data(config)
train_loader = get_data_loader(train_dataset, config["batch_size"])
val_loader = get_data_loader(val_dataset, config["batch_size"])
if config["use_manual_hyperparameters"]:
train_model(
config,
device,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
)
else:
objective_with_config_and_data = functools.partial(
objective,
train_loader=train_loader,
val_loader=val_loader,
train_cell_id_mapping=train_cell_id_mapping,
val_cell_id_mapping=val_cell_id_mapping,
num_labels_list=num_labels_list,
config=config,
device=device,
)
study = optuna.create_study(
direction="minimize", # Minimize validation loss
study_name=config["study_name"],
# storage=config["storage"],
load_if_exists=True,
)
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
# After finding the best trial
best_params = study.best_trial.params
best_task_weights = study.best_trial.user_attrs["task_weights"]
print("Saving the best model and its hyperparameters...")
# Saving model as before
best_model = GeneformerMultiTask(
config["pretrained_path"],
num_labels_list,
dropout_rate=best_params["dropout_rate"],
use_task_weights=config["use_task_weights"],
task_weights=best_task_weights,
)
# Get the best model state dictionary
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
# Remove the "module." prefix from the state dictionary keys if present
best_model_state_dict = {
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
}
# Load the modified state dictionary into the model, skipping unexpected keys
best_model.load_state_dict(best_model_state_dict, strict=False)
model_save_directory = os.path.join(
config["model_save_path"], "GeneformerMultiTask"
)
save_model(best_model, model_save_directory)
# Additionally, save the best hyperparameters and task weights
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
with open(hyperparams_path, "w") as f:
json.dump({**best_params, "task_weights": best_task_weights}, f)
print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
|