import gradio as gr import torch import torchvision.transforms.functional as TF from model import NeuralNetwork import json import os device = "cuda" if torch.cuda.is_available() else "cpu" def pokemon_classifier(inp): model = NeuralNetwork() model.load_state_dict(torch.load("model_best.pt", map_location=torch.device(device))) model.eval() with open('labels.json') as f: labels = json.load(f) x = TF.to_tensor(inp) x = TF.resize(x, (64, 64), antialias=True) x = x.to(device) x = x.unsqueeze(0) with torch.no_grad(): y_pred = model(x) pokemon = torch.argmax(y_pred, dim=1).item() return labels[str(pokemon)] with gr.Blocks() as demo: gr.Markdown("# Gen 1 Pokemon classifier") with gr.Column(scale=4): inp = gr.Image(type="pil") out = gr.Textbox(label='Pokemon') gr.Examples( examples=[ os.path.join(os.path.dirname(__file__), "images/Aerodactyl.jpg"), os.path.join(os.path.dirname(__file__), "images/Bulbasaur.jpg"), os.path.join(os.path.dirname(__file__), "images/Charizard.jpg") ], inputs=inp, outputs=out, fn=pokemon_classifier, cache_examples=False ) btn = gr.Button("Run") btn.click(fn=pokemon_classifier, inputs=inp, outputs=out) demo.launch()