from transformers import AutoFeatureExtractor, AutoModelForImageClassification import gradio import torch from einops import rearrange import numpy extractor = AutoFeatureExtractor.from_pretrained("vincentclaes/mit-indoor-scenes") model = AutoModelForImageClassification.from_pretrained("vincentclaes/mit-indoor-scenes") labels = { "0": "airport_inside", "1": "artstudio", "2": "auditorium", "3": "bakery", "4": "bar", "5": "bathroom", "6": "bedroom", "7": "bookstore", "8": "bowling", "9": "buffet", "10": "casino", "11": "children_room", "12": "church_inside", "13": "classroom", "14": "cloister", "15": "closet", "16": "clothingstore", "17": "computerroom", "18": "concert_hall", "19": "corridor", "20": "deli", "21": "dentaloffice", "22": "dining_room", "23": "elevator", "24": "fastfood_restaurant", "25": "florist", "26": "gameroom", "27": "garage", "28": "greenhouse", "29": "grocerystore", "30": "gym", "31": "hairsalon", "32": "hospitalroom", "33": "inside_bus", "34": "inside_subway", "35": "jewelleryshop", "36": "kindergarden", "37": "kitchen", "38": "laboratorywet", "39": "laundromat", "40": "library", "41": "livingroom", "42": "lobby", "43": "locker_room", "44": "mall", "45": "meeting_room", "46": "movietheater", "47": "museum", "48": "nursery", "49": "office", "50": "operating_room", "51": "pantry", "52": "poolinside", "53": "prisoncell", "54": "restaurant", "55": "restaurant_kitchen", "56": "shoeshop", "57": "stairscase", "58": "studiomusic", "59": "subway", "60": "toystore", "61": "trainstation", "62": "tv_studio", "63": "videostore", "64": "waitingroom", "65": "warehouse", "66": "winecellar" } def classify(image): model.eval() with torch.no_grad(): inputs = extractor(images=image, return_tensors='pt') outputs = model(**inputs).logits outputs = rearrange(outputs, '1 j->j') outputs = outputs.cpu().numpy() outputs = (numpy.exp(outputs)) / (numpy.sum(numpy.exp(outputs))) return {labels[str(i)]: float(outputs[i]) for i in range(len(labels))} gradio.Interface(fn=classify, inputs=gradio.inputs.Image(shape=(224,224), image_mode='RGB', source='upload', tool='editor', type='pil', label=None, optional=False), outputs=gradio.outputs.Label(num_top_classes=5, type='auto'), allow_flagging='never').launch(inbrowser=True, debug=True)