Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchvision.transforms.functional as TF | |
from model import NeuralNetwork | |
import json | |
import os | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def pokemon_classifier(inp): | |
model = NeuralNetwork() | |
model.load_state_dict(torch.load("model_best.pt", map_location=torch.device(device))) | |
model.eval() | |
with open('labels.json') as f: | |
labels = json.load(f) | |
x = TF.to_tensor(inp) | |
x = TF.resize(x, (64, 64), antialias=True) | |
x = x.to(device) | |
x = x.unsqueeze(0) | |
with torch.no_grad(): | |
y_pred = model(x) | |
pokemon = torch.argmax(y_pred, dim=1).item() | |
return labels[str(pokemon)] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Gen 1 Pokemon classifier") | |
with gr.Column(scale=4): | |
inp = gr.Image(type="pil") | |
out = gr.Textbox(label='Pokemon') | |
gr.Examples( | |
examples=[ | |
os.path.join(os.path.dirname(__file__), "images/Aerodactyl.jpg"), | |
os.path.join(os.path.dirname(__file__), "images/Bulbasaur.jpg"), | |
os.path.join(os.path.dirname(__file__), "images/Charizard.jpg") | |
], | |
inputs=inp, | |
outputs=out, | |
fn=pokemon_classifier, | |
cache_examples=False | |
) | |
btn = gr.Button("Run") | |
btn.click(fn=pokemon_classifier, inputs=inp, outputs=out) | |
demo.launch() |