import torch from model import VariationalAutoEncoder from torchvision import transforms from PIL import Image import gradio as gr INPUT_DIM = 784 H_DIM = 512 Z_DIM = 256 model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM) model.load_state_dict(torch.load("MnistVAEmodel.pt")) model.eval() def predict(img): img = transforms.ToTensor()(img) mu, sigma = model.encode(img.view(1, INPUT_DIM)) res = [] for example in range(5): epsilon = torch.randn_like(sigma) z = mu + sigma * epsilon out = model.decode(z) out = out.view(-1,1,28,28) res.append(transforms.ToPILImage()(out[0])) return res title = "Variational-Autoencoder-on-MNIST " description = "TO DO" examples = ["original_8.png"] gr.Interface(fn=predict, inputs = gr.inputs.Image(shape=(28,28), image_mode="L"), outputs= gr.Gallery(), examples=examples, title=title, description=description).launch(inline=False)