GV05's picture
built space
093675e
raw
history blame
1.71 kB
import torch
from torch import nn
class VariationalAutoEncoder(nn.Module):
# Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image
def __init__(self, inpud_dim, h_dim=200, z_dim=20):
super().__init__()
# encoder
self.img_2hid = nn.Linear(inpud_dim, h_dim)
self.hid_2mu = nn.Linear(h_dim, z_dim)
self.hid_2sigma = nn.Linear(h_dim, z_dim)
# decoder
self.z_2hi = nn.Linear(z_dim, h_dim)
self.hid_2img = nn.Linear(h_dim, inpud_dim)
self.relu = nn.ReLU()
def encode(self, x):
# q_phi(z/x)
h = self.relu(self.img_2hid(x))
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
return mu, sigma
def decode(self, z):
# p_theta(x/z)
h = self.relu(self.z_2hi(z))
x = self.hid_2img(h)
return torch.sigmoid(x) # image values should be between zero and one.
def forward(self, x):
mu, sigma = self.encode(x)
# parametirazation trick
epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1
z_reparametrized = mu + sigma * epsilon
x_reconstructed = self.decode(z_reparametrized)
return x_reconstructed, mu, sigma # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x
if __name__ == "__main__":
x = torch.randn(4,28*28)
vae = VariationalAutoEncoder(inpud_dim=784)
x_reconstructed, mu, sigma = vae(x)
print(x_reconstructed.shape)
print(mu.shape)
print(sigma.shape)
print(torch.mean(mu))