import gradio as gr import timm import torch import torchvision.transforms as transforms inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101) inference_model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) inference_model.eval() with open('labels.txt', 'r') as f: idx_to_class = [s.strip() for s in f.readlines()] preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def inference(input_image): input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) if torch.cuda.is_available(): input_batch = input_batch.to('cuda') inference_model.to('cuda') with torch.inference_mode(): output = inference_model(input_batch) probabilities = torch.nn.functional.softmax(output[0], dim=0) top5_prob, top5_catid = torch.topk(probabilities, 5) # Label:probability result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())} return result title = "See Food 101" description = "Gradio demo for See Food 101, the expansion edition of See Food from Silicon Valley. Simply upload your image, or click on the example(s) to load them. Read more at the links below for architecture used." article = "
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows | Data
" examples = [ ['Screenshot 2023-05-05 085533.png'] ] iface = gr.Interface(fn=inference, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title=title, description=description, article=article, examples=examples, analytics_enabled=False) iface.launch()