File size: 7,068 Bytes
6a059b5
82b4cf3
6a059b5
 
 
 
 
 
 
 
 
 
 
 
7224c58
6a059b5
7224c58
6a059b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82b4cf3
6a059b5
 
 
 
 
 
 
 
 
 
82b4cf3
6a059b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82b4cf3
6a059b5
2929986
6a059b5
 
 
65082b9
6a059b5
82b4cf3
6a059b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e11b77
6a059b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import time
import uuid

import cv2
import gradio as gr
import numpy as np
import spaces
import supervision as sv
import torch

from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
model = AutoModelForZeroShotObjectDetection.from_pretrained(
    "omlab/omdet-turbo-swin-tiny-hf"
).to(device)

css = """
.feedback textarea {font-size: 24px !important}
"""

global classes
global detections
global labels
global threshold
classes = "person, bike, car"
detections = None
labels = None
threshold = 0.2

BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
SUBSAMPLE = 2


def annotate_image(input_image, detections, labels) -> np.ndarray:
    output_image = MASK_ANNOTATOR.annotate(input_image, detections)
    output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
    output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
    return output_image


@spaces.GPU
def process_video(
    input_video,
    confidence_threshold,
    classes_new,
    progress=gr.Progress(track_tqdm=True),
):
    global detections
    global labels
    global classes
    global threshold
    classes = classes_new
    threshold = confidence_threshold
    result_file_name = f"output_{uuid.uuid4()}.mp4"

    cap = cv2.VideoCapture(input_video)

    video_codec = cv2.VideoWriter_fourcc(*"mp4v")  # type: ignore
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    desired_fps = fps // SUBSAMPLE
    iterating, frame = cap.read()
    segment_file = cv2.VideoWriter(
        result_file_name, video_codec, desired_fps, (width, height)
    )  # type: ignore
    batch = []
    frames = []
    predict_index = []
    n_frames = 0
    while iterating:
        # frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
        if n_frames % SUBSAMPLE == 0:
            predict_index.append(len(frames))
            batch.append(frame)
            frames.append(frame)
        if len(batch) == desired_fps:
            classes_list = classes.strip(" ").split(",")
            results, fps = query(batch, classes_list, threshold, (width, height))
            for i in range(len(frames)):
                if i in predict_index:
                    batch_index = predict_index.index(i)
                    detections = sv.Detections(
                        xyxy=results[batch_index]["boxes"].cpu().detach().numpy(),
                        confidence=results[batch_index]["scores"]
                        .cpu()
                        .detach()
                        .numpy(),
                        class_id=np.array(
                            [
                                classes_list.index(results_class)
                                for results_class in results[batch_index]["classes"]
                            ]
                        ),
                        data={"class_name": results[batch_index]["classes"]},
                    )
                    labels = results[batch_index]["classes"]
                frame = annotate_image(
                    input_image=frames[i],
                    detections=detections,
                    labels=labels,
                )
                segment_file.write(frame)
            segment_file.release()
            yield (
                result_file_name,
                gr.Markdown(
                    f'<h3 style="text-align: center;">Model inference FPS (batched): {fps*len(batch):.2f}</h3>',
                    visible=True,
                ),
            )
            result_file_name = f"output_{uuid.uuid4()}.mp4"
            segment_file = cv2.VideoWriter(
                result_file_name, video_codec, desired_fps, (width, height)
            )  # type: ignore
            batch = []
            frames = []
            predict_index = []
        iterating, frame = cap.read()
        n_frames += 1


def query(frame, classes, confidence_threshold, size=(640, 480)):
    inputs = processor(
        images=frame, text=[classes] * len(frame), return_tensors="pt"
    ).to(device)
    with torch.no_grad():
        start = time.time()
        outputs = model(**inputs)
        fps = 1 / (time.time() - start)
    target_sizes = torch.tensor([size[::-1]] * len(frame))

    results = processor.post_process_grounded_object_detection(
        outputs=outputs,
        classes=[classes] * len(frame),
        score_threshold=confidence_threshold,
        target_sizes=target_sizes,
    )
    return results, fps


def set_classes(classes_input):
    global classes
    classes = classes_input


def set_confidence_threshold(confidence_threshold_input):
    global threshold
    threshold = confidence_threshold_input


with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown("## Real Time Open Vocabulary Object Detection with Omdet-Turbo")
    gr.Markdown(
        """
        This is a demo for real-time open vocabulary object detection using OmDet-Turbo.<br>
        It runs on ZeroGPU which captures GPU every first time you infer.<br>
        This combined with video processing time means that the demo inference time is slower than the model's actual inference time.<br>
        The actual model average inference FPS is displayed under the processed video after inference.
        """
    )
    gr.Markdown(
        "Simply upload a video or try the examples below πŸ‘‡, and press run. You can then change the object detected live in the text box! You also play with the confidence threshold and see how it impacts the objects detected in real time."
    )

    with gr.Row():
        with gr.Column():
            input_video = gr.Video(label="Input Video")
        with gr.Column():
            output_video = gr.Video(label="Output Video", streaming=True, autoplay=True)
            actual_fps = gr.Markdown("", visible=False)
    with gr.Row():
        classes = gr.Textbox(
            "person, cat, dog",
            label="Objects to detect. Change this as you like and press enter!",
            elem_classes="feedback",
            scale=3,
        )
        conf = gr.Slider(
            label="Confidence Threshold",
            minimum=0.1,
            maximum=1.0,
            value=0.2,
            step=0.05,
        )
    with gr.Row():
        submit = gr.Button(variant="primary")

    example = gr.Examples(
        examples=[
            ["./newyorkstreets_small.mp4", 0.3, "person, car, shoe"],
        ],
        inputs=[input_video, conf, classes],
        outputs=[output_video, actual_fps],
    )
    classes.submit(set_classes, classes)
    conf.change(set_confidence_threshold, conf)

    submit.click(
        fn=process_video,
        inputs=[input_video, conf, classes],
        outputs=[output_video, actual_fps],
    )

if __name__ == "__main__":
    demo.launch(show_error=True)