from typing import Optional, Type import torch import torch.nn as nn import pyro import pyro.infer as infer import pyro.optim as optim import warnings #from vae_model import set_deterministic_mode as set_deterministic_mode from atoms_detection.vae_model import set_deterministic_mode as set_deterministic_mode warnings.filterwarnings("ignore", module="torchvision.datasets") class SVItrainer: """ Stochastic variational inference (SVI) trainer for unsupervised and class-conditioned variational models """ def __init__(self, model: Type[nn.Module], optimizer: Type[optim.PyroOptim] = None, loss: Type[infer.ELBO] = None, seed: int = 1 ) -> None: """ Initializes the trainer's parameters """ pyro.clear_param_store() set_deterministic_mode(seed) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if optimizer is None: optimizer = optim.Adam({"lr": 1.0e-3}) if loss is None: loss = infer.Trace_ELBO() self.svi = infer.SVI(model.model, model.guide, optimizer, loss=loss) self.loss_history = {"training_loss": [], "test_loss": []} self.current_epoch = 0 def train(self, train_loader: Type[torch.utils.data.DataLoader], **kwargs: float) -> float: """ Trains a single epoch """ # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch returned by the data loader for data in train_loader: if len(data) == 1: # VAE mode x = data[0] loss = self.svi.step(x.to(self.device), **kwargs) else: # VED or cVAE mode x, y = data loss = self.svi.step( x.to(self.device), y.to(self.device), **kwargs) # do ELBO gradient and accumulate loss epoch_loss += loss return epoch_loss / len(train_loader.dataset) def evaluate(self, test_loader: Type[torch.utils.data.DataLoader], **kwargs: float) -> float: """ Evaluates current models state on a single epoch """ # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set with torch.no_grad(): for data in test_loader: if len(data) == 1: # VAE mode x = data[0] loss = self.svi.step(x.to(self.device), **kwargs) else: # VED or cVAE mode x, y = data loss = self.svi.step( x.to(self.device), y.to(self.device), **kwargs) test_loss += loss return test_loss / len(test_loader.dataset) def step(self, train_loader: Type[torch.utils.data.DataLoader], test_loader: Optional[Type[torch.utils.data.DataLoader]] = None, **kwargs: float) -> None: """ Single training and (optionally) evaluation step """ self.loss_history["training_loss"].append(self.train(train_loader, **kwargs)) if test_loader is not None: self.loss_history["test_loss"].append(self.evaluate(test_loader, **kwargs)) self.current_epoch += 1 def print_statistics(self) -> None: """ Prints training and test (if any) losses for current epoch """ e = self.current_epoch if len(self.loss_history["test_loss"]) > 0: template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}' print(template.format(e, self.loss_history["training_loss"][-1], self.loss_history["test_loss"][-1])) else: template = 'Epoch: {} Training loss: {:.4f}' print(template.format(e, self.loss_history["training_loss"][-1])) def init_dataloader(*args: torch.Tensor, **kwargs: int ) -> Type[torch.utils.data.DataLoader]: batch_size = kwargs.get("batch_size", 100) tensor_set = torch.utils.data.dataset.TensorDataset(*args) data_loader = torch.utils.data.DataLoader( dataset=tensor_set, batch_size=batch_size, shuffle=True) return data_loader