Spaces:
Runtime error
Runtime error
built space
Browse files- MnistVAEmodel.pt +3 -0
- app.py +35 -0
- model.py +50 -0
- original_5.png +0 -0
- original_8.png +0 -0
- requirements.txt +3 -0
MnistVAEmodel.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d6ab1a824858a37b3dbeffce09cd2de481906e689b4817e505cb2550e992d3d
|
3 |
+
size 4796991
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from model import VariationalAutoEncoder
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
|
8 |
+
INPUT_DIM = 784
|
9 |
+
H_DIM = 512
|
10 |
+
Z_DIM = 256
|
11 |
+
|
12 |
+
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM)
|
13 |
+
model.load_state_dict(torch.load("MnistVAEmodel.pth"))
|
14 |
+
model.eval()
|
15 |
+
def predict(img):
|
16 |
+
img = img.convert('1')
|
17 |
+
img = transforms.ToTensor()(img)
|
18 |
+
img = transforms.CenterCrop(size=28)(img)
|
19 |
+
print(type(img), img.shape)
|
20 |
+
mu, sigma = model.encode(img.view(1, INPUT_DIM))
|
21 |
+
|
22 |
+
res = []
|
23 |
+
for example in range(10):
|
24 |
+
epsilon = torch.randn_like(sigma)
|
25 |
+
z = mu + sigma * epsilon
|
26 |
+
out = model.decode(z)
|
27 |
+
out = out.view(-1,1,28,28)
|
28 |
+
res.append(transforms.ToPILImage()(out[0]))
|
29 |
+
return res
|
30 |
+
|
31 |
+
title = "Variational-Autoencoder-on-MNIST "
|
32 |
+
description = "TO DO"
|
33 |
+
examples = ["original_5.png", "original_8.png"]
|
34 |
+
gr.Interface(fn=predict, inputs = gr.inputs.Image(), outputs= gr.outputs.Gallery(),
|
35 |
+
examples=examples, title=title, description=description).launch(inline=False)
|
model.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class VariationalAutoEncoder(nn.Module):
|
6 |
+
# Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image
|
7 |
+
def __init__(self, inpud_dim, h_dim=200, z_dim=20):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
# encoder
|
11 |
+
self.img_2hid = nn.Linear(inpud_dim, h_dim)
|
12 |
+
self.hid_2mu = nn.Linear(h_dim, z_dim)
|
13 |
+
self.hid_2sigma = nn.Linear(h_dim, z_dim)
|
14 |
+
|
15 |
+
# decoder
|
16 |
+
self.z_2hi = nn.Linear(z_dim, h_dim)
|
17 |
+
self.hid_2img = nn.Linear(h_dim, inpud_dim)
|
18 |
+
|
19 |
+
self.relu = nn.ReLU()
|
20 |
+
|
21 |
+
def encode(self, x):
|
22 |
+
# q_phi(z/x)
|
23 |
+
h = self.relu(self.img_2hid(x))
|
24 |
+
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
|
25 |
+
|
26 |
+
return mu, sigma
|
27 |
+
|
28 |
+
def decode(self, z):
|
29 |
+
# p_theta(x/z)
|
30 |
+
h = self.relu(self.z_2hi(z))
|
31 |
+
x = self.hid_2img(h)
|
32 |
+
return torch.sigmoid(x) # image values should be between zero and one.
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
mu, sigma = self.encode(x)
|
36 |
+
# parametirazation trick
|
37 |
+
epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1
|
38 |
+
z_reparametrized = mu + sigma * epsilon
|
39 |
+
x_reconstructed = self.decode(z_reparametrized)
|
40 |
+
return x_reconstructed, mu, sigma # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
|
44 |
+
x = torch.randn(4,28*28)
|
45 |
+
vae = VariationalAutoEncoder(inpud_dim=784)
|
46 |
+
x_reconstructed, mu, sigma = vae(x)
|
47 |
+
print(x_reconstructed.shape)
|
48 |
+
print(mu.shape)
|
49 |
+
print(sigma.shape)
|
50 |
+
print(torch.mean(mu))
|
original_5.png
ADDED
original_8.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|