SkalskiP's picture
generate caption only once
1f28a9c
raw
history blame contribute delete
No virus
4.59 kB
import os
import spaces
from unittest.mock import patch
import gradio as gr
import numpy as np
import supervision as sv
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor
from utils.imports import fixed_get_imports
from utils.models import (
run_captioning,
CAPTIONING_TASK,
run_caption_to_phrase_grounding
)
from utils.video import (
create_directory,
remove_files_older_than,
generate_file_name,
calculate_end_frame_index
)
MARKDOWN = """
# Florence-2 for Videos 🎬
<div>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-florence-2-on-detection-dataset.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/florence-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://arxiv.org/abs/2311.06242">
<img src="https://img.shields.io/badge/arXiv-2311.06242-b31b1b.svg" alt="arXiv" style="display:inline-block;">
</a>
</div>
"""
RESULTS = "results"
CHECKPOINT = "microsoft/Florence-2-base-ft"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
MODEL = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, trust_remote_code=True).to(DEVICE)
PROCESSOR = AutoProcessor.from_pretrained(
CHECKPOINT, trust_remote_code=True)
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.TRACK)
LABEL_ANNOTATOR = sv.LabelAnnotator(color_lookup=sv.ColorLookup.TRACK)
TRACKER = sv.ByteTrack()
# creating video results directory
create_directory(directory_path=RESULTS)
def annotate_image(
input_image: np.ndarray,
detections: sv.Detections
) -> np.ndarray:
output_image = input_image.copy()
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
return output_image
@spaces.GPU
def process_video(
input_video: str,
progress=gr.Progress(track_tqdm=True)
) -> str:
# cleanup of old video files
remove_files_older_than(RESULTS, 30)
OUTPUT_LENGTH = 4
video_info = sv.VideoInfo.from_video_path(input_video)
video_info.fps = video_info.fps // OUTPUT_LENGTH
total = calculate_end_frame_index(input_video, OUTPUT_LENGTH)
frame_generator = sv.get_video_frames_generator(
source_path=input_video,
end=total,
stride=OUTPUT_LENGTH
)
result_file_name = generate_file_name(extension="mp4")
result_file_path = os.path.join(RESULTS, result_file_name)
TRACKER.reset()
caption = None
with sv.VideoSink(result_file_path, video_info=video_info) as sink:
for _ in tqdm(range(total // OUTPUT_LENGTH), desc="Processing video..."):
frame = next(frame_generator)
if caption is None:
caption = run_captioning(
model=MODEL,
processor=PROCESSOR,
image=frame,
device=DEVICE
)[CAPTIONING_TASK]
detections = run_caption_to_phrase_grounding(
model=MODEL,
processor=PROCESSOR,
caption=caption,
image=frame,
device=DEVICE
)
detections.confidence = np.ones(len(detections))
detections.class_id = np.zeros(len(detections))
detections = TRACKER.update_with_detections(detections)
frame = annotate_image(
input_image=frame,
detections=detections
)
sink.write_frame(frame)
return result_file_path
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
input_video_component = gr.Video(
label='Input Video'
)
output_video_component = gr.Video(
label='Output Video'
)
with gr.Row():
submit_button_component = gr.Button(
value='Submit',
scale=1,
variant='primary'
)
submit_button_component.click(
fn=process_video,
inputs=[
input_video_component,
],
outputs=output_video_component
)
demo.launch(debug=False, show_error=True, max_threads=1)