Spaces:
Runtime error
Runtime error
import torch | |
from model import VariationalAutoEncoder | |
from torchvision import transforms | |
from PIL import Image | |
import gradio as gr | |
INPUT_DIM = 784 | |
H_DIM = 512 | |
Z_DIM = 256 | |
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM) | |
model.load_state_dict(torch.load("MnistVAEmodel.pth")) | |
model.eval() | |
def predict(img): | |
img = img.convert('1') | |
img = transforms.ToTensor()(img) | |
img = transforms.CenterCrop(size=28)(img) | |
print(type(img), img.shape) | |
mu, sigma = model.encode(img.view(1, INPUT_DIM)) | |
res = [] | |
for example in range(10): | |
epsilon = torch.randn_like(sigma) | |
z = mu + sigma * epsilon | |
out = model.decode(z) | |
out = out.view(-1,1,28,28) | |
res.append(transforms.ToPILImage()(out[0])) | |
return res | |
title = "Variational-Autoencoder-on-MNIST " | |
description = "TO DO" | |
examples = ["original_5.png", "original_8.png"] | |
gr.Interface(fn=predict, inputs = gr.inputs.Image(), outputs= gr.outputs.Gallery(), | |
examples=examples, title=title, description=description).launch(inline=False) |