File size: 2,802 Bytes
df8b8a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
from PIL import Image, ImageDraw

from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection

from PIL import Image

model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_and_real_data"

image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)

id2label = model.config.id2label
color_dict = {
    "not_trash": "red",
    "bin": "green",
    "trash": "blue",
    "hand": "purple"
}

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

def predict_on_image(image, conf_threshold=0.25):
    with torch.no_grad():
        inputs = image_processor(images=[image], return_tensors="pt")
        outputs = model(**inputs.to(device))

        target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width 

        results = image_processor.post_process_object_detection(outputs,
                                                                threshold=conf_threshold,
                                                                target_sizes=target_sizes)[0]
    # Return all items in results to CPU
    for key, value in results.items():
        try:
            results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
        except:
            results[key] = value.cpu()

    # Can return results as plotted on a PIL image (then display the image)
    draw = ImageDraw.Draw(image)

    for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
        # Create coordinates
        x, y, x2, y2 = tuple(box.tolist())

        # Get label_name
        label_name = id2label[label.item()]
        targ_color = color_dict[label_name]

        # Draw the rectangle
        draw.rectangle(xy=(x, y, x2, y2), 
                       outline=targ_color,
                       width=3)
        
        # Create a text string to display
        text_string_to_show = f"{label_name} ({round(score.item(), 3)})"

        # Draw the text on the image
        draw.text(xy=(x, y),
                  text=text_string_to_show,
                  fill="white")
    
    # Remove the draw each time
    del draw
    
    return image

demo = gr.Interface(
    fn=predict_on_image,
    inputs=[
        gr.Image(type="pil", label="Upload Target Image"),
        gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
    ],
    outputs=gr.Image(type="pil"),
    title="🚮 Trashify Object Detection Demo (real and synthetic data)",
    description="Upload an image to detect whether there's a bin, a hand or trash in it. Trained on a mixture of real and synthetic data."
)

if __name__ == "__main__":
    demo.launch()