""" Full definition of a VAE model, all of it in this single file. References: 1) An Introduction to Variational Autoencoders: https://arxiv.org/abs/1906.02691 """ import torch import torch.nn as nn class VAE(nn.Module): """VAE for 64x64 face generation. The hidden dimensions can be tuned. """ def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None: super().__init__() # encoder prev_channels = 3 modules = [] img_length = 64 for cur_channels in hiddens: modules.append( nn.Sequential( nn.Conv2d(prev_channels, cur_channels, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(cur_channels), nn.ReLU())) prev_channels = cur_channels img_length //= 2 self.encoder = nn.Sequential(*modules) self.mean_linear = nn.Linear(prev_channels * img_length * img_length, latent_dim) self.var_linear = nn.Linear(prev_channels * img_length * img_length, latent_dim) self.latent_dim = latent_dim # decoder modules = [] self.decoder_projection = nn.Linear( latent_dim, prev_channels * img_length * img_length) self.decoder_input_chw = (prev_channels, img_length, img_length) for i in range(len(hiddens) - 1, 0, -1): modules.append( nn.Sequential( nn.ConvTranspose2d(hiddens[i], hiddens[i - 1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU())) modules.append( nn.Sequential( nn.ConvTranspose2d(hiddens[0], hiddens[0], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hiddens[0]), nn.ReLU(), nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1), nn.ReLU())) self.decoder = nn.Sequential(*modules) def forward(self, x): encoded = self.encoder(x) encoded = torch.flatten(encoded, 1) mean = self.mean_linear(encoded) logvar = self.var_linear(encoded) eps = torch.randn_like(logvar) std = torch.exp(logvar / 2) z = eps * std + mean x = self.decoder_projection(z) x = torch.reshape(x, (-1, *self.decoder_input_chw)) decoded = self.decoder(x) return decoded, mean, logvar def sample(self, device='cuda'): z = torch.randn(1, self.latent_dim).to(device) x = self.decoder_projection(z) x = torch.reshape(x, (-1, *self.decoder_input_chw)) decoded = self.decoder(x) return decoded