Spaces:
Running
on
A10G
Running
on
A10G
import time | |
import uuid | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import supervision as sv | |
import torch | |
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") | |
model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
"omlab/omdet-turbo-swin-tiny-hf" | |
).to(device) | |
css = """ | |
.feedback textarea {font-size: 24px !important} | |
""" | |
global classes | |
global detections | |
global labels | |
global threshold | |
classes = "person, bike, car" | |
detections = None | |
labels = None | |
threshold = 0.2 | |
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() | |
MASK_ANNOTATOR = sv.MaskAnnotator() | |
LABEL_ANNOTATOR = sv.LabelAnnotator() | |
SUBSAMPLE = 2 | |
def annotate_image(input_image, detections, labels) -> np.ndarray: | |
output_image = MASK_ANNOTATOR.annotate(input_image, detections) | |
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) | |
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) | |
return output_image | |
def process_video( | |
input_video, | |
confidence_threshold, | |
classes_new, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global detections | |
global labels | |
global classes | |
global threshold | |
classes = classes_new | |
threshold = confidence_threshold | |
result_file_name = f"output_{uuid.uuid4()}.mp4" | |
cap = cv2.VideoCapture(input_video) | |
video_codec = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
desired_fps = fps // SUBSAMPLE | |
iterating, frame = cap.read() | |
segment_file = cv2.VideoWriter( | |
result_file_name, video_codec, desired_fps, (width, height) | |
) # type: ignore | |
batch = [] | |
frames = [] | |
predict_index = [] | |
n_frames = 0 | |
while iterating: | |
# frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5) | |
if n_frames % SUBSAMPLE == 0: | |
predict_index.append(len(frames)) | |
batch.append(frame) | |
frames.append(frame) | |
if len(batch) == desired_fps: | |
classes_list = classes.strip(" ").split(",") | |
results, fps = query(batch, classes_list, threshold, (width, height)) | |
for i in range(len(frames)): | |
if i in predict_index: | |
batch_index = predict_index.index(i) | |
detections = sv.Detections( | |
xyxy=results[batch_index]["boxes"].cpu().detach().numpy(), | |
confidence=results[batch_index]["scores"] | |
.cpu() | |
.detach() | |
.numpy(), | |
class_id=np.array( | |
[ | |
classes_list.index(results_class) | |
for results_class in results[batch_index]["classes"] | |
] | |
), | |
data={"class_name": results[batch_index]["classes"]}, | |
) | |
labels = results[batch_index]["classes"] | |
frame = annotate_image( | |
input_image=frames[i], | |
detections=detections, | |
labels=labels, | |
) | |
segment_file.write(frame) | |
segment_file.release() | |
yield ( | |
result_file_name, | |
gr.Markdown( | |
f'<h3 style="text-align: center;">Model inference FPS (batched): {fps*len(batch):.2f}</h3>', | |
visible=True, | |
), | |
) | |
result_file_name = f"output_{uuid.uuid4()}.mp4" | |
segment_file = cv2.VideoWriter( | |
result_file_name, video_codec, desired_fps, (width, height) | |
) # type: ignore | |
batch = [] | |
frames = [] | |
predict_index = [] | |
iterating, frame = cap.read() | |
n_frames += 1 | |
def query(frame, classes, confidence_threshold, size=(640, 480)): | |
inputs = processor( | |
images=frame, text=[classes] * len(frame), return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
start = time.time() | |
outputs = model(**inputs) | |
fps = 1 / (time.time() - start) | |
target_sizes = torch.tensor([size[::-1]] * len(frame)) | |
results = processor.post_process_grounded_object_detection( | |
outputs=outputs, | |
classes=[classes] * len(frame), | |
score_threshold=confidence_threshold, | |
target_sizes=target_sizes, | |
) | |
return results, fps | |
def set_classes(classes_input): | |
global classes | |
classes = classes_input | |
def set_confidence_threshold(confidence_threshold_input): | |
global threshold | |
threshold = confidence_threshold_input | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.Markdown("## Real Time Open Vocabulary Object Detection with Omdet-Turbo") | |
gr.Markdown( | |
""" | |
This is a demo for real-time open vocabulary object detection using OmDet-Turbo.<br> | |
It runs on ZeroGPU which captures GPU every first time you infer.<br> | |
This combined with video processing time means that the demo inference time is slower than the model's actual inference time.<br> | |
The actual model average inference FPS is displayed under the processed video after inference. | |
""" | |
) | |
gr.Markdown( | |
"Simply upload a video or try the examples below π, and press run. You can then change the object detected live in the text box! You also play with the confidence threshold and see how it impacts the objects detected in real time." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Video(label="Input Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video", streaming=True, autoplay=True) | |
actual_fps = gr.Markdown("", visible=False) | |
with gr.Row(): | |
classes = gr.Textbox( | |
"person, cat, dog", | |
label="Objects to detect. Change this as you like and press enter!", | |
elem_classes="feedback", | |
scale=3, | |
) | |
conf = gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.2, | |
step=0.05, | |
) | |
with gr.Row(): | |
submit = gr.Button(variant="primary") | |
example = gr.Examples( | |
examples=[ | |
["./newyorkstreets_small.mp4", 0.3, "person, car, shoe"], | |
], | |
inputs=[input_video, conf, classes], | |
outputs=[output_video, actual_fps], | |
) | |
classes.submit(set_classes, classes) | |
conf.change(set_confidence_threshold, conf) | |
submit.click( | |
fn=process_video, | |
inputs=[input_video, conf, classes], | |
outputs=[output_video, actual_fps], | |
) | |
if __name__ == "__main__": | |
demo.launch(show_error=True) | |