Jugal-sheth's picture
Update app.py
42632eb
import gradio as gr
import torch
from PIL import Image
from model import model
from torchvision import transforms
# Load your own model
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
image = Image.fromarray(image)
tensor = transform(image).unsqueeze(0)
return tensor
def classify(image):
tensor = preprocess_image(image)
with torch.no_grad():
output = model(tensor)
prediction = output.argmax(dim=1, keepdim=True).item()
return str(prediction) # Convert prediction to string
iface = gr.Interface(
fn=classify,
inputs="sketchpad",
outputs='label',
theme="huggingface",
title="Digit Recognition",
description="Draw a Digit 0-9 and the algorithm will detect it in real time! This is tiny model Kindly Draw digits in center of drawing area",
article="<p style='text-align: center'>Digit Recognition | Demo Model by Jugal</p>",
live=True)
iface.launch(debug=True)