Spaces:
Sleeping
Sleeping
File size: 4,377 Bytes
b2ffc9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from typing import Optional, Type
import torch
import torch.nn as nn
import pyro
import pyro.infer as infer
import pyro.optim as optim
import warnings
#from vae_model import set_deterministic_mode as set_deterministic_mode
from atoms_detection.vae_model import set_deterministic_mode as set_deterministic_mode
warnings.filterwarnings("ignore", module="torchvision.datasets")
class SVItrainer:
"""
Stochastic variational inference (SVI) trainer for
unsupervised and class-conditioned variational models
"""
def __init__(self,
model: Type[nn.Module],
optimizer: Type[optim.PyroOptim] = None,
loss: Type[infer.ELBO] = None,
seed: int = 1
) -> None:
"""
Initializes the trainer's parameters
"""
pyro.clear_param_store()
set_deterministic_mode(seed)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if optimizer is None:
optimizer = optim.Adam({"lr": 1.0e-3})
if loss is None:
loss = infer.Trace_ELBO()
self.svi = infer.SVI(model.model, model.guide, optimizer, loss=loss)
self.loss_history = {"training_loss": [], "test_loss": []}
self.current_epoch = 0
def train(self,
train_loader: Type[torch.utils.data.DataLoader],
**kwargs: float) -> float:
"""
Trains a single epoch
"""
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch returned by the data loader
for data in train_loader:
if len(data) == 1: # VAE mode
x = data[0]
loss = self.svi.step(x.to(self.device), **kwargs)
else: # VED or cVAE mode
x, y = data
loss = self.svi.step(
x.to(self.device), y.to(self.device), **kwargs)
# do ELBO gradient and accumulate loss
epoch_loss += loss
return epoch_loss / len(train_loader.dataset)
def evaluate(self,
test_loader: Type[torch.utils.data.DataLoader],
**kwargs: float) -> float:
"""
Evaluates current models state on a single epoch
"""
# initialize loss accumulator
test_loss = 0.
# compute the loss over the entire test set
with torch.no_grad():
for data in test_loader:
if len(data) == 1: # VAE mode
x = data[0]
loss = self.svi.step(x.to(self.device), **kwargs)
else: # VED or cVAE mode
x, y = data
loss = self.svi.step(
x.to(self.device), y.to(self.device), **kwargs)
test_loss += loss
return test_loss / len(test_loader.dataset)
def step(self,
train_loader: Type[torch.utils.data.DataLoader],
test_loader: Optional[Type[torch.utils.data.DataLoader]] = None,
**kwargs: float) -> None:
"""
Single training and (optionally) evaluation step
"""
self.loss_history["training_loss"].append(self.train(train_loader, **kwargs))
if test_loader is not None:
self.loss_history["test_loss"].append(self.evaluate(test_loader, **kwargs))
self.current_epoch += 1
def print_statistics(self) -> None:
"""
Prints training and test (if any) losses for current epoch
"""
e = self.current_epoch
if len(self.loss_history["test_loss"]) > 0:
template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}'
print(template.format(e, self.loss_history["training_loss"][-1],
self.loss_history["test_loss"][-1]))
else:
template = 'Epoch: {} Training loss: {:.4f}'
print(template.format(e, self.loss_history["training_loss"][-1]))
def init_dataloader(*args: torch.Tensor, **kwargs: int
) -> Type[torch.utils.data.DataLoader]:
batch_size = kwargs.get("batch_size", 100)
tensor_set = torch.utils.data.dataset.TensorDataset(*args)
data_loader = torch.utils.data.DataLoader(
dataset=tensor_set, batch_size=batch_size, shuffle=True)
return data_loader
|