Spaces:
Runtime error
Runtime error
File size: 985 Bytes
6982e15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import gradio as gr
from Models import VisionModel
import huggingface_hub
from PIL import Image
import torch.amp.autocast_mode
from pathlib import Path
MODEL_REPO = "fancyfeast/joytag"
@torch.no_grad()
def predict(image: Image.Image):
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
preds = model(image)
tag_preds = preds['tags'].sigmoid().cpu()
return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}
print("Downloading model...")
path = huggingface_hub.snapshot_download(MODEL_REPO)
print("Loading model...")
model = VisionModel.load_model(path)
model.eval()
with open(Path(path) / 'top_tags.txt', 'r') as f:
top_tags = [line.strip() for line in f.readlines() if line.strip()]
print("Starting server...")
gradio_app = gr.Interface(
predict,
inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'),
outputs=[gr.Label(label="Result", num_top_classes=5)],
title="JoyTag",
)
if __name__ == '__main__':
gradio_app.launch()
|