# 1. Imports and class names setup import gradio as gr import os import torch from model import create_effnetb2_model from timeit import default_timer as timer from typing import Tuple , Dict # Setup class names class_names = ['pizza','steak','sushi'] # Model and transforms preparation # Create EffNetB2 model effnetb2 , effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names)) # load and save weights <<<<<<< HEAD effnetb2.load_state_dict(torch.load(os.path.join("effnetb2.pth"),map_location=torch.device('cpu'))) ======= effnetb2.load_state_dict(torch.load("effnetb2.pth",map_location=torch.device('cpu'))) >>>>>>> f57d3888756f20e9db37eb8ce02739685876fb20 # Predict function def predict(img): """ Transforms and performs a prediction on img and returns prediction and time taken. """ # Start timer start_time = timer() # transform the target image and add a batch dimension img = effnetb2_transforms(img).unsqueeze(0) # put model into evaluation mode and turn on inference mode effnetb2.eval() with torch.inference_mode(): # pass the transformed image through the model and turn the pred logits into prediction probabilities pred_probs = torch.softmax(effnetb2(img), dim=1) # create a prediction label and prediction probability dictionary pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} # calculate time pred_time = round(timer() - start_time , 5) # return the prediction dictionary return pred_labels_and_probs, pred_time ## Gradio app # Create title, description and article strings title = "FoodVision Mini 🍕🥩🍣" description = "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi." article = "Created " # Create examples list from "examples/" directory #example_list = [["examples/" + example] for example in os.listdir("examples")] # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=gr.Image(type="pil"), # what are the inputs? outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs # Create examples list from "examples/" directory #examples=example_list, title=title, description=description, article=article) # Launch the demo! demo.launch()