File size: 947 Bytes
093675e
 
 
 
 
 
 
 
 
 
 
 
ce00128
093675e
 
 
 
 
 
af4c588
093675e
 
 
 
 
 
 
 
 
af4c588
 
093675e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from model import VariationalAutoEncoder
from torchvision import transforms
from PIL import Image
import gradio as gr


INPUT_DIM = 784
H_DIM = 512
Z_DIM = 256

model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM)
model.load_state_dict(torch.load("MnistVAEmodel.pt"))
model.eval()
def predict(img):
    img = transforms.ToTensor()(img)
    mu, sigma = model.encode(img.view(1, INPUT_DIM))

    res = []
    for example in range(5):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1,1,28,28)
        res.append(transforms.ToPILImage()(out[0]))
    return res

title = "Variational-Autoencoder-on-MNIST "
description = "TO DO"
examples = ["original_8.png"]
gr.Interface(fn=predict, inputs = gr.inputs.Image(shape=(28,28), image_mode="L"), outputs= gr.Gallery(),
             examples=examples, title=title, description=description).launch(inline=False)