from PIL import Image import torch import gradio as gr from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop from pathlib import Path labels = {0: 'glass', 1: 'metal', 2: 'organic-waste', 3: 'organice-waste', 4: 'paper', 5: 'plastic', 6: 'textiles'} inference = torch.load('fine_tune_resnet.pth', map_location=torch.device('cpu')) inference.eval() example = [str(i) for i in Path('examples').glob('*')] def classifier(image): test_transform = Compose([ Resize(256), CenterCrop(224), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) with torch.no_grad(): output = inference(test_transform(image).unsqueeze(0)) out = torch.softmax(output,1) values,indices = torch.topk(out[0],k=7) return {labels[i.item()]: v.item() for i, v in zip(indices,values)} iface = gr.Interface(fn=classifier, inputs=gr.Image(type="pil"), outputs='label', examples = example, title = 'Garbage Image Classification') iface.launch(share=True)