|
import json |
|
import pickle |
|
import os |
|
from types import SimpleNamespace as sn |
|
import time |
|
from os.path import join |
|
import copy |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.distributed as dist |
|
from torch.distributed.algorithms.join import Join |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.optim import AdamW |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from torch.utils.data import Subset |
|
from torch.utils.tensorboard import SummaryWriter |
|
from torch_geometric.loader import DataLoader |
|
from torch.utils.data import DataLoader as TorchDataLoader |
|
import loralib as lora |
|
import gpytorch |
|
import data |
|
import utils.configs |
|
from model.module.utils import loss_fn_mapping |
|
import data |
|
from model.model import create_model, create_model_and_load |
|
from torch import _dynamo |
|
_dynamo.config.suppress_errors = True |
|
|
|
class PreMode_trainer(object): |
|
""" |
|
A wrapper for dataloader, summary writer, optimizer, scheduler |
|
""" |
|
|
|
def __init__(self, hparams, model, stage: str = "train", dataset=None, device_id=None): |
|
super(PreMode_trainer, self).__init__() |
|
if isinstance(hparams, dict): |
|
hparams = sn(**hparams) |
|
self.hparams = hparams |
|
|
|
|
|
self.device_id = device_id |
|
if device_id is not None and torch.cuda.is_available(): |
|
self.device = f"cuda:{device_id}" |
|
else: |
|
self.device = "cpu" |
|
|
|
self.model = model.to(self.device) |
|
|
|
|
|
self.dataset = dataset |
|
self.train_dataset = None |
|
self.val_dataset = None |
|
self.test_dataset = None |
|
self.train_dataloader = None |
|
self.val_dataloader = None |
|
self.test_dataloader = None |
|
self.split_fn = self.hparams.data_split_fn |
|
self.setup_dataloaders(stage, self.split_fn) |
|
print(f'Finished setting dataloaders for rank {self.device_id}') |
|
if self.train_dataloader is not None: |
|
self.batchs_per_epoch = len(self.train_dataloader) |
|
self.num_data = len(self.train_dataloader.dataset) |
|
else: |
|
self.batchs_per_epoch = 0 |
|
self.num_data = len(self.test_dataloader.dataset) |
|
self.reset_train_dataloader_each_epoch = self.hparams.reset_train_dataloader_each_epoch and hparams.data_split_fn != "_by_anno" |
|
self.reset_train_dataloader_each_epoch_seed = self.hparams.reset_train_dataloader_each_epoch_seed |
|
self.train_iterator = None |
|
self.val_iterator = None |
|
self.test_iterator = None |
|
|
|
|
|
if self.hparams.loss_fn == "weighted_combined_loss" or "weighted_loss" in self.hparams.loss_fn: |
|
label_counts = self.dataset.get_label_counts() |
|
if len(label_counts) == 4: |
|
|
|
|
|
total_count_1 = label_counts.sum() |
|
task_weight = total_count_1 / (label_counts[0] + label_counts[2]) |
|
total_count_2 = total_count_1 - label_counts[3] - label_counts[0] |
|
if label_counts[1] != 0: |
|
weight_1 = torch.tensor([total_count_1 / label_counts[1] / 2, |
|
total_count_1 / (total_count_1 - label_counts[1]) / 2], |
|
dtype=torch.float32, device=self.device) |
|
weight_2 = torch.tensor([total_count_2 / label_counts[0] / 2, |
|
total_count_2 / label_counts[2] / 2], |
|
dtype=torch.float32, device=self.device) |
|
else: |
|
weight_1 = torch.ones(2, dtype=torch.float32, device=self.device) |
|
weight_2 = torch.tensor([total_count_2 / label_counts[0] / 2, |
|
total_count_2 / label_counts[2] / 2], |
|
dtype=torch.float32, device=self.device) |
|
elif len(label_counts) == 2: |
|
|
|
task_weight = 0 |
|
total_count_1 = label_counts.sum() |
|
if label_counts[0] != 0: |
|
weight_1 = torch.tensor([total_count_1 / label_counts[0] / 2, |
|
total_count_1 / label_counts[1] / 2], |
|
dtype=torch.float32, device=self.device) |
|
weight_2 = torch.zeros(2, dtype=torch.float32, device=self.device) |
|
else: |
|
weight_1 = torch.ones(2, dtype=torch.float32, device=self.device) |
|
weight_2 = torch.zeros(2, dtype=torch.float32, device=self.device) |
|
else: |
|
raise ValueError("The number of labels should be 2 or 4.") |
|
weight=torch.cat([weight_1, weight_2]) |
|
print(f"set up weighted loss function with weight: {weight}") |
|
self.loss_fn = loss_fn_mapping[self.hparams.loss_fn](weight=weight, task_weight=task_weight) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif self.hparams.loss_fn == "GP_loss": |
|
self.loss_fn = gpytorch.mlls.VariationalELBO(self.model.output_model.likelihood, |
|
self.model.output_model.output_network, |
|
num_data=self.num_data) |
|
self.hparams.y_weight = -1 |
|
else: |
|
self.loss_fn = loss_fn_mapping[self.hparams.loss_fn] |
|
|
|
|
|
if self.hparams.freeze_representation: |
|
for param in self.model.representation_model.parameters(): |
|
param.requires_grad = False |
|
|
|
self.model.representation_model.eval() |
|
if self.hparams.freeze_representation_but_attention: |
|
for param in self.model.representation_model.parameters(): |
|
param.requires_grad = False |
|
|
|
self.model.representation_model.eval() |
|
for param in self.model.representation_model.attention_layers.parameters(): |
|
param.requires_grad = True |
|
if self.hparams.freeze_representation_but_gru: |
|
for param in self.model.representation_model.parameters(): |
|
param.requires_grad = False |
|
|
|
self.model.representation_model.eval() |
|
for layer in self.model.representation_model.attention_layers: |
|
assert layer.gru is not None |
|
for param in layer.gru.parameters(): |
|
param.requires_grad = True |
|
if self.hparams.use_lora is not None: |
|
self.model.eval() |
|
lora.mark_only_lora_as_trainable(model) |
|
|
|
if isinstance(self.model, DDP): |
|
if self.hparams.loss_fn == "weighted_combined_loss" or self.hparams.loss_fn == "combined_loss": |
|
self.model.module.output_model.output_network.requires_grad_(True) |
|
elif self.hparams.loss_fn == "weighted_loss": |
|
self.model.module.output_model.requires_grad_(True) |
|
elif self.hparams.model == "lora-esm": |
|
self.model.module.output_model.requires_grad_(True) |
|
else: |
|
if self.hparams.loss_fn == "weighted_combined_loss" or self.hparams.loss_fn == "combined_loss": |
|
self.model.output_model.output_network.requires_grad_(True) |
|
elif self.hparams.loss_fn == "weighted_loss": |
|
self.model.output_model.requires_grad_(True) |
|
elif self.hparams.model == "lora-esm": |
|
self.model.output_model.requires_grad_(True) |
|
self.use_lora = True |
|
else: |
|
self.use_lora = False |
|
|
|
|
|
|
|
self.losses = None |
|
self._reset_losses_dict() |
|
|
|
|
|
self.predictions = None |
|
self._reset_predictions_dict() |
|
|
|
|
|
self.global_step = 0 |
|
self.current_epoch = 0 |
|
|
|
|
|
self.updated = True |
|
self.optimizer = None |
|
self.scheduler = None |
|
self.lr_scheduler = None |
|
self.configure_optimizers() |
|
|
|
|
|
self.contrastive_loss = loss_fn_mapping[self.hparams.contrastive_loss_fn] if self.hparams.contrastive_loss_fn is not None else None |
|
|
|
|
|
if stage == "train": |
|
self.writer = SummaryWriter(log_dir=f'{self.hparams.log_dir}/log/') |
|
|
|
def setup_dataloaders(self, stage: str = 'train', split_fn="_by_uniprot_id"): |
|
if self.dataset is None: |
|
self.dataset = getattr(data, self.hparams["dataset"])( |
|
data_file=self.hparams.data_file_train, |
|
data_type=self.hparams.data_type, |
|
radius=self.hparams.radius, |
|
max_neighbors=self.hparams.max_num_neighbors, |
|
loop=self.hparams.loop, |
|
) |
|
if self.hparams.dataset.startswith("FullGraph"): |
|
data_loader_fn = TorchDataLoader |
|
else: |
|
data_loader_fn = DataLoader |
|
if stage == 'train': |
|
|
|
if self.hparams.val_size > 0: |
|
idx_train, idx_val = getattr(utils.configs, "make_splits_train_val" + split_fn)( |
|
self.dataset, |
|
self.hparams.train_size, |
|
self.hparams.val_size, |
|
self.hparams.seed, |
|
self.hparams.batch_size, |
|
join(self.hparams.log_dir, f"splits.{self.device_id}.npz"), |
|
) |
|
print(f"train {len(idx_train)}, val {len(idx_val)}") |
|
if split_fn == "_by_anno": |
|
self.val_dataset = copy.deepcopy(self.dataset).subset(idx_val) |
|
self.train_dataset = self.dataset.subset(idx_train) |
|
else: |
|
self.val_dataset = Subset(self.dataset, idx_val) |
|
self.train_dataset = Subset(self.dataset, idx_train) |
|
self.idx_val = idx_val |
|
self.idx_train = idx_train |
|
else: |
|
self.train_dataset = self.dataset |
|
self.val_dataset = None |
|
self.idx_train = np.arange(len(self.dataset)) |
|
self.idx_val = None |
|
dataloader_args = { |
|
"batch_size": self.hparams.batch_size, |
|
"num_workers": min(20, self.hparams.num_workers), |
|
"pin_memory": True, |
|
"shuffle": split_fn=='_by_anno' |
|
} |
|
if self.hparams.num_workers == 0: |
|
dataloader_args['pin_memory_device'] = 'cpu' |
|
self.train_dataloader = data_loader_fn( |
|
dataset=self.train_dataset, |
|
**dataloader_args, |
|
) |
|
if self.val_dataset is not None: |
|
dataloader_args['shuffle'] = False |
|
dataloader_args["num_workers"] = 0 |
|
dataloader_args["pin_memory"] = False |
|
self.val_dataloader = data_loader_fn( |
|
dataset=self.val_dataset, |
|
**dataloader_args, |
|
) |
|
else: |
|
self.val_dataloader = None |
|
elif stage == 'test': |
|
|
|
self.test_dataset = self.dataset |
|
dataloader_args = { |
|
"batch_size": self.hparams.batch_size, |
|
"num_workers": 0, |
|
"pin_memory": False, |
|
"shuffle": False |
|
} |
|
self.test_dataloader = data_loader_fn( |
|
dataset=self.test_dataset, |
|
**dataloader_args, |
|
) |
|
elif stage == 'all': |
|
|
|
idx_train, idx_val, idx_test = getattr(utils.configs, "make_splits_train_val_test" + split_fn)( |
|
self.dataset, |
|
self.hparams.train_size, |
|
self.hparams.val_size, |
|
self.hparams.test_size, |
|
0, |
|
self.hparams.batch_size * self.hparams.num_workers, |
|
join(self.hparams.log_dir, "splits.npz"), |
|
self.hparams.splits, |
|
) |
|
print(f"train {len(idx_train)}, val {len(idx_val)}, test {len(idx_test)}") |
|
|
|
self.val_dataset = copy.deepcopy(self.dataset).subset(idx_val) |
|
self.idx_val = idx_val |
|
self.test_dataset = copy.deepcopy(self.dataset).subset(idx_test) |
|
self.idx_test = idx_test |
|
self.train_dataset = self.dataset.subset(idx_train) |
|
self.idx_train = idx_train |
|
|
|
self.train_dataloader = data_loader_fn( |
|
dataset=self.train_dataset, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=0, |
|
pin_memory=True, |
|
pin_memory_device='cpu', |
|
shuffle=False, |
|
) |
|
self.val_dataloader = data_loader_fn( |
|
dataset=self.val_dataset, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=0, |
|
pin_memory=True, |
|
pin_memory_device='cpu', |
|
shuffle=False, |
|
) |
|
self.test_dataloader = data_loader_fn( |
|
dataset=self.test_dataset, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=0, |
|
pin_memory=True, |
|
pin_memory_device='cpu', |
|
shuffle=False, |
|
) |
|
else: |
|
raise ValueError(f"stage {stage} not supported") |
|
|
|
def configure_optimizers(self): |
|
|
|
self.optimizer = AdamW( |
|
filter(lambda p: p.requires_grad, self.model.parameters()), |
|
lr=float(self.hparams.lr), |
|
weight_decay=self.hparams.weight_decay, |
|
) |
|
self.scheduler = ReduceLROnPlateau( |
|
self.optimizer, |
|
"min", |
|
factor=self.hparams.lr_factor, |
|
patience=self.hparams.lr_patience, |
|
min_lr=float(self.hparams.lr_min), |
|
) |
|
self.lr_scheduler = { |
|
"scheduler": self.scheduler, |
|
"monitor": getattr(self.hparams, "lr_metric", "val_loss"), |
|
"interval": "epoch", |
|
"frequency": 1, |
|
} |
|
|
|
def forward(self, x, x_mask, x_alt, pos, batch=None, |
|
edge_index=None, edge_attr=None, |
|
edge_index_star=None, edge_attr_star=None, |
|
node_vec_attr=None, |
|
extra_args=None, |
|
return_attn=False): |
|
return self.model(x=x, |
|
x_mask=x_mask, |
|
x_alt=x_alt, |
|
pos=pos, |
|
batch=batch, |
|
edge_index=edge_index, |
|
edge_attr=edge_attr, |
|
edge_index_star=edge_index_star, |
|
edge_attr_star=edge_attr_star, |
|
node_vec_attr=node_vec_attr, |
|
extra_args=extra_args, |
|
return_attn=return_attn) |
|
|
|
def training_step(self): |
|
if self.train_iterator is None: |
|
raise ValueError("train_iterator is None, please call training_epoch_begin() first") |
|
batch = next(self.train_iterator) |
|
loss = self.step(batch, "train") / self.hparams.num_steps_update |
|
loss.backward() |
|
self.write_loss_log("train", loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.updated = False |
|
self.global_step += 1 |
|
return loss |
|
|
|
def validation_step(self): |
|
if self.val_iterator is None: |
|
raise ValueError("val_iterator is None, please call validation_epoch_begin() first") |
|
batch = next(self.val_iterator) |
|
with torch.no_grad(): |
|
loss = self.step(batch, "val") |
|
|
|
return loss |
|
|
|
def test_step(self): |
|
if self.test_iterator is None: |
|
raise ValueError("test_iterator is None, please call test_epoch_begin() first") |
|
batch = next(self.test_iterator) |
|
with torch.no_grad(): |
|
return self.step(batch, "test") |
|
|
|
def interpret_step(self, batch): |
|
with torch.no_grad(): |
|
return self.step(batch, "interpret") |
|
|
|
def step(self, batch, stage): |
|
with torch.set_grad_enabled(stage == "train"): |
|
if isinstance(batch, dict): |
|
extra_args = copy.deepcopy(batch) |
|
batch = sn(**batch) |
|
else: |
|
extra_args = batch.to_dict() |
|
|
|
for a in ('y', 'x', 'x_mask', 'x_alt', 'pos', 'batch', |
|
'edge_index', 'edge_attr', |
|
'edge_index_star', 'edge_attr_star', |
|
'node_vec_attr'): |
|
if a in extra_args: |
|
del extra_args[a] |
|
y, x_embed, attn_weight_layers = self.forward( |
|
x=batch.x.to(self.device, non_blocking=True), |
|
x_mask=batch.x_mask.to(self.device, non_blocking=True), |
|
x_alt=batch.x_alt.to(self.device, non_blocking=True), |
|
pos=batch.pos.to(self.device, non_blocking=True) if hasattr(batch, "pos") and batch.pos is not None else None, |
|
batch=batch.batch.to(self.device, non_blocking=True) if hasattr(batch, "batch") and batch.batch is not None else None, |
|
edge_index=batch.edge_index.to(self.device, non_blocking=True) if hasattr(batch, "edge_index") and batch.edge_index is not None else None, |
|
edge_index_star=batch.edge_index_star.to(self.device, non_blocking=True) if hasattr(batch, "edge_index_star") and batch.edge_index_star is not None else None, |
|
edge_attr=batch.edge_attr.to(self.device, non_blocking=True) if hasattr(batch, "edge_attr") and batch.edge_attr is not None else None, |
|
edge_attr_star=batch.edge_attr_star.to(self.device, non_blocking=True) if hasattr(batch, "edge_attr_star") and batch.edge_attr_star is not None else None, |
|
node_vec_attr=batch.node_vec_attr.to(self.device, non_blocking=True) if hasattr(batch, "node_vec_attr") and batch.node_vec_attr is not None else None, |
|
extra_args=extra_args, |
|
return_attn=stage == "interpret", |
|
) |
|
if stage == "test": |
|
if self.hparams.dataset.startswith("Mask"): |
|
|
|
self.predictions['y'].append(y[batch.x_mask == False].detach().cpu().numpy()) |
|
else: |
|
self.predictions['y'].append(y.detach().cpu().numpy()) |
|
loss_y = 0 |
|
|
|
if stage != "interpret": |
|
if hasattr(batch, 'y'): |
|
if batch.y.ndim == 1 and self.hparams.loss_fn != "cross_entropy": |
|
batch.y = batch.y.unsqueeze(1) |
|
|
|
|
|
if self.hparams.dataset.startswith("Mask"): |
|
y = y[batch.x_mask==False] |
|
batch.y = batch.y[batch.x_mask==False] |
|
if self.hparams.loss_fn == "GP_loss": |
|
batch.y = (batch.y + 1) / 2 |
|
if hasattr(batch, 'score_mask'): |
|
loss_y = self.loss_fn(input=y, |
|
target=batch.y.to(self.device, non_blocking=True), |
|
weight=batch.score_mask.to(self.device, non_blocking=True)) |
|
else: |
|
loss_y = self.loss_fn(y, batch.y.to(self.device, non_blocking=True)) |
|
if loss_y.ndim > 0: |
|
loss_y = loss_y.mean() |
|
if self.contrastive_loss is not None: |
|
loss_cont = self.contrastive_loss(x_embed, batch.y.to(self.device)) |
|
else: |
|
loss_cont = 0 |
|
|
|
if self.hparams.y_weight != 0 and stage != "interpret": |
|
self.losses[stage + "_y"].append(loss_y.detach().cpu() * self.hparams.y_weight) |
|
|
|
|
|
loss = loss_y * self.hparams.y_weight + loss_cont |
|
self.losses[stage].append(loss.detach().cpu()) |
|
return loss |
|
else: |
|
if self.hparams.loss_fn == "GP_loss": |
|
return self.model.output_model.likelihood(y).variance, self.model.output_model.likelihood(y).mean, x_embed, attn_weight_layers |
|
else: |
|
return None, y, x_embed, attn_weight_layers |
|
|
|
def optimizer_step(self, loss=None): |
|
|
|
if self.global_step < self.hparams.lr_warmup_steps: |
|
lr_scale = min( |
|
1.0, |
|
float(self.global_step + 1) |
|
/ float(self.hparams.lr_warmup_steps), |
|
) |
|
for pg in self.optimizer.param_groups: |
|
pg["lr"] = lr_scale * float(self.hparams.lr) |
|
|
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
self.updated = True |
|
|
|
def scheduler_step(self, val_loss): |
|
self.scheduler.step(val_loss) |
|
|
|
def training_epoch_begin(self): |
|
if hasattr(self.dataset, 'env') and self.dataset.env is not None: |
|
self.dataset.env.close() |
|
self.dataset.env = None |
|
if hasattr(self.dataset, 'txn') and self.dataset.txn is not None: |
|
self.dataset.txn = None |
|
self.train_iterator = iter(self.train_dataloader) |
|
|
|
self.model.train() |
|
|
|
def training_epoch_end(self): |
|
self.train_iterator = None |
|
self._reset_losses_dict() |
|
self.current_epoch += 1 |
|
if self.reset_train_dataloader_each_epoch: |
|
idx_train = getattr(utils.configs, "reshuffle_train" + self.split_fn)(self.idx_train, self.hparams.batch_size, |
|
self.dataset, |
|
seed=self.current_epoch if self.reset_train_dataloader_each_epoch_seed else None) |
|
self.train_dataset = Subset(self.dataset, idx_train) |
|
dataloader_args = { |
|
"batch_size": self.hparams.batch_size, |
|
"num_workers": min(1, self.hparams.num_workers), |
|
"pin_memory": True, |
|
"shuffle": False |
|
} |
|
if self.hparams.num_workers == 0: |
|
dataloader_args['pin_memory_device'] = 'cpu' |
|
self.train_dataloader = DataLoader( |
|
dataset=self.train_dataset, |
|
**dataloader_args, |
|
) |
|
|
|
def validation_epoch_begin(self): |
|
if self.val_dataloader is None: |
|
self.val_iterator = iter(self.train_dataloader) |
|
else: |
|
self.val_iterator = iter(self.val_dataloader) |
|
|
|
self.model.eval() |
|
|
|
def validation_epoch_end(self, reset_train_loss=False): |
|
self.val_iterator = None |
|
|
|
result_dict = { |
|
"epoch": int(self.current_epoch), |
|
"lr": self.optimizer.param_groups[0]["lr"], |
|
"train_loss": torch.stack(self.losses["train"]).mean().item() if len(self.losses["train"]) > 0 else None, |
|
} |
|
if self.val_dataset is not None: |
|
result_dict["val_loss"] = torch.stack(self.losses["val"]).mean().item() if len(self.losses["val"]) > 0 else 0 |
|
self.write_loss_log("val", result_dict["val_loss"]) |
|
else: |
|
|
|
result_dict["val_loss"] = torch.stack(self.losses["train"]).mean().item() |
|
self.write_loss_log("val", torch.stack(self.losses["train"]).mean()) |
|
|
|
if len(self.losses["test"]) > 0: |
|
result_dict["test_loss"] = torch.stack(self.losses["test"]).mean().item() |
|
|
|
|
|
if len(self.losses["train_y"]) > 0: |
|
result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean().item() |
|
if self.val_dataset is not None: |
|
result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean().item() if len(self.losses["val_y"]) > 0 else 0 |
|
|
|
if len(self.losses["test"]) > 0: |
|
result_dict["test_loss_y"] = torch.stack( |
|
self.losses["test_y"] |
|
).mean().item() |
|
if reset_train_loss: |
|
self._reset_losses_dict() |
|
else: |
|
self._reset_val_losses_dict() |
|
|
|
self.model.train() |
|
return result_dict |
|
|
|
def testing_epoch_begin(self): |
|
self.test_iterator = iter(self.test_dataloader) |
|
|
|
self.model.eval() |
|
|
|
def testing_epoch_end(self): |
|
self.test_iterator = None |
|
|
|
result_dict = { |
|
"epoch": int(self.current_epoch), |
|
"lr": self.optimizer.param_groups[0]["lr"], |
|
"test_loss": torch.stack(self.losses["test"]).mean().item(), |
|
} |
|
|
|
if len(self.losses["test_y"]) > 0: |
|
if len(self.losses["test"]) > 0: |
|
result_dict["test_loss_y"] = torch.stack( |
|
self.losses["test_y"] |
|
).mean().item() |
|
self._reset_losses_dict() |
|
|
|
y_result = pd.DataFrame(np.concatenate(self.predictions['y'], axis=0), |
|
index=self.dataset.data.index) |
|
y_result.columns = [f'y.{i}' for i in y_result.columns] |
|
result_df = pd.concat( |
|
[self.dataset.data, |
|
y_result, |
|
], |
|
axis=1 |
|
) |
|
self._reset_predictions_dict() |
|
|
|
self.model.train() |
|
return result_dict, result_df |
|
|
|
def write_loss_log(self, stage, loss): |
|
if self.device_id is None: |
|
scalar_name = f"loss/{stage}" |
|
else: |
|
scalar_name = f"loss/ddp_rank.{self.device_id}.{stage}" |
|
self.writer.add_scalar(scalar_name, loss, self.global_step) |
|
if stage == "train" and self.device_id == 0: |
|
for tag, value in self.model.named_parameters(): |
|
tag = tag.replace('.', '/') |
|
self.writer.add_histogram('weights/'+tag, value.data.cpu().numpy(), self.global_step) |
|
try: |
|
|
|
if value.grad is not None: |
|
self.writer.add_histogram('grads/'+tag, value.grad.data.cpu().numpy(), self.global_step) |
|
except: |
|
print(f"failed to add grad histogram for '{tag}' in counter: {self.global_step}") |
|
|
|
def write_model(self, epoch=None, step=None, save_optimizer=False, optimizer_rank=None): |
|
if save_optimizer: |
|
assert optimizer_rank is not None |
|
if epoch is None: |
|
if step is None: |
|
model_save_file_name = f"{self.hparams.log_dir}/model.epoch.{self.current_epoch}.step.{self.global_step}.pt" |
|
if save_optimizer: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.epoch.{self.current_epoch}.step.{self.global_step}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.epoch.{self.current_epoch}.step.{self.global_step}.rank.{optimizer_rank}.pt" |
|
else: |
|
model_save_file_name = f"{self.hparams.log_dir}/model.step.{step}.pt" |
|
if save_optimizer: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.step.{step}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.step.{step}.rank.{optimizer_rank}.pt" |
|
else: |
|
if step is None: |
|
model_save_file_name = f"{self.hparams.log_dir}/model.epoch.{epoch}.pt" |
|
if save_optimizer: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.epoch.{epoch}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.epoch.{epoch}.rank.{optimizer_rank}.pt" |
|
else: |
|
model_save_file_name = f"{self.hparams.log_dir}/model.epoch.{epoch}.step.{step}.pt" |
|
if save_optimizer: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.epoch.{epoch}.step.{step}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.epoch.{epoch}.step.{step}.rank.{optimizer_rank}.pt" |
|
if isinstance(self.model, DDP): |
|
if self.use_lora: |
|
state_dic = lora.lora_state_dict(self.model.module) |
|
|
|
output_model_state_dic = self.model.module.output_model.state_dict() |
|
for key, value in output_model_state_dic.items(): |
|
state_dic[f"module.output_model.{key}"] = value |
|
torch.save(state_dic, model_save_file_name) |
|
else: |
|
torch.save(self.model.module.state_dict(), model_save_file_name) |
|
else: |
|
if self.use_lora: |
|
state_dic = lora.lora_state_dict(self.model) |
|
|
|
output_model_state_dic = self.model.output_model.output_network.state_dict() |
|
for key, value in output_model_state_dic.items(): |
|
state_dic[f"output_model.output_network.{key}"] = value |
|
torch.save(state_dic, model_save_file_name) |
|
else: |
|
torch.save(self.model.state_dict(), model_save_file_name) |
|
if save_optimizer: |
|
torch.save(self.optimizer.state_dict(), optimizer_save_file_name) |
|
torch.save(self.scheduler.state_dict(), scheduler_save_file_name) |
|
|
|
def write_optimizer(self, epoch=None, step=None, optimizer_rank=None): |
|
if epoch is None: |
|
if step is None: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.epoch.{self.current_epoch}.step.{self.global_step}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.epoch.{self.current_epoch}.step.{self.global_step}.rank.{optimizer_rank}.pt" |
|
else: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.step.{step}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.step.{step}.rank.{optimizer_rank}.pt" |
|
else: |
|
if step is None: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.epoch.{epoch}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.epoch.{epoch}.rank.{optimizer_rank}.pt" |
|
else: |
|
optimizer_save_file_name = f"{self.hparams.log_dir}/optimizer.epoch.{epoch}.step.{step}.rank.{optimizer_rank}.pt" |
|
scheduler_save_file_name = f"{self.hparams.log_dir}/scheduler.epoch.{epoch}.step.{step}.rank.{optimizer_rank}.pt" |
|
torch.save(self.optimizer.state_dict(), optimizer_save_file_name) |
|
torch.save(self.scheduler.state_dict(), scheduler_save_file_name) |
|
|
|
def load_model(self, epoch=None, step=None, update_count=False): |
|
|
|
if (epoch is not None and epoch == 0) or (step is not None and step == 0): |
|
return |
|
if epoch is None: |
|
if step is None: |
|
_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/model.epoch.{self.current_epoch}.step.{self.global_step}.pt", |
|
maplocation=self.device |
|
) |
|
else: |
|
_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/model.step.{step}.pt", |
|
map_location=self.device |
|
) |
|
if update_count: |
|
self.global_step = step |
|
self.current_epoch = step // self.batchs_per_epoch |
|
else: |
|
if step is None: |
|
_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/model.epoch.{epoch}.pt", |
|
map_location=self.device |
|
) |
|
if update_count: |
|
self.current_epoch = epoch |
|
self.global_step = epoch * self.batchs_per_epoch |
|
else: |
|
_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/model.epoch.{epoch}.step.{step}.pt", |
|
map_location=self.device |
|
) |
|
if update_count: |
|
self.current_epoch = epoch |
|
self.global_step = step |
|
_state_dict_is_ddp = list(_state_dict.keys())[0].startswith("module.") |
|
if isinstance(self.model, DDP): |
|
if _state_dict_is_ddp: |
|
self.model.load_state_dict(_state_dict, strict=self.use_lora==False) |
|
else: |
|
self.model.module.load_state_dict(_state_dict, strict=self.use_lora==False) |
|
else: |
|
if _state_dict_is_ddp: |
|
|
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
for k, v in _state_dict.items(): |
|
name = k[7:] |
|
new_state_dict[name] = v |
|
|
|
self.model.load_state_dict(new_state_dict, strict=self.use_lora==False) |
|
else: |
|
self.model.load_state_dict(_state_dict, strict=self.use_lora==False) |
|
|
|
def load_optimizer(self, epoch=None, step=None, optimizer_rank=0): |
|
if epoch is None: |
|
if step is None: |
|
optimizer_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/optimizer.epoch.{self.current_epoch}.step.{self.global_step}.rank.{optimizer_rank}.pt", |
|
maplocation=self.device |
|
) |
|
scheduler_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/scheduler.epoch.{self.current_epoch}.step.{self.global_step}.rank.{optimizer_rank}.pt", |
|
maplocation=self.device |
|
) |
|
else: |
|
optimizer_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/optimizer.step.{step}.rank.{optimizer_rank}.pt", |
|
map_location=self.device |
|
) |
|
scheduler_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/scheduler.step.{step}.rank.{optimizer_rank}.pt", |
|
map_location=self.device |
|
) |
|
else: |
|
if step is None: |
|
optimizer_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/optimizer.epoch.{epoch}.rank.{optimizer_rank}.pt", |
|
map_location=self.device |
|
) |
|
scheduler_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/scheduler.epoch.{epoch}.rank.{optimizer_rank}.pt", |
|
map_location=self.device |
|
) |
|
else: |
|
optimizer_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/optimizer.epoch.{epoch}.step.{step}.rank.{optimizer_rank}.pt", |
|
map_location=self.device |
|
) |
|
scheduler_state_dict = torch.load( |
|
f"{self.hparams.log_dir}/scheduler.epoch.{epoch}.step.{step}.rank.{optimizer_rank}.pt", |
|
map_location=self.device |
|
) |
|
self.optimizer.load_state_dict(optimizer_state_dict) |
|
self.scheduler.load_state_dict(scheduler_state_dict) |
|
|
|
def _reset_predictions_dict(self): |
|
self.predictions = { |
|
"y": [], |
|
} |
|
|
|
def _reset_losses_dict(self): |
|
self.losses = { |
|
"train": [], |
|
"val": [], |
|
"test": [], |
|
"train_y": [], |
|
"val_y": [], |
|
"test_y": [], |
|
} |
|
|
|
def _reset_val_losses_dict(self): |
|
self.losses["val"] = [] |
|
self.losses["val_y"] = [] |
|
|
|
|
|
def setup(rank, world_size): |
|
os.environ['MASTER_ADDR'] = 'localhost' |
|
os.environ['MASTER_PORT'] = '15433' |
|
|
|
dist.init_process_group("gloo", rank=rank, world_size=world_size) |
|
|
|
|
|
def cleanup(): |
|
dist.destroy_process_group() |
|
|
|
|
|
def data_distributed_parallel_gpu(rank, model, hparams, dataset_att, dataset_extra_args, trainer_fn=None, checkpoint_epoch=None): |
|
|
|
|
|
global result_dict |
|
if isinstance(hparams, dict): |
|
|
|
hparams = sn(**hparams) |
|
torch.set_num_threads(6) |
|
world_size = hparams.ngpus |
|
epochs = hparams.num_epochs |
|
save_every_step = hparams.num_save_batches |
|
save_every_epoch = hparams.num_save_epochs |
|
setup(rank, world_size) |
|
device = f'cuda:{rank}' |
|
torch.cuda.set_per_process_memory_fraction(1.0, rank) |
|
if hparams.dataset.startswith("FullGraph"): |
|
model = torch.compile(model.to(device)) |
|
print(f'Compiled model in rank {rank}') |
|
else: |
|
model = model.to(device) |
|
|
|
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=hparams.model.startswith("lora")) |
|
ddp_model.train() |
|
|
|
|
|
print(f'Begin loading dataset in rank {rank}') |
|
dataset = getattr(data, hparams.dataset)( |
|
data_file=f"{hparams.data_file_train_ddp_prefix}.{rank}.csv", |
|
gpu_id=rank, |
|
**dataset_att, |
|
**dataset_extra_args, |
|
) |
|
print(f'Loaded dataset in rank {rank}') |
|
trainer = trainer_fn(hparams=hparams, model=ddp_model, dataset=dataset, device_id=rank) |
|
print(f"number of trainable parameters: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)}, " + |
|
f"percentage = {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad) / sum(p.numel() for p in trainer.model.parameters())}") |
|
|
|
if checkpoint_epoch is not None: |
|
while trainer.current_epoch < checkpoint_epoch - 1: |
|
epoch_start_time = time.time() |
|
|
|
|
|
trainer.current_epoch += 1 |
|
epoch_end_time = time.time() |
|
print(f"Dry run load: Epoch {trainer.current_epoch} time: ", epoch_end_time - epoch_start_time) |
|
dist.barrier() |
|
|
|
trainer.training_epoch_end() |
|
trainer.load_model(epoch=checkpoint_epoch, update_count=True) |
|
trainer.load_optimizer(epoch=checkpoint_epoch, optimizer_rank=rank) |
|
print(f"Finished dry run, loaded model from epoch {checkpoint_epoch}") |
|
else: |
|
print("No checkpoint epoch, start from scratch") |
|
checkpoint_epoch = 0 |
|
|
|
dist.barrier() |
|
with Join([trainer.model]): |
|
for i in range(checkpoint_epoch, epochs): |
|
epoch_start_time = time.time() |
|
train_finished = False |
|
trainer.training_epoch_begin() |
|
while not train_finished: |
|
try: |
|
batch_start_time = time.time() |
|
loss = trainer.training_step() |
|
if trainer.global_step % hparams.num_steps_update == 0: |
|
dist.barrier() |
|
|
|
trainer.optimizer_step(loss) |
|
batch_end_time = time.time() |
|
print(f"Rank {rank} batch {trainer.global_step} time: {batch_end_time - batch_start_time}") |
|
if trainer.global_step % save_every_step == 0: |
|
if rank == 0: |
|
trainer.write_model(step=trainer.global_step) |
|
|
|
if trainer.val_dataset is not None: |
|
val_finished = False |
|
val_begin_time = time.time() |
|
trainer.validation_epoch_begin() |
|
while not val_finished: |
|
try: |
|
trainer.validation_step() |
|
except StopIteration: |
|
val_finished = True |
|
val_end_time = time.time() |
|
dist.barrier() |
|
result_dict = trainer.validation_epoch_end(reset_train_loss=True) |
|
print(f"Rank {rank} batch {trainer.global_step} result: {result_dict}") |
|
with open( |
|
f"{hparams.log_dir}/result_dict.batch.{trainer.global_step}.ddp_rank.{rank}.json", "w" |
|
) as f: |
|
json.dump(result_dict, f) |
|
dist.barrier() |
|
all_val_loss = [] |
|
for k in range(world_size): |
|
with open( |
|
f"{hparams.log_dir}/result_dict.batch.{trainer.global_step}.ddp_rank.{k}.json", "r" |
|
) as f: |
|
if trainer.val_dataset is not None: |
|
all_val_loss.append(json.load(f)["val_loss"]) |
|
else: |
|
|
|
all_val_loss.append(json.load(f)["train_loss"]) |
|
print(f"Batch {trainer.global_step} all val loss: {np.mean(all_val_loss)}") |
|
print(f"Batch {trainer.global_step} val time: {val_end_time - val_begin_time}") |
|
trainer.scheduler_step(np.mean(all_val_loss)) |
|
dist.barrier() |
|
except StopIteration: |
|
train_finished = True |
|
|
|
if not trainer.updated: |
|
trainer.optimizer_step(loss) |
|
dist.barrier() |
|
|
|
if trainer.val_dataset is not None: |
|
val_finished = False |
|
trainer.validation_epoch_begin() |
|
while not val_finished: |
|
try: |
|
trainer.validation_step() |
|
dist.barrier() |
|
except StopIteration: |
|
val_finished = True |
|
result_dict = trainer.validation_epoch_end() |
|
print(f"Rank {rank} epoch {i} result: {result_dict}") |
|
with open(f"{hparams.log_dir}/result_dict.epoch.{i}.ddp_rank.{rank}.json", "w") as f: |
|
json.dump(result_dict, f) |
|
|
|
dist.barrier() |
|
trainer.training_epoch_end() |
|
epoch_end_time = time.time() |
|
print(f"Epoch {i} time: ", epoch_end_time - epoch_start_time) |
|
dist.barrier() |
|
if trainer.current_epoch % save_every_epoch == 0: |
|
if rank == 0: |
|
trainer.write_model(epoch=trainer.current_epoch, save_optimizer=True, optimizer_rank=rank) |
|
else: |
|
trainer.write_optimizer(epoch=trainer.current_epoch, optimizer_rank=rank) |
|
|
|
trainer.dataset.clean_up() |
|
cleanup() |
|
|
|
return trainer |
|
|
|
|
|
def single_thread_gpu(rank, model, hparams, dataset, trainer_fn=None, checkpoint_epoch=None, trial_id=None): |
|
|
|
|
|
if isinstance(hparams, dict): |
|
|
|
hparams = sn(**hparams) |
|
|
|
if trial_id is not None: |
|
print(f"Trial id: {trial_id}") |
|
hparams.log_dir = f"{hparams.log_dir}/trial.{trial_id}" |
|
os.makedirs(hparams.log_dir, exist_ok=True) |
|
if hparams.hp_tune: |
|
from ray.air import Checkpoint, session |
|
epochs = hparams.num_epochs |
|
save_every_step = hparams.num_save_batches |
|
save_every_epoch = hparams.num_save_epochs |
|
device = f'cuda:{rank}' |
|
torch.cuda.set_per_process_memory_fraction(1.0, rank) |
|
|
|
|
|
|
|
|
|
model = model.to(device) |
|
model.train() |
|
|
|
trainer = trainer_fn(hparams=hparams, model=model, dataset=dataset, device_id=rank) |
|
print(f"number of trainable parameters: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)}, " + |
|
f"percentage = {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad) / sum(p.numel() for p in trainer.model.parameters())}") |
|
|
|
if checkpoint_epoch is not None: |
|
while trainer.current_epoch < checkpoint_epoch: |
|
epoch_start_time = time.time() |
|
trainer.training_epoch_begin() |
|
trainer.training_epoch_end() |
|
epoch_end_time = time.time() |
|
print(f"Dry run load: Epoch {trainer.current_epoch} time: ", epoch_end_time - epoch_start_time) |
|
trainer.load_model(epoch=checkpoint_epoch, update_count=True) |
|
trainer.load_optimizer(epoch=checkpoint_epoch, optimizer_rank=rank) |
|
print(f"Finished dry run, loaded model from epoch {checkpoint_epoch}") |
|
else: |
|
print("No checkpoint epoch, start from scratch") |
|
checkpoint_epoch = 0 |
|
for i in range(checkpoint_epoch, epochs): |
|
epoch_start_time = time.time() |
|
train_finished = False |
|
trainer.training_epoch_begin() |
|
while not train_finished: |
|
try: |
|
batch_start_time = time.time() |
|
loss = trainer.training_step() |
|
if trainer.global_step % hparams.num_steps_update == 0: |
|
|
|
trainer.optimizer_step(loss) |
|
batch_end_time = time.time() |
|
print(f"Rank {rank} batch {trainer.global_step} time: {batch_end_time - batch_start_time}") |
|
if trainer.global_step % save_every_step == 0: |
|
trainer.write_model(step=trainer.global_step) |
|
|
|
val_finished = False |
|
val_start_time = time.time() |
|
trainer.validation_epoch_begin() |
|
while not val_finished: |
|
try: |
|
trainer.validation_step() |
|
except StopIteration: |
|
val_finished = True |
|
result_dict = trainer.validation_epoch_end() |
|
print(f"Rank {rank} batch {trainer.global_step} result: {result_dict}") |
|
with open( |
|
f"{hparams.log_dir}/result_dict.batch.{trainer.global_step}.ddp_rank.{rank}.json", "w" |
|
) as f: |
|
json.dump(result_dict, f) |
|
all_val_loss = result_dict["val_loss"] |
|
print(f"Batch {trainer.global_step} all val loss: {all_val_loss}") |
|
trainer.scheduler_step(all_val_loss) |
|
|
|
if hparams.hp_tune: |
|
checkpoint_data = { |
|
"epoch": trainer.current_epoch, |
|
"batch": trainer.global_step, |
|
"net_state_dict": trainer.model.state_dict(), |
|
"optimizer_state_dict": trainer.optimizer.state_dict(), |
|
"scheduler_state_dict": trainer.scheduler.state_dict(), |
|
} |
|
checkpoint = Checkpoint.from_dict(checkpoint_data) |
|
session.report( |
|
{"loss": all_val_loss}, |
|
checkpoint=checkpoint, |
|
) |
|
val_end_time = time.time() |
|
print(f"Rank {rank} batch {trainer.global_step} validation time: {val_end_time - val_start_time}") |
|
except StopIteration: |
|
train_finished = True |
|
|
|
if not trainer.updated: |
|
trainer.optimizer_step(loss) |
|
|
|
val_finished = False |
|
trainer.validation_epoch_begin() |
|
while not val_finished: |
|
try: |
|
trainer.validation_step() |
|
except StopIteration: |
|
val_finished = True |
|
result_dict = trainer.validation_epoch_end() |
|
print(f"Rank {rank} epoch {i} result: {result_dict}") |
|
with open(f"{hparams.log_dir}/result_dict.epoch.{i}.ddp_rank.{rank}.json", "w") as f: |
|
json.dump(result_dict, f) |
|
trainer.training_epoch_end() |
|
|
|
all_val_loss = result_dict["val_loss"] |
|
if hparams.hp_tune: |
|
checkpoint_data = { |
|
"epoch": trainer.current_epoch, |
|
"batch": trainer.global_step, |
|
"net_state_dict": trainer.model.state_dict(), |
|
"optimizer_state_dict": trainer.optimizer.state_dict(), |
|
"scheduler_state_dict": trainer.scheduler.state_dict(), |
|
} |
|
checkpoint = Checkpoint.from_dict(checkpoint_data) |
|
session.report( |
|
{"loss": all_val_loss}, |
|
checkpoint=checkpoint, |
|
) |
|
epoch_end_time = time.time() |
|
print(f"Epoch {i} time: ", epoch_end_time - epoch_start_time) |
|
if trainer.current_epoch % save_every_epoch == 0: |
|
trainer.write_model(epoch=trainer.current_epoch, save_optimizer=True, optimizer_rank=rank) |
|
|
|
|
|
trainer.dataset.clean_up() |
|
return trainer |
|
|
|
|
|
def single_thread_gpu_4_fold(rank, model, hparams, dataset, trainer_fn=None, checkpoint_epoch=None): |
|
|
|
|
|
|
|
|
|
if isinstance(hparams, dict): |
|
|
|
hparams = sn(**hparams) |
|
|
|
|
|
|
|
np.random.seed(0) |
|
|
|
gof_indices = dataset.data.index[dataset.data["score"] == 1] |
|
lof_indices = dataset.data.index[dataset.data["score"] == -1] |
|
|
|
|
|
gof_fold_split_sz = max(len(gof_indices) // 4, 1) |
|
lof_fold_split_sz = max(len(lof_indices) // 4, 1) |
|
gof_fold_split = np.split(np.random.permutation(gof_indices), [gof_fold_split_sz, 2*gof_fold_split_sz, 3*gof_fold_split_sz]) |
|
lof_fold_split = np.split(np.random.permutation(lof_indices), [lof_fold_split_sz, 2*lof_fold_split_sz, 3*lof_fold_split_sz]) |
|
|
|
with open(f"{hparams.log_dir}/fold_split.pkl", "wb") as f: |
|
pickle.dump([gof_fold_split, lof_fold_split], f) |
|
main_log_dir = hparams.log_dir |
|
for FOLD in range(4): |
|
print(f"Begin Fold id: {FOLD}") |
|
hparams.log_dir = f"{main_log_dir}/FOLD.{FOLD}/" |
|
hparams.data_split_fn = "_by_anno" |
|
os.makedirs(hparams.log_dir, exist_ok=True) |
|
|
|
dataset_fold = copy.deepcopy(dataset) |
|
|
|
dataset_fold.data["split"] = 'train' |
|
|
|
dataset_fold.data.loc[gof_fold_split[FOLD], "split"] = 'val' |
|
dataset_fold.data.loc[lof_fold_split[FOLD], "split"] = 'val' |
|
|
|
epochs = hparams.num_epochs |
|
save_every_step = hparams.num_save_batches |
|
save_every_epoch = hparams.num_save_epochs |
|
|
|
if os.path.exists(f"{hparams.log_dir}/model.epoch.{epochs}.pt"): |
|
print(f"Fold {FOLD} already trained, skip") |
|
continue |
|
device = f'cuda:{rank}' |
|
torch.cuda.set_per_process_memory_fraction(1.0, rank) |
|
|
|
model_fold = copy.deepcopy(model) |
|
model_fold = model_fold.to(device) |
|
model_fold.train() |
|
|
|
trainer = trainer_fn(hparams=hparams, model=model_fold, dataset=dataset_fold, device_id=rank) |
|
print(f"number of trainable parameters: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)}, " + |
|
f"percentage = {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad) / sum(p.numel() for p in trainer.model.parameters())}") |
|
|
|
for i in range(epochs): |
|
epoch_start_time = time.time() |
|
train_finished = False |
|
trainer.training_epoch_begin() |
|
while not train_finished: |
|
try: |
|
batch_start_time = time.time() |
|
loss = trainer.training_step() |
|
if trainer.global_step % hparams.num_steps_update == 0: |
|
|
|
trainer.optimizer_step(loss) |
|
batch_end_time = time.time() |
|
print(f"Rank {rank} batch {trainer.global_step} time: {batch_end_time - batch_start_time}") |
|
if trainer.global_step % save_every_step == 0: |
|
trainer.write_model(step=trainer.global_step) |
|
|
|
val_finished = False |
|
val_start_time = time.time() |
|
trainer.validation_epoch_begin() |
|
while not val_finished: |
|
try: |
|
trainer.validation_step() |
|
except StopIteration: |
|
val_finished = True |
|
result_dict = trainer.validation_epoch_end() |
|
print(f"Rank {rank} batch {trainer.global_step} result: {result_dict}") |
|
with open( |
|
f"{hparams.log_dir}/result_dict.batch.{trainer.global_step}.ddp_rank.{rank}.json", "w" |
|
) as f: |
|
json.dump(result_dict, f) |
|
all_val_loss = result_dict["val_loss"] |
|
print(f"Batch {trainer.global_step} all val loss: {all_val_loss}") |
|
trainer.scheduler_step(all_val_loss) |
|
|
|
if hparams.hp_tune: |
|
checkpoint_data = { |
|
"epoch": trainer.current_epoch, |
|
"batch": trainer.global_step, |
|
"net_state_dict": trainer.model.state_dict(), |
|
"optimizer_state_dict": trainer.optimizer.state_dict(), |
|
"scheduler_state_dict": trainer.scheduler.state_dict(), |
|
} |
|
checkpoint = Checkpoint.from_dict(checkpoint_data) |
|
session.report( |
|
{"loss": all_val_loss}, |
|
checkpoint=checkpoint, |
|
) |
|
val_end_time = time.time() |
|
print(f"Rank {rank} batch {trainer.global_step} validation time: {val_end_time - val_start_time}") |
|
except StopIteration: |
|
train_finished = True |
|
|
|
if not trainer.updated: |
|
trainer.optimizer_step(loss) |
|
|
|
val_finished = False |
|
trainer.validation_epoch_begin() |
|
while not val_finished: |
|
try: |
|
trainer.validation_step() |
|
except StopIteration: |
|
val_finished = True |
|
result_dict = trainer.validation_epoch_end() |
|
print(f"Rank {rank} epoch {i} result: {result_dict}") |
|
with open(f"{hparams.log_dir}/result_dict.epoch.{i}.ddp_rank.{rank}.json", "w") as f: |
|
json.dump(result_dict, f) |
|
trainer.training_epoch_end() |
|
|
|
all_val_loss = result_dict["val_loss"] |
|
if hparams.hp_tune: |
|
checkpoint_data = { |
|
"epoch": trainer.current_epoch, |
|
"batch": trainer.global_step, |
|
"net_state_dict": trainer.model.state_dict(), |
|
"optimizer_state_dict": trainer.optimizer.state_dict(), |
|
"scheduler_state_dict": trainer.scheduler.state_dict(), |
|
} |
|
checkpoint = Checkpoint.from_dict(checkpoint_data) |
|
session.report( |
|
{"loss": all_val_loss}, |
|
checkpoint=checkpoint, |
|
) |
|
epoch_end_time = time.time() |
|
print(f"Epoch {i} time: ", epoch_end_time - epoch_start_time) |
|
if trainer.current_epoch % save_every_epoch == 0: |
|
trainer.write_model(epoch=trainer.current_epoch, save_optimizer=True, optimizer_rank=rank) |
|
|
|
|
|
trainer.dataset.clean_up() |
|
return trainer |
|
|
|
|
|
def ray_tune(config, dataset=None, trial_id=None): |
|
args = sn(**config) |
|
model_class = args.model_class |
|
|
|
if args.load_model == "None" or args.load_model == "null" or args.load_model is None: |
|
my_model = create_model(config, model_class=model_class) |
|
else: |
|
my_model = create_model_and_load(config, model_class=model_class) |
|
if args.trainer_fn == "PreMode_trainer": |
|
trainer_fn = PreMode_trainer |
|
else: |
|
raise ValueError(f"trainer_fn {args.trainer_fn} not supported") |
|
check_point_epoch = None |
|
return single_thread_gpu(args.gpu_id, my_model, config, dataset, trainer_fn=trainer_fn, checkpoint_epoch=check_point_epoch, trial_id=trial_id) |