GV05's picture
Update app.py
ce00128
raw
history blame
1.05 kB
import torch
from model import VariationalAutoEncoder
from torchvision import transforms
from PIL import Image
import gradio as gr
INPUT_DIM = 784
H_DIM = 512
Z_DIM = 256
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM)
model.load_state_dict(torch.load("MnistVAEmodel.pt"))
model.eval()
def predict(img):
img = img.convert('1')
img = transforms.ToTensor()(img)
img = transforms.CenterCrop(size=28)(img)
print(type(img), img.shape)
mu, sigma = model.encode(img.view(1, INPUT_DIM))
res = []
for example in range(10):
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
out = model.decode(z)
out = out.view(-1,1,28,28)
res.append(transforms.ToPILImage()(out[0]))
return res
title = "Variational-Autoencoder-on-MNIST "
description = "TO DO"
examples = ["original_5.png", "original_8.png"]
gr.Interface(fn=predict, inputs = gr.inputs.Image(), outputs= gr.outputs.Gallery(),
examples=examples, title=title, description=description).launch(inline=False)