|
from transformers import Owlv2Processor, Owlv2ForObjectDetection |
|
from typing import List |
|
import os |
|
import numpy as np |
|
import supervision as sv |
|
import uuid |
|
import torch |
|
from tqdm import tqdm |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import spaces |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") |
|
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) |
|
|
|
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() |
|
MASK_ANNOTATOR = sv.MaskAnnotator() |
|
LABEL_ANNOTATOR = sv.LabelAnnotator() |
|
|
|
|
|
def calculate_end_frame_index(source_video_path): |
|
video_info = sv.VideoInfo.from_video_path(source_video_path) |
|
return min( |
|
video_info.total_frames, |
|
video_info.fps * 2 |
|
) |
|
|
|
|
|
def annotate_image( |
|
input_image, |
|
detections, |
|
labels |
|
) -> np.ndarray: |
|
output_image = MASK_ANNOTATOR.annotate(input_image, detections) |
|
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) |
|
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) |
|
return output_image |
|
|
|
@spaces.GPU |
|
def process_video( |
|
input_video, |
|
labels, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
labels = labels.split(",") |
|
video_info = sv.VideoInfo.from_video_path(input_video) |
|
total = calculate_end_frame_index(input_video) |
|
frame_generator = sv.get_video_frames_generator( |
|
source_path=input_video, |
|
end=total |
|
) |
|
|
|
result_file_name = f"{uuid.uuid4()}.mp4" |
|
result_file_path = os.path.join("./", result_file_name) |
|
with sv.VideoSink(result_file_path, video_info=video_info) as sink: |
|
for _ in tqdm(range(total), desc="Processing video.."): |
|
frame = next(frame_generator) |
|
|
|
results = query(frame, labels) |
|
print("results", results) |
|
detections = sv.Detections.from_transformers(results[0]) |
|
final_labels = [] |
|
for id in results[0]["labels"]: |
|
final_labels.append(labels[id]) |
|
frame = annotate_image( |
|
input_image=frame, |
|
detections=detections, |
|
labels=final_labels, |
|
) |
|
sink.write_frame(frame) |
|
return result_file_path |
|
|
|
def query(image, texts): |
|
inputs = processor(text=texts, images=image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
target_sizes = torch.Tensor([image.shape[:-1]]) |
|
|
|
results = processor.post_process_object_detection(outputs=outputs, threshold=0.3, target_sizes=target_sizes) |
|
|
|
return results |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Zero-shot Object Tracking with OWLv2 🦉") |
|
gr.Markdown("This is a demo for zero-shot object tracking using [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) model by Google.") |
|
gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. 👇") |
|
with gr.Tab(label="Video"): |
|
with gr.Row(): |
|
input_video = gr.Video( |
|
label='Input Video' |
|
) |
|
output_video = gr.Video( |
|
label='Output Video' |
|
) |
|
with gr.Row(): |
|
candidate_labels = gr.Textbox( |
|
label='Labels', |
|
placeholder='Labels separated by a comma', |
|
) |
|
submit = gr.Button() |
|
gr.Examples( |
|
fn=process_video, |
|
examples=[["./cats.mp4", "dog,cat"]], |
|
inputs=[ |
|
input_video, |
|
candidate_labels, |
|
|
|
], |
|
outputs=output_video |
|
) |
|
|
|
submit.click( |
|
fn=process_video, |
|
inputs=[input_video, candidate_labels], |
|
outputs=output_video |
|
) |
|
|
|
demo.launch(debug=False, show_error=True) |
|
|