# creating app.py ########## imports ############ import torch import torch.nn as nn from torchvision import models, transforms import gradio as gr from model import create_model import PIL from PIL import Image import os from pathlib import Path ############################### def predict(img): swinb, transform = create_model() class_names = [] with open('classes.txt', 'r') as f: class_names = [foodname.strip() for foodname in f.readlines()] swinb.load_state_dict(torch.load(f = 'models/swin40%newer.pth', map_location = torch.device('cpu'))) img = transform(img).unsqueeze(0) swinb.eval() with torch.inference_mode(): pred_label = class_names[swinb(img).softmax(dim = 1).argmax(dim = 1)] print(pred_label) return pred_label ############################### title = 'FoodVision Project' description = 'FoodVision is an image classification model based on EfficientNet_B2 which has been trained on a 101 different classes using the Food101 dataset' ############################### example_list = [['examples/' + example] for example in os.listdir('examples')] ############## for examples in example_list: print(examples) demo = gr.Interface( fn = predict, inputs = gr.Image(type = 'pil'), outputs = [gr.Textbox()], examples = example_list, title = title, description = description ) demo.launch(debug = True)