import gradio as gr import torch from model import model, classes from torchvision import transforms import torch.nn.functional as F checkpoint = torch.load('model.pth', map_location=torch.device('cpu')) # Load the state dictionary into model model.load_state_dict(checkpoint['model_state_dict']) # Set model to evaluation mode model.eval() # Transforms transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) import torch import torch.nn.functional as F def classify(img): # Transform x = transform(img) # Add batch dim x = x.unsqueeze(0) # Get predictions preds = model(x) # Get prediction percentages perc = F.softmax(preds, dim=1)[0] * 100 # Get index of top prediction idx = torch.argmax(preds) # Get percentage of top prediction top_perc = perc[idx].item() # Get class name class_name = classes[idx] # Return prediction with percentage return f"{class_name} ({top_perc:.2f}%)" iface = gr.Interface(classify, "image", "text", theme="huggingface", title="Digit Recognition", description="Upload Image of any Airplane Automobile Bird Cat Deer Dog Frog Horse Ship Truck and the algorithm will detect it in real time! This is CNN trained on CIFAR10 Dataset", article="
CIFAR10 Classification | Demo Model by Jugal
",live=True) iface.launch(debug=True)