raihanp's picture
Update app.py
7f16b2f verified
raw
history blame
1.16 kB
from PIL import Image
import torch
import gradio as gr
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
from pathlib import Path
labels = {0: 'glass',
1: 'metal',
2: 'organic-waste',
3: 'organice-waste',
4: 'paper',
5: 'plastic',
6: 'textiles'}
inference = torch.load('fine_tune_resnet.pth', map_location=torch.device('cpu'))
inference.eval()
example = [str(i) for i in Path('examples').glob('*')]
def classifier(image):
test_transform = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
with torch.no_grad():
output = inference(test_transform(image).unsqueeze(0))
out = torch.softmax(output,1)
values,indices = torch.topk(out[0],k=7)
return {labels[i.item()]: v.item() for i, v in zip(indices,values)}
iface = gr.Interface(fn=classifier,
inputs=gr.Image(type="pil"),
outputs='label',
examples = example,
title = 'Garbage Image Classification')
iface.launch(share=True)