|
""" |
|
Full definition of a VAE model, all of it in this single file. |
|
References: |
|
1) An Introduction to Variational Autoencoders: |
|
https://arxiv.org/abs/1906.02691 |
|
""" |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class VAE(nn.Module): |
|
"""VAE for 64x64 face generation. |
|
|
|
The hidden dimensions can be tuned. |
|
""" |
|
|
|
def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None: |
|
super().__init__() |
|
|
|
|
|
prev_channels = 3 |
|
modules = [] |
|
img_length = 64 |
|
for cur_channels in hiddens: |
|
modules.append( |
|
nn.Sequential( |
|
nn.Conv2d(prev_channels, |
|
cur_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1), nn.BatchNorm2d(cur_channels), |
|
nn.ReLU())) |
|
prev_channels = cur_channels |
|
img_length //= 2 |
|
self.encoder = nn.Sequential(*modules) |
|
self.mean_linear = nn.Linear(prev_channels * img_length * img_length, |
|
latent_dim) |
|
self.var_linear = nn.Linear(prev_channels * img_length * img_length, |
|
latent_dim) |
|
self.latent_dim = latent_dim |
|
|
|
|
|
modules = [] |
|
self.decoder_projection = nn.Linear( |
|
latent_dim, prev_channels * img_length * img_length) |
|
self.decoder_input_chw = (prev_channels, img_length, img_length) |
|
for i in range(len(hiddens) - 1, 0, -1): |
|
modules.append( |
|
nn.Sequential( |
|
nn.ConvTranspose2d(hiddens[i], |
|
hiddens[i - 1], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
output_padding=1), |
|
nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU())) |
|
modules.append( |
|
nn.Sequential( |
|
nn.ConvTranspose2d(hiddens[0], |
|
hiddens[0], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
output_padding=1), |
|
nn.BatchNorm2d(hiddens[0]), nn.ReLU(), |
|
nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU())) |
|
self.decoder = nn.Sequential(*modules) |
|
|
|
def forward(self, x): |
|
encoded = self.encoder(x) |
|
encoded = torch.flatten(encoded, 1) |
|
mean = self.mean_linear(encoded) |
|
logvar = self.var_linear(encoded) |
|
eps = torch.randn_like(logvar) |
|
std = torch.exp(logvar / 2) |
|
z = eps * std + mean |
|
x = self.decoder_projection(z) |
|
x = torch.reshape(x, (-1, *self.decoder_input_chw)) |
|
decoded = self.decoder(x) |
|
|
|
return decoded, mean, logvar |
|
|
|
def sample(self, device='cuda'): |
|
z = torch.randn(1, self.latent_dim).to(device) |
|
x = self.decoder_projection(z) |
|
x = torch.reshape(x, (-1, *self.decoder_input_chw)) |
|
decoded = self.decoder(x) |
|
return decoded |