| | """ |
| | Full definition of a VAE model, all of it in this single file. |
| | References: |
| | 1) An Introduction to Variational Autoencoders: |
| | https://arxiv.org/abs/1906.02691 |
| | """ |
| |
|
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class VAE(nn.Module): |
| | """VAE for 64x64 face generation. |
| | |
| | The hidden dimensions can be tuned. |
| | """ |
| |
|
| | def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None: |
| | super().__init__() |
| |
|
| | |
| | prev_channels = 3 |
| | modules = [] |
| | img_length = 64 |
| | for cur_channels in hiddens: |
| | modules.append( |
| | nn.Sequential( |
| | nn.Conv2d(prev_channels, |
| | cur_channels, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1), nn.BatchNorm2d(cur_channels), |
| | nn.ReLU())) |
| | prev_channels = cur_channels |
| | img_length //= 2 |
| | self.encoder = nn.Sequential(*modules) |
| | self.mean_linear = nn.Linear(prev_channels * img_length * img_length, |
| | latent_dim) |
| | self.var_linear = nn.Linear(prev_channels * img_length * img_length, |
| | latent_dim) |
| | self.latent_dim = latent_dim |
| |
|
| | |
| | modules = [] |
| | self.decoder_projection = nn.Linear( |
| | latent_dim, prev_channels * img_length * img_length) |
| | self.decoder_input_chw = (prev_channels, img_length, img_length) |
| | for i in range(len(hiddens) - 1, 0, -1): |
| | modules.append( |
| | nn.Sequential( |
| | nn.ConvTranspose2d(hiddens[i], |
| | hiddens[i - 1], |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | output_padding=1), |
| | nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU())) |
| | modules.append( |
| | nn.Sequential( |
| | nn.ConvTranspose2d(hiddens[0], |
| | hiddens[0], |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | output_padding=1), |
| | nn.BatchNorm2d(hiddens[0]), nn.ReLU(), |
| | nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1), |
| | nn.ReLU())) |
| | self.decoder = nn.Sequential(*modules) |
| |
|
| | def forward(self, x): |
| | encoded = self.encoder(x) |
| | encoded = torch.flatten(encoded, 1) |
| | mean = self.mean_linear(encoded) |
| | logvar = self.var_linear(encoded) |
| | eps = torch.randn_like(logvar) |
| | std = torch.exp(logvar / 2) |
| | z = eps * std + mean |
| | x = self.decoder_projection(z) |
| | x = torch.reshape(x, (-1, *self.decoder_input_chw)) |
| | decoded = self.decoder(x) |
| |
|
| | return decoded, mean, logvar |
| |
|
| | def sample(self, device='cuda'): |
| | z = torch.randn(1, self.latent_dim).to(device) |
| | x = self.decoder_projection(z) |
| | x = torch.reshape(x, (-1, *self.decoder_input_chw)) |
| | decoded = self.decoder(x) |
| | return decoded |