Spaces:
Runtime error
Runtime error
import json | |
from pprint import pprint | |
import torch | |
import torch.hub | |
from gradio import Interface, inputs, outputs | |
from PIL import Image | |
from torchvision import transforms | |
real_load = torch.hub.load_state_dict_from_url | |
def load_state_dict_from_url(*args, **kwargs): | |
kwargs["map_location"] = "cpu" | |
return real_load(*args, **kwargs) | |
torch.hub.load_state_dict_from_url = load_state_dict_from_url | |
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50") | |
model.eval() | |
with open("./tags.json", "rt", encoding="utf-8") as f: | |
tags = json.load(f) | |
def main(input_image: Image.Image, threshold: float): | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize(360), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979] | |
), | |
] | |
) | |
input_tensor = preprocess(input_image) | |
input_batch = input_tensor.unsqueeze( | |
0 | |
) # create a mini-batch as expected by the model | |
with torch.no_grad(): | |
output, *_ = model(input_batch) | |
probs = torch.sigmoid(output) | |
results = probs[probs > threshold] | |
inds = probs.argsort(descending=True) | |
tag_confidences = {} | |
for index in inds[0 : len(results)]: | |
tag_confidences[tags[index]] = float(probs[index].cpu().numpy()) | |
pprint(tag_confidences) | |
return tag_confidences | |
image = inputs.Image(label="Upload your image here!", type="pil") | |
threshold = inputs.Slider( | |
label="Hide images confidence under", maximum=1, minimum=0, default=0.2 | |
) | |
labels = outputs.Label(label="Tags", type="confidences") | |
interface = Interface(main, inputs=[image, threshold], outputs=[labels]) | |
interface.launch() | |