mrdbourke's picture
Upload 2 files
df8b8a4 verified
raw
history blame
No virus
2.8 kB
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
from PIL import Image
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_and_real_data"
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
id2label = model.config.id2label
color_dict = {
"not_trash": "red",
"bin": "green",
"trash": "blue",
"hand": "purple"
}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def predict_on_image(image, conf_threshold=0.25):
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
results = image_processor.post_process_object_detection(outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all items in results to CPU
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
except:
results[key] = value.cpu()
# Can return results as plotted on a PIL image (then display the image)
draw = ImageDraw.Draw(image)
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white")
# Remove the draw each time
del draw
return image
demo = gr.Interface(
fn=predict_on_image,
inputs=[
gr.Image(type="pil", label="Upload Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
],
outputs=gr.Image(type="pil"),
title="🚮 Trashify Object Detection Demo (real and synthetic data)",
description="Upload an image to detect whether there's a bin, a hand or trash in it. Trained on a mixture of real and synthetic data."
)
if __name__ == "__main__":
demo.launch()