import gradio as gr import torch from model import model, classes from torchvision import transforms checkpoint = torch.load('mnist_model.pth', map_location=torch.device('cpu')) # Load the state dictionary into model model.load_state_dict(checkpoint['model_state_dict']) # Set your model to evaluation mode model.eval() def preprocess_image(image): gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) threshold = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2) resized = cv2.resize(threshold, (28, 28), interpolation=cv2.INTER_AREA) tensor = transforms.ToTensor()(resized).unsqueeze(0) tensor = transforms.Normalize((0.5,), (0.5,))(tensor) return tensor def classify(image): tensor = preprocess_image(image) with torch.no_grad(): output = model(tensor) prediction = output.argmax(dim=1, keepdim=True).item() return str(prediction) # Convert prediction to string iface = gr.Interface( fn=classify, inputs="sketchpad", outputs='label', theme="huggingface", title="Digit Recognition", description="Draw a Digit 0-9 and the algorithm will detect it in real time!", article="
Digit Recognition | Demo Model by Jugal
", live=True) iface.launch(debug=True)