File size: 985 Bytes
6982e15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from Models import VisionModel
import huggingface_hub
from PIL import Image
import torch.amp.autocast_mode
from pathlib import Path


MODEL_REPO = "fancyfeast/joytag"


@torch.no_grad()
def predict(image: Image.Image):
	with torch.amp.autocast_mode.autocast('cuda', enabled=True):
		preds = model(image)
		tag_preds = preds['tags'].sigmoid().cpu()
	
	return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}


print("Downloading model...")
path = huggingface_hub.snapshot_download(MODEL_REPO)
print("Loading model...")
model = VisionModel.load_model(path)
model.eval()

with open(Path(path) / 'top_tags.txt', 'r') as f:
	top_tags = [line.strip() for line in f.readlines() if line.strip()]

print("Starting server...")

gradio_app = gr.Interface(
	predict,
	inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'),
	outputs=[gr.Label(label="Result", num_top_classes=5)],
	title="JoyTag",
)


if __name__ == '__main__':
	gradio_app.launch()