# import the essentials from demos.foodvision_mini.model import create_vit_b_16_model import torch import torchvision import time import gradio as gr import numpy as np from pathlib import Path class_names = ['pizza','steak','sushi'] device = 'cuda' if torch.cuda.is_available else 'cpu' # creating the vit_b_16_model and loading it with state_dict of our trained model vit_b_16_model,vit_b_16_transform = create_vit_b_16_model(num_classes=3) vit_b_16_model.load_state_dict(torch.load(f='vit_b_16_20_percent_data.pth')) # create the predict function def predict(img): """ args: img: is an image returns: prediction class, prediction probability, and time taken to make the prediction """ # transforming the image tr_img = vit_b_16_transform(img).unsqueeze(dim=0).to(device) # make prediction with vit_b_16 model = vit_b_16_model.to(device) # starting the time start_time = time.perf_counter() model.eval() with torch.inference_mode(): pred_logit = model(tr_img) pred_label = torch.argmax(pred_logit,dim=1).cpu() pred_prob = torch.max(torch.softmax(pred_logit,dim=1)).cpu().item() # ending the time end_time = time.perf_counter() # pred_dict = {str(class_names[i]):float(pred_prob[0][i].item()) for i in range(len(class_names))} pred_prob = float(np.round(pred_prob,3)) pred_class = class_names[pred_label] time_taken = float(np.round(end_time-start_time,3)) return pred_class,pred_prob,time_taken # create example list example_dir = Path('demos/foodvision_mini/examples') example_list = [['examples/' + str(filepath)] for filepath in os.listdir(example_dir)] # create Gradio interface description = 'A machine learning model to classify images into pizza,steak and sushi appropriately' title = 'Image Classifier' demo = gr.Interface(fn=predict, # this function maps the inputs to the output inputs=gr.Image(type='pil'), # pillow image outputs=[gr.Label(num_top_classes=1,label='Prediction'), gr.Number(label='prediction probability'), gr.Number(label='prediction time(s)')], examples=example_list, description=description, title=title ) demo.launch(debug=False, # print errors locally? share=True) # share to the public?