|
import os |
|
|
|
import gradio as gr |
|
import timm |
|
from huggingface_hub import login |
|
from torch import no_grad, softmax, topk |
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
login(token=HF_TOKEN) |
|
|
|
model = timm.create_model(f"hf_hub:{MODEL_NAME}", pretrained=True) |
|
model.eval() |
|
|
|
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) |
|
transform = timm.data.create_transform(**data_cfg) |
|
|
|
|
|
def classify_image(input): |
|
inp = transform(input) |
|
with no_grad(): |
|
output = model(inp.unsqueeze(0)) |
|
probabilities = softmax(output[0], dim=0) |
|
values, indices = topk(probabilities, 3) |
|
|
|
return { |
|
model.pretrained_cfg["label_names"][str(id.item())].title(): prob |
|
for id, prob in zip(indices, values) |
|
} |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.Image(type="pil", sources=["upload", "clipboard"]), |
|
outputs=gr.Label(num_top_classes=3), |
|
allow_flagging="never", |
|
examples="examples", |
|
) |
|
demo.queue() |
|
demo.launch(debug=True) |
|
|