Spaces:
Running
Running
File size: 2,802 Bytes
df8b8a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
from PIL import Image
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_and_real_data"
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
id2label = model.config.id2label
color_dict = {
"not_trash": "red",
"bin": "green",
"trash": "blue",
"hand": "purple"
}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def predict_on_image(image, conf_threshold=0.25):
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
results = image_processor.post_process_object_detection(outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all items in results to CPU
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
except:
results[key] = value.cpu()
# Can return results as plotted on a PIL image (then display the image)
draw = ImageDraw.Draw(image)
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white")
# Remove the draw each time
del draw
return image
demo = gr.Interface(
fn=predict_on_image,
inputs=[
gr.Image(type="pil", label="Upload Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
],
outputs=gr.Image(type="pil"),
title="🚮 Trashify Object Detection Demo (real and synthetic data)",
description="Upload an image to detect whether there's a bin, a hand or trash in it. Trained on a mixture of real and synthetic data."
)
if __name__ == "__main__":
demo.launch()
|