Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch | |
import numpy | |
from timeit import default_timer as timer | |
from model import create_effnetb2_model | |
from typing import Tuple , Dict | |
# 1.Import and class names setup | |
class_names = ['apple_pie', | |
'baby_back_ribs', | |
'baklava', | |
'beef_carpaccio', | |
'beef_tartare', | |
'beet_salad', | |
'beignets', | |
'bibimbap', | |
'bread_pudding', | |
'breakfast_burrito', | |
'bruschetta', | |
'caesar_salad', | |
'cannoli', | |
'caprese_salad', | |
'carrot_cake', | |
'ceviche', | |
'cheese_plate', | |
'cheesecake', | |
'chicken_curry', | |
'chicken_quesadilla', | |
'chicken_wings', | |
'chocolate_cake', | |
'chocolate_mousse', | |
'churros', | |
'clam_chowder', | |
'club_sandwich', | |
'crab_cakes', | |
'creme_brulee', | |
'croque_madame', | |
'cup_cakes', | |
'deviled_eggs', | |
'donuts', | |
'dumplings', | |
'edamame', | |
'eggs_benedict', | |
'escargots', | |
'falafel', | |
'filet_mignon', | |
'fish_and_chips', | |
'foie_gras', | |
'french_fries', | |
'french_onion_soup', | |
'french_toast', | |
'fried_calamari', | |
'fried_rice', | |
'frozen_yogurt', | |
'garlic_bread', | |
'gnocchi', | |
'greek_salad', | |
'grilled_cheese_sandwich', | |
'grilled_salmon', | |
'guacamole', | |
'gyoza', | |
'hamburger', | |
'hot_and_sour_soup', | |
'hot_dog', | |
'huevos_rancheros', | |
'hummus', | |
'ice_cream', | |
'lasagna', | |
'lobster_bisque', | |
'lobster_roll_sandwich', | |
'macaroni_and_cheese', | |
'macarons', | |
'miso_soup', | |
'mussels', | |
'nachos', | |
'omelette', | |
'onion_rings', | |
'oysters', | |
'pad_thai', | |
'paella', | |
'pancakes', | |
'panna_cotta', | |
'peking_duck', | |
'pho', | |
'pizza', | |
'pork_chop', | |
'poutine', | |
'prime_rib', | |
'pulled_pork_sandwich', | |
'ramen', | |
'ravioli', | |
'red_velvet_cake', | |
'risotto', | |
'samosa', | |
'sashimi', | |
'scallops', | |
'seaweed_salad', | |
'shrimp_and_grits', | |
'spaghetti_bolognese', | |
'spaghetti_carbonara', | |
'spring_rolls', | |
'steak', | |
'strawberry_shortcake', | |
'sushi', | |
'tacos', | |
'takoyaki', | |
'tiramisu', | |
'tuna_tartare', | |
'waffles'] | |
print(numpy.__version__) | |
from torchvision.models._api import WeightsEnum | |
from torch.hub import load_state_dict_from_url | |
def get_state_dict(self, *args, **kwargs): | |
kwargs.pop("check_hash") | |
return load_state_dict_from_url(self.url, *args, **kwargs) | |
WeightsEnum.get_state_dict = get_state_dict | |
# 2. Model annd transforms prepration | |
effnetb2_model , effnet_b2_transforms = create_effnetb2_model(num_classes = 101, seed = 42) | |
# Load save weights | |
effnetb2_model.load_state_dict( | |
torch.load( | |
f='11_pretrained_effnet_feature_extractor_food101_fine_tune.pth', | |
map_location = torch.device('cpu') # Load the model on CPU | |
) | |
) | |
# 3.prediction function (predict()) | |
def predict(img) -> Tuple[Dict,float] : | |
start_time = timer() | |
image = effnet_b2_transforms(img).unsqueeze(0) | |
effnetb2_model.eval() | |
with torch.inference_mode(): | |
pred_probs = torch.softmax(effnetb2_model(image) , dim=1) | |
pred_label_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range (len(class_names))} | |
end_time = timer() | |
pred_time = round(end_time - start_time , 4) | |
return pred_label_and_probs , pred_time | |
### 4. Gradio app - our Gradio interface + launch command | |
title = 'FoodVision Big' | |
description = 'An FineTune last 4 Sequential layers of EfficientNetB2 model to classifiy 101 Food images ' | |
article = 'created at PyTorch Model Deployment' | |
# Create example list | |
example_list = [['examples/'+ example] for example in os.listdir('examples')] | |
example_list | |
# create a gradio demo | |
demo = gr.Interface(fn=predict , | |
inputs=gr.Image(type='pil'), | |
outputs=[gr.Label(num_top_classes = 3 , label= 'prediction'), | |
gr.Number(label= 'Prediction time (s)')], | |
examples = example_list, | |
title = title, | |
description = description, | |
article= article) | |
# Launch the demo | |
demo.launch(debug= False) | |