mayhug's picture
Update app.py
e951858
import json
from pprint import pprint
import torch
import torch.hub
from gradio import Interface, inputs, outputs
from PIL import Image
from torchvision import transforms
real_load = torch.hub.load_state_dict_from_url
def load_state_dict_from_url(*args, **kwargs):
kwargs["map_location"] = "cpu"
return real_load(*args, **kwargs)
torch.hub.load_state_dict_from_url = load_state_dict_from_url
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50")
model.eval()
with open("./tags.json", "rt", encoding="utf-8") as f:
tags = json.load(f)
def main(input_image: Image.Image, threshold: float):
preprocess = transforms.Compose(
[
transforms.Resize(360),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]
),
]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(
0
) # create a mini-batch as expected by the model
with torch.no_grad():
output, *_ = model(input_batch)
probs = torch.sigmoid(output)
results = probs[probs > threshold]
inds = probs.argsort(descending=True)
tag_confidences = {}
for index in inds[0 : len(results)]:
tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
pprint(tag_confidences)
return tag_confidences
image = inputs.Image(label="Upload your image here!", type="pil")
threshold = inputs.Slider(
label="Hide images confidence under", maximum=1, minimum=0, default=0.2
)
labels = outputs.Label(label="Tags", type="confidences")
interface = Interface(main, inputs=[image, threshold], outputs=[labels])
interface.launch()