Spaces:
Runtime error
Runtime error
import gradio as gr | |
import timm | |
import torch | |
import torchvision.transforms as transforms | |
inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101) | |
inference_model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) | |
inference_model.eval() | |
with open('labels.txt', 'r') as f: | |
idx_to_class = [s.strip() for s in f.readlines()] | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def inference(input_image): | |
input_tensor = preprocess(input_image) | |
input_batch = input_tensor.unsqueeze(0) | |
if torch.cuda.is_available(): | |
input_batch = input_batch.to('cuda') | |
inference_model.to('cuda') | |
with torch.inference_mode(): | |
output = inference_model(input_batch) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
top5_prob, top5_catid = torch.topk(probabilities, 5) | |
# Label:probability | |
result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())} | |
return result | |
title = "See Food 101" | |
description = "Gradio demo for See Food 101, the expansion edition of See Food from Silicon Valley. Simply upload your image, or click on the example(s) to load them. Read more at the links below for architecture used." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2103.14030'>Swin Transformer: Hierarchical Vision Transformer using Shifted Windows</a> | <a href='https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/'>Data</a></p>" | |
examples = [ | |
['Screenshot 2023-05-05 085533.png'] | |
] | |
iface = gr.Interface(fn=inference, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=5), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
analytics_enabled=False) | |
iface.launch() |