|
import os |
|
import random |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from torch.utils.tensorboard import SummaryWriter |
|
from tqdm import tqdm |
|
|
|
from .imports import * |
|
from .model import GeneformerMultiTask |
|
from .utils import calculate_task_specific_metrics |
|
|
|
|
|
def set_seed(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def initialize_wandb(config): |
|
if config.get("use_wandb", False): |
|
import wandb |
|
|
|
wandb.init(project=config["wandb_project"], config=config) |
|
print("Weights & Biases (wandb) initialized and will be used for logging.") |
|
else: |
|
print( |
|
"Weights & Biases (wandb) is not enabled. Logging will use other methods." |
|
) |
|
|
|
|
|
def create_model(config, num_labels_list, device): |
|
model = GeneformerMultiTask( |
|
config["pretrained_path"], |
|
num_labels_list, |
|
dropout_rate=config["dropout_rate"], |
|
use_task_weights=config["use_task_weights"], |
|
task_weights=config["task_weights"], |
|
max_layers_to_freeze=config["max_layers_to_freeze"], |
|
use_attention_pooling=config["use_attention_pooling"], |
|
) |
|
if config["use_data_parallel"]: |
|
model = nn.DataParallel(model) |
|
return model.to(device) |
|
|
|
|
|
def setup_optimizer_and_scheduler(model, config, total_steps): |
|
optimizer = AdamW( |
|
model.parameters(), |
|
lr=config["learning_rate"], |
|
weight_decay=config["weight_decay"], |
|
) |
|
warmup_steps = int(config["warmup_ratio"] * total_steps) |
|
|
|
if config["lr_scheduler_type"] == "linear": |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps |
|
) |
|
elif config["lr_scheduler_type"] == "cosine": |
|
scheduler = get_cosine_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=warmup_steps, |
|
num_training_steps=total_steps, |
|
num_cycles=0.5, |
|
) |
|
|
|
return optimizer, scheduler |
|
|
|
|
|
def train_epoch( |
|
model, train_loader, optimizer, scheduler, device, config, writer, epoch |
|
): |
|
model.train() |
|
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") |
|
for batch_idx, batch in enumerate(progress_bar): |
|
optimizer.zero_grad() |
|
input_ids = batch["input_ids"].to(device) |
|
attention_mask = batch["attention_mask"].to(device) |
|
labels = [ |
|
batch["labels"][task_name].to(device) for task_name in config["task_names"] |
|
] |
|
|
|
loss, _, _ = model(input_ids, attention_mask, labels) |
|
loss.backward() |
|
|
|
if config["gradient_clipping"]: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) |
|
|
|
optimizer.step() |
|
scheduler.step() |
|
|
|
writer.add_scalar( |
|
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx |
|
) |
|
if config.get("use_wandb", False): |
|
import wandb |
|
|
|
wandb.log({"Training Loss": loss.item()}) |
|
|
|
|
|
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"}) |
|
|
|
return loss.item() |
|
|
|
|
|
def validate_model(model, val_loader, device, config): |
|
model.eval() |
|
val_loss = 0.0 |
|
task_true_labels = {task_name: [] for task_name in config["task_names"]} |
|
task_pred_labels = {task_name: [] for task_name in config["task_names"]} |
|
task_pred_probs = {task_name: [] for task_name in config["task_names"]} |
|
|
|
with torch.no_grad(): |
|
for batch in val_loader: |
|
input_ids = batch["input_ids"].to(device) |
|
attention_mask = batch["attention_mask"].to(device) |
|
labels = [ |
|
batch["labels"][task_name].to(device) |
|
for task_name in config["task_names"] |
|
] |
|
loss, logits, _ = model(input_ids, attention_mask, labels) |
|
val_loss += loss.item() |
|
|
|
for sample_idx in range(len(batch["input_ids"])): |
|
for i, task_name in enumerate(config["task_names"]): |
|
true_label = batch["labels"][task_name][sample_idx].item() |
|
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() |
|
pred_prob = ( |
|
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() |
|
) |
|
task_true_labels[task_name].append(true_label) |
|
task_pred_labels[task_name].append(pred_label) |
|
task_pred_probs[task_name].append(pred_prob) |
|
|
|
val_loss /= len(val_loader) |
|
return val_loss, task_true_labels, task_pred_labels, task_pred_probs |
|
|
|
|
|
def log_metrics(task_metrics, val_loss, config, writer, epochs): |
|
for task_name, metrics in task_metrics.items(): |
|
print( |
|
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}" |
|
) |
|
if config.get("use_wandb", False): |
|
import wandb |
|
|
|
wandb.log( |
|
{ |
|
f"{task_name} Validation F1 Macro": metrics["f1"], |
|
f"{task_name} Validation Accuracy": metrics["accuracy"], |
|
} |
|
) |
|
|
|
writer.add_scalar("Validation Loss", val_loss, epochs) |
|
for task_name, metrics in task_metrics.items(): |
|
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs) |
|
writer.add_scalar( |
|
f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs |
|
) |
|
|
|
|
|
def save_validation_predictions( |
|
val_cell_id_mapping, |
|
task_true_labels, |
|
task_pred_labels, |
|
task_pred_probs, |
|
config, |
|
trial_number=None, |
|
): |
|
if trial_number is not None: |
|
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}") |
|
os.makedirs(trial_results_dir, exist_ok=True) |
|
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv") |
|
else: |
|
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv") |
|
|
|
rows = [] |
|
for sample_idx in range(len(val_cell_id_mapping)): |
|
row = {"Cell ID": val_cell_id_mapping[sample_idx]} |
|
for task_name in config["task_names"]: |
|
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx] |
|
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx] |
|
row[f"{task_name} Probabilities"] = ",".join( |
|
map(str, task_pred_probs[task_name][sample_idx]) |
|
) |
|
rows.append(row) |
|
|
|
df = pd.DataFrame(rows) |
|
df.to_csv(val_preds_file, index=False) |
|
print(f"Validation predictions saved to {val_preds_file}") |
|
|
|
|
|
def train_model( |
|
config, |
|
device, |
|
train_loader, |
|
val_loader, |
|
train_cell_id_mapping, |
|
val_cell_id_mapping, |
|
num_labels_list, |
|
): |
|
set_seed(config["seed"]) |
|
initialize_wandb(config) |
|
|
|
model = create_model(config, num_labels_list, device) |
|
total_steps = len(train_loader) * config["epochs"] |
|
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) |
|
|
|
log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run") |
|
writer = SummaryWriter(log_dir=log_dir) |
|
|
|
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress") |
|
for epoch in epoch_progress: |
|
last_loss = train_epoch( |
|
model, train_loader, optimizer, scheduler, device, config, writer, epoch |
|
) |
|
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"}) |
|
|
|
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( |
|
model, val_loader, device, config |
|
) |
|
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) |
|
|
|
log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) |
|
writer.close() |
|
|
|
save_validation_predictions( |
|
val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config |
|
) |
|
|
|
if config.get("use_wandb", False): |
|
import wandb |
|
|
|
wandb.finish() |
|
|
|
print(f"\nFinal Validation Loss: {val_loss:.4f}") |
|
return val_loss, model |
|
|
|
|
|
def objective( |
|
trial, |
|
train_loader, |
|
val_loader, |
|
train_cell_id_mapping, |
|
val_cell_id_mapping, |
|
num_labels_list, |
|
config, |
|
device, |
|
): |
|
set_seed(config["seed"]) |
|
initialize_wandb(config) |
|
|
|
|
|
config["learning_rate"] = trial.suggest_float( |
|
"learning_rate", |
|
config["hyperparameters"]["learning_rate"]["low"], |
|
config["hyperparameters"]["learning_rate"]["high"], |
|
log=config["hyperparameters"]["learning_rate"]["log"], |
|
) |
|
config["warmup_ratio"] = trial.suggest_float( |
|
"warmup_ratio", |
|
config["hyperparameters"]["warmup_ratio"]["low"], |
|
config["hyperparameters"]["warmup_ratio"]["high"], |
|
) |
|
config["weight_decay"] = trial.suggest_float( |
|
"weight_decay", |
|
config["hyperparameters"]["weight_decay"]["low"], |
|
config["hyperparameters"]["weight_decay"]["high"], |
|
) |
|
config["dropout_rate"] = trial.suggest_float( |
|
"dropout_rate", |
|
config["hyperparameters"]["dropout_rate"]["low"], |
|
config["hyperparameters"]["dropout_rate"]["high"], |
|
) |
|
config["lr_scheduler_type"] = trial.suggest_categorical( |
|
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"] |
|
) |
|
config["use_attention_pooling"] = trial.suggest_categorical( |
|
"use_attention_pooling", [True, False] |
|
) |
|
|
|
if config["use_task_weights"]: |
|
config["task_weights"] = [ |
|
trial.suggest_float( |
|
f"task_weight_{i}", |
|
config["hyperparameters"]["task_weights"]["low"], |
|
config["hyperparameters"]["task_weights"]["high"], |
|
) |
|
for i in range(len(num_labels_list)) |
|
] |
|
weight_sum = sum(config["task_weights"]) |
|
config["task_weights"] = [ |
|
weight / weight_sum for weight in config["task_weights"] |
|
] |
|
else: |
|
config["task_weights"] = None |
|
|
|
|
|
if isinstance(config["max_layers_to_freeze"], dict): |
|
config["max_layers_to_freeze"] = trial.suggest_int( |
|
"max_layers_to_freeze", |
|
config["max_layers_to_freeze"]["min"], |
|
config["max_layers_to_freeze"]["max"], |
|
) |
|
elif isinstance(config["max_layers_to_freeze"], int): |
|
|
|
pass |
|
else: |
|
raise ValueError("Invalid type for max_layers_to_freeze. Expected dict or int.") |
|
|
|
model = create_model(config, num_labels_list, device) |
|
total_steps = len(train_loader) * config["epochs"] |
|
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) |
|
|
|
log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}") |
|
writer = SummaryWriter(log_dir=log_dir) |
|
|
|
for epoch in range(config["epochs"]): |
|
train_epoch( |
|
model, train_loader, optimizer, scheduler, device, config, writer, epoch |
|
) |
|
|
|
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( |
|
model, val_loader, device, config |
|
) |
|
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) |
|
|
|
log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) |
|
writer.close() |
|
|
|
save_validation_predictions( |
|
val_cell_id_mapping, |
|
task_true_labels, |
|
task_pred_labels, |
|
task_pred_probs, |
|
config, |
|
trial.number, |
|
) |
|
|
|
trial.set_user_attr("model_state_dict", model.state_dict()) |
|
trial.set_user_attr("task_weights", config["task_weights"]) |
|
|
|
trial.report(val_loss, config["epochs"]) |
|
|
|
if trial.should_prune(): |
|
raise optuna.TrialPruned() |
|
|
|
if config.get("use_wandb", False): |
|
import wandb |
|
|
|
wandb.log( |
|
{ |
|
"trial_number": trial.number, |
|
"val_loss": val_loss, |
|
**{ |
|
f"{task_name}_f1": metrics["f1"] |
|
for task_name, metrics in task_metrics.items() |
|
}, |
|
**{ |
|
f"{task_name}_accuracy": metrics["accuracy"] |
|
for task_name, metrics in task_metrics.items() |
|
}, |
|
**{ |
|
k: v |
|
for k, v in config.items() |
|
if k |
|
in [ |
|
"learning_rate", |
|
"warmup_ratio", |
|
"weight_decay", |
|
"dropout_rate", |
|
"lr_scheduler_type", |
|
"use_attention_pooling", |
|
"max_layers_to_freeze", |
|
] |
|
}, |
|
} |
|
) |
|
wandb.finish() |
|
|
|
return val_loss |
|
|