Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") | |
model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
"omlab/omdet-turbo-swin-tiny-hf" | |
).to(device) | |
css = """ | |
.feedback textarea {font-size: 24px !important} | |
""" | |
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() | |
MASK_ANNOTATOR = sv.MaskAnnotator() | |
LABEL_ANNOTATOR = sv.LabelAnnotator() | |
def calculate_end_frame_index(source_video_path): | |
video_info = sv.VideoInfo.from_video_path(source_video_path) | |
return min(video_info.total_frames, video_info.fps * 5) | |
def annotate_image(input_image, detections, labels) -> np.ndarray: | |
output_image = MASK_ANNOTATOR.annotate(input_image, detections) | |
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) | |
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) | |
return output_image | |
def process_video( | |
input_video, | |
confidence_threshold, | |
classes, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
classes = classes.strip(" ").split(",") | |
video_info = sv.VideoInfo.from_video_path(input_video) | |
total = calculate_end_frame_index(input_video) | |
frame_generator = sv.get_video_frames_generator(source_path=input_video, end=total) | |
result_file_name = "output.mp4" | |
result_file_path = os.path.join(os.getcwd(), result_file_name) | |
all_fps = [] | |
with sv.VideoSink(result_file_path, video_info=video_info) as sink: | |
for _ in tqdm(range(total), desc="Processing video.."): | |
try: | |
frame = next(frame_generator) | |
except StopIteration: | |
break | |
results, fps = query(frame, classes, confidence_threshold) | |
all_fps.append(fps) | |
detections = [] | |
detections = sv.Detections( | |
xyxy=results[0]["boxes"].cpu().detach().numpy(), | |
confidence=results[0]["scores"].cpu().detach().numpy(), | |
class_id=np.array( | |
[ | |
classes.index(results_class) | |
for results_class in results[0]["classes"] | |
] | |
), | |
data={"class_name": results[0]["classes"]}, | |
) | |
frame = annotate_image( | |
input_image=frame, | |
detections=detections, | |
labels=results[0]["classes"], | |
) | |
sink.write_frame(frame) | |
avg_fps = np.mean(all_fps) | |
return result_file_path, gr.Markdown( | |
f'<h3 style="text-align: center;">Model inference FPS: {avg_fps:.2f}</h3>', | |
visible=True, | |
) | |
def query(frame, classes, confidence_threshold): | |
image = Image.fromarray(frame) | |
inputs = processor(images=image, text=classes, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
start = time.time() | |
outputs = model(**inputs) | |
fps = 1 / (time.time() - start) | |
target_sizes = [frame.shape[:2]] | |
results = processor.post_process_grounded_object_detection( | |
outputs=outputs, | |
classes=classes, | |
score_threshold=confidence_threshold, | |
target_sizes=target_sizes, | |
) | |
return results, fps | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.Markdown("## Real Time Open Vocabulary Object Detection with Omdet-Turbo") | |
gr.Markdown( | |
""" | |
This is a demo for open vocabulary object detection using OmDet-Turbo.<br> | |
It runs on ZeroGPU which captures GPU every first time you infer.<br> | |
This combined with video processing time means that the demo inference time is slower than the model's actual inference time.<br> | |
The actual model average inference FPS is displayed under the processed video after inference. | |
""" | |
) | |
gr.Markdown( | |
"Simply upload a video, and write the objects you want to detect! You can also play with confidence threshold or try the examples below. ๐" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Video(label="Input Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video (5s max)") | |
actual_fps = gr.Markdown("", visible=False) | |
with gr.Row(): | |
classes = gr.Textbox( | |
"person, cat, dog", | |
label="Objects to detect. Change this as you like!", | |
elem_classes="feedback", | |
scale=3, | |
) | |
conf = gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.2, | |
step=0.05, | |
) | |
with gr.Row(): | |
submit = gr.Button(variant="primary") | |
example = gr.Examples( | |
examples=[ | |
["./football.mp4", 0.3, "person, ball, shoe"], | |
["./cat.mp4", 0.2, "cat"], | |
["./safari2.mp4", 0.3, "elephant, giraffe, springbok, zebra"], | |
], | |
inputs=[input_video, conf, classes], | |
outputs=output_video, | |
) | |
submit.click( | |
fn=process_video, | |
inputs=[input_video, conf, classes], | |
outputs=[output_video, actual_fps], | |
) | |
if __name__ == "__main__": | |
demo.launch(show_error=True) | |