SeeFood101v1 / app.py
HangenYuu's picture
Update app.py
8a62832
import gradio as gr
import timm
import torch
import torchvision.transforms as transforms
inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101)
inference_model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
inference_model.eval()
with open('labels.txt', 'r') as f:
idx_to_class = [s.strip() for s in f.readlines()]
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def inference(input_image):
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
inference_model.to('cuda')
with torch.inference_mode():
output = inference_model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
# Label:probability
result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())}
return result
title = "See Food 101"
description = "Gradio demo for See Food 101, the expansion edition of See Food from Silicon Valley. Simply upload your image, or click on the example(s) to load them. Read more at the links below for architecture used."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2103.14030'>Swin Transformer: Hierarchical Vision Transformer using Shifted Windows</a> | <a href='https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/'>Data</a></p>"
examples = [
['Screenshot 2023-05-05 085533.png']
]
iface = gr.Interface(fn=inference,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title=title,
description=description,
article=article,
examples=examples,
analytics_enabled=False)
iface.launch()