Spaces:
Runtime error
Runtime error
# import the essentials | |
from demos.foodvision_mini.model import create_vit_b_16_model | |
import torch | |
import torchvision | |
import time | |
import gradio as gr | |
import numpy as np | |
from pathlib import Path | |
import os | |
class_names = ['pizza','steak','sushi'] | |
device = 'cuda' if torch.cuda.is_available else 'cpu' | |
# creating the vit_b_16_model and loading it with state_dict of our trained model | |
vit_b_16_model,vit_b_16_transform = create_vit_b_16_model(num_classes=3) | |
vit_b_16_model.load_state_dict(torch.load(f='vit_b_16_20_percent_data.pth')) | |
# create the predict function | |
def predict(img): | |
""" | |
args: | |
img: is an image | |
returns: prediction class, prediction probability, and time taken to make the prediction | |
""" | |
# transforming the image | |
tr_img = vit_b_16_transform(img).unsqueeze(dim=0).to(device) | |
# make prediction with vit_b_16 | |
model = vit_b_16_model.to(device) | |
# starting the time | |
start_time = time.perf_counter() | |
model.eval() | |
with torch.inference_mode(): | |
pred_logit = model(tr_img) | |
pred_label = torch.argmax(pred_logit,dim=1).cpu() | |
pred_prob = torch.max(torch.softmax(pred_logit,dim=1)).cpu().item() | |
# ending the time | |
end_time = time.perf_counter() | |
# pred_dict = {str(class_names[i]):float(pred_prob[0][i].item()) for i in range(len(class_names))} | |
pred_prob = float(np.round(pred_prob,3)) | |
pred_class = class_names[pred_label] | |
time_taken = float(np.round(end_time-start_time,3)) | |
return pred_class,pred_prob,time_taken | |
# create example list | |
example_dir = Path('demos/foodvision_mini/examples') | |
example_list = [['examples/' + str(filepath)] for filepath in os.listdir(example_dir)] | |
# create Gradio interface | |
description = 'A machine learning model to classify images into pizza,steak and sushi appropriately' | |
title = 'Image Classifier' | |
demo = gr.Interface(fn=predict, # this function maps the inputs to the output | |
inputs=gr.Image(type='pil'), # pillow image | |
outputs=[gr.Label(num_top_classes=1,label='Prediction'), | |
gr.Number(label='prediction probability'), | |
gr.Number(label='prediction time(s)')], | |
examples=example_list, | |
description=description, | |
title=title | |
) | |
demo.launch(debug=False, # print errors locally? | |
share=True) # share to the public? | |