Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
class VariationalAutoEncoder(nn.Module): | |
# Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image | |
def __init__(self, inpud_dim, h_dim=200, z_dim=20): | |
super().__init__() | |
# encoder | |
self.img_2hid = nn.Linear(inpud_dim, h_dim) | |
self.hid_2mu = nn.Linear(h_dim, z_dim) | |
self.hid_2sigma = nn.Linear(h_dim, z_dim) | |
# decoder | |
self.z_2hi = nn.Linear(z_dim, h_dim) | |
self.hid_2img = nn.Linear(h_dim, inpud_dim) | |
self.relu = nn.ReLU() | |
def encode(self, x): | |
# q_phi(z/x) | |
h = self.relu(self.img_2hid(x)) | |
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h) | |
return mu, sigma | |
def decode(self, z): | |
# p_theta(x/z) | |
h = self.relu(self.z_2hi(z)) | |
x = self.hid_2img(h) | |
return torch.sigmoid(x) # image values should be between zero and one. | |
def forward(self, x): | |
mu, sigma = self.encode(x) | |
# parametirazation trick | |
epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1 | |
z_reparametrized = mu + sigma * epsilon | |
x_reconstructed = self.decode(z_reparametrized) | |
return x_reconstructed, mu, sigma # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x | |
if __name__ == "__main__": | |
x = torch.randn(4,28*28) | |
vae = VariationalAutoEncoder(inpud_dim=784) | |
x_reconstructed, mu, sigma = vae(x) | |
print(x_reconstructed.shape) | |
print(mu.shape) | |
print(sigma.shape) | |
print(torch.mean(mu)) | |