GV05's picture
Update app.py
af4c588
raw
history blame contribute delete
947 Bytes
import torch
from model import VariationalAutoEncoder
from torchvision import transforms
from PIL import Image
import gradio as gr
INPUT_DIM = 784
H_DIM = 512
Z_DIM = 256
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM)
model.load_state_dict(torch.load("MnistVAEmodel.pt"))
model.eval()
def predict(img):
img = transforms.ToTensor()(img)
mu, sigma = model.encode(img.view(1, INPUT_DIM))
res = []
for example in range(5):
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
out = model.decode(z)
out = out.view(-1,1,28,28)
res.append(transforms.ToPILImage()(out[0]))
return res
title = "Variational-Autoencoder-on-MNIST "
description = "TO DO"
examples = ["original_8.png"]
gr.Interface(fn=predict, inputs = gr.inputs.Image(shape=(28,28), image_mode="L"), outputs= gr.Gallery(),
examples=examples, title=title, description=description).launch(inline=False)