File size: 1,050 Bytes
093675e
 
 
 
 
 
 
 
 
 
 
 
ce00128
093675e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from model import VariationalAutoEncoder
from torchvision import transforms
from PIL import Image
import gradio as gr


INPUT_DIM = 784
H_DIM = 512
Z_DIM = 256

model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM)
model.load_state_dict(torch.load("MnistVAEmodel.pt"))
model.eval()
def predict(img):
    img = img.convert('1')
    img = transforms.ToTensor()(img)
    img = transforms.CenterCrop(size=28)(img)
    print(type(img), img.shape)
    mu, sigma = model.encode(img.view(1, INPUT_DIM))

    res = []
    for example in range(10):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1,1,28,28)
        res.append(transforms.ToPILImage()(out[0]))
    return res

title = "Variational-Autoencoder-on-MNIST "
description = "TO DO"
examples = ["original_5.png", "original_8.png"]
gr.Interface(fn=predict, inputs = gr.inputs.Image(), outputs= gr.outputs.Gallery(),
             examples=examples, title=title, description=description).launch(inline=False)