SkalskiP's picture
take every second frame
b03c819
raw
history blame
4.43 kB
import os
import spaces
from unittest.mock import patch
import gradio as gr
import numpy as np
import supervision as sv
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor
from utils.imports import fixed_get_imports
from utils.models import (
run_captioning,
CAPTIONING_TASK,
run_caption_to_phrase_grounding
)
from utils.video import (
create_directory,
remove_files_older_than,
generate_file_name,
calculate_end_frame_index
)
MARKDOWN = """
# Florence-2 for Videos 🎬
<div>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-florence-2-on-detection-dataset.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/florence-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://arxiv.org/abs/2311.06242">
<img src="https://img.shields.io/badge/arXiv-2311.06242-b31b1b.svg" alt="arXiv" style="display:inline-block;">
</a>
</div>
"""
RESULTS = "results"
CHECKPOINT = "microsoft/Florence-2-base-ft"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
MODEL = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, trust_remote_code=True).to(DEVICE)
PROCESSOR = AutoProcessor.from_pretrained(
CHECKPOINT, trust_remote_code=True)
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.TRACK)
LABEL_ANNOTATOR = sv.LabelAnnotator(color_lookup=sv.ColorLookup.TRACK)
TRACKER = sv.ByteTrack()
# creating video results directory
create_directory(directory_path=RESULTS)
def annotate_image(
input_image: np.ndarray,
detections: sv.Detections
) -> np.ndarray:
output_image = input_image.copy()
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
return output_image
@spaces.GPU
def process_video(
input_video: str,
progress=gr.Progress(track_tqdm=True)
) -> str:
# cleanup of old video files
remove_files_older_than(RESULTS, 30)
video_info = sv.VideoInfo.from_video_path(input_video)
video_info.fps = video_info.fps // 2
total = calculate_end_frame_index(input_video)
frame_generator = sv.get_video_frames_generator(
source_path=input_video,
end=total,
stride=2
)
result_file_name = generate_file_name(extension="mp4")
result_file_path = os.path.join(RESULTS, result_file_name)
TRACKER.reset()
with sv.VideoSink(result_file_path, video_info=video_info) as sink:
for _ in tqdm(range(total), desc="Processing video..."):
frame = next(frame_generator)
caption = run_captioning(
model=MODEL,
processor=PROCESSOR,
image=frame,
device=DEVICE
)[CAPTIONING_TASK]
detections = run_caption_to_phrase_grounding(
model=MODEL,
processor=PROCESSOR,
caption=caption,
image=frame,
device=DEVICE
)
detections.confidence = np.ones(len(detections))
detections.class_id = np.zeros(len(detections))
detections = TRACKER.update_with_detections(detections)
frame = annotate_image(
input_image=frame,
detections=detections
)
sink.write_frame(frame)
return result_file_path
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
input_video_component = gr.Video(
label='Input Video'
)
output_video_component = gr.Video(
label='Output Video'
)
with gr.Row():
submit_button_component = gr.Button(
value='Submit',
scale=1,
variant='primary'
)
submit_button_component.click(
fn=process_video,
inputs=[
input_video_component,
],
outputs=output_video_component
)
demo.launch(debug=False, show_error=True, max_threads=1)