import os import gradio as gr import timm from huggingface_hub import login from torch import no_grad, softmax, topk MODEL_NAME = os.getenv("MODEL_NAME") HF_TOKEN = os.getenv("HF_TOKEN") login(token=HF_TOKEN) model = timm.create_model(f"hf_hub:{MODEL_NAME}", pretrained=True) model.eval() data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) transform = timm.data.create_transform(**data_cfg) def classify_image(input): inp = transform(input) with no_grad(): output = model(inp.unsqueeze(0)) probabilities = softmax(output[0], dim=0) values, indices = topk(probabilities, 3) return { model.pretrained_cfg["label_names"][str(id.item())].title(): prob for id, prob in zip(indices, values) } demo = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", sources=["upload", "clipboard"]), outputs=gr.Label(num_top_classes=3), allow_flagging="never", examples="examples", ) demo.queue() demo.launch(debug=True)