food_classifier / app.py
GeorgeImmanuel's picture
we have unzipped and moved the files to the right directory
aa6a6fe
# import the essentials
import torch
import torchvision
import time
import gradio as gr
import numpy as np
from pathlib import Path
from model import create_effnet_b2_model
with open(class_names.txt) as f:
class_names = [class_name.strip('\n') for class_name in f.readlines()]
device = 'cuda' if torch.cuda.is_available else 'cpu'
# creating the vit_b_16_model and loading it with state_dict of our trained model
effnetb2_model,effnetb2_transform = create_effnet_b2_model(num_classes=len(class_names))
effnetb2_model.load_state_dict(torch.load(f='effnetb2_20percent_101classes.pth'))
# create the predict function
def predict(img):
"""
args:
img: is an image
returns: prediction class, prediction probability, and time taken to make the prediction
"""
# transforming the image
tr_img = effnetb2_transform(img).unsqueeze(dim=0).to(device)
# make prediction with vit_b_16
model = effnetb2_model.to(device)
# starting the time
start_time = time.perf_counter()
model.eval()
with torch.inference_mode():
pred_logit = model(tr_img)
pred_label = torch.argmax(pred_logit,dim=1).cpu()
pred_prob = torch.max(torch.softmax(pred_logit,dim=1)).cpu().item()
# ending the time
end_time = time.perf_counter()
pred_prob = float(np.round(pred_prob,3))
pred_class = class_names[pred_label]
time_taken = float(np.round(end_time-start_time,3))
return pred_class,pred_prob,time_taken
# create example list
example_dir = Path('demos/foodvision_big/examples')
example_list = [['examples/' + str(filepath)] for filepath in os.listdir(example_dir)]
# create Gradio interface
description = 'A machine learning model to classify images into pizza,steak and sushi appropriately'
title = 'Image Classifier'
demo = gr.Interface(fn=predict, # this function maps the inputs to the output
inputs=gr.Image(type='pil'), # pillow image
outputs=[gr.Label(num_top_classes=1,label='Prediction'),
gr.Number(label='prediction probability'),
gr.Number(label='prediction time(s)')],
examples=example_list,
description=description,
title=title
)
demo.launch()