import torch from model import VariationalAutoEncoder from torchvision import transforms from PIL import Image import gradio as gr INPUT_DIM = 784 H_DIM = 512 Z_DIM = 256 model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM) model.load_state_dict(torch.load("MnistVAEmodel.pt")) model.eval() def predict(img): img = img.convert('1') img = transforms.ToTensor()(img) img = transforms.CenterCrop(size=28)(img) print(type(img), img.shape) mu, sigma = model.encode(img.view(1, INPUT_DIM)) res = [] for example in range(10): epsilon = torch.randn_like(sigma) z = mu + sigma * epsilon out = model.decode(z) out = out.view(-1,1,28,28) res.append(transforms.ToPILImage()(out[0])) return res title = "Variational-Autoencoder-on-MNIST " description = "TO DO" examples = ["original_5.png", "original_8.png"] gr.Interface(fn=predict, inputs = gr.inputs.Image(), outputs= gr.outputs.Gallery(), examples=examples, title=title, description=description).launch(inline=False)