File size: 1,709 Bytes
093675e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch import nn


class VariationalAutoEncoder(nn.Module):
    # Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image
    def __init__(self, inpud_dim, h_dim=200, z_dim=20):
        super().__init__()

        # encoder
        self.img_2hid = nn.Linear(inpud_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hi = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, inpud_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        # q_phi(z/x)
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)

        return mu, sigma

    def decode(self, z):
        # p_theta(x/z)
        h = self.relu(self.z_2hi(z))
        x = self.hid_2img(h)
        return torch.sigmoid(x) # image values should be between zero and one.

    def forward(self, x):
        mu, sigma = self.encode(x)
        # parametirazation trick
        epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1
        z_reparametrized = mu + sigma * epsilon
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, sigma  # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x

if __name__ == "__main__":

    x = torch.randn(4,28*28)
    vae = VariationalAutoEncoder(inpud_dim=784)
    x_reconstructed, mu, sigma = vae(x)
    print(x_reconstructed.shape)
    print(mu.shape)
    print(sigma.shape)
    print(torch.mean(mu))