Spaces:
Running
on
A100
Running
on
A100
import os | |
import json | |
import gradio as gr | |
import tempfile | |
from PIL import Image, ImageDraw, ImageFont | |
import cv2 | |
from typing import Tuple, Optional | |
import torch | |
from pathlib import Path | |
import time | |
import torch | |
import spaces | |
import os | |
from video_highlight_detector import ( | |
load_model, | |
BatchedVideoHighlightDetector, | |
get_video_duration_seconds, | |
get_fixed_30s_segments | |
) | |
def load_examples(json_path: str) -> dict: | |
with open(json_path, 'r') as f: | |
return json.load(f) | |
def format_duration(seconds: int) -> str: | |
hours = seconds // 3600 | |
minutes = (seconds % 3600) // 60 | |
secs = seconds % 60 | |
if hours > 0: | |
return f"{hours}:{minutes:02d}:{secs:02d}" | |
return f"{minutes}:{secs:02d}" | |
def create_ui(examples_path: str): | |
examples_data = load_examples(examples_path) | |
with gr.Blocks() as app: | |
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/SmolVLM2-highlight-generator.png" | |
gr.Image(value=img_url, height=300, show_label=False) | |
gr.Markdown("Upload a video and get an automated highlight reel!") | |
with gr.Row(): | |
gr.Markdown("## Example Results") | |
with gr.Row(): | |
for example in examples_data["examples"]: | |
with gr.Column(): | |
gr.Video( | |
value=example["original"]["url"], | |
label=f"Original ({format_duration(example['original']['duration_seconds'])})", | |
interactive=False | |
) | |
gr.Markdown(f"### {example['title']}") | |
with gr.Column(): | |
gr.Video( | |
value=example["highlights"]["url"], | |
label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})", | |
interactive=False | |
) | |
with gr.Accordion("Chain of thought details", open=False): | |
gr.Markdown(f"### Summary:\n{example['analysis']['video_description']}") | |
gr.Markdown(f"### Highlights to search for:\n{example['analysis']['highlight_types']}") | |
gr.Markdown("## Try It Yourself!") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_video = gr.Video( | |
label="Upload your video (max 30 minutes)", | |
interactive=True | |
) | |
process_btn = gr.Button("Process Video", variant="primary") | |
with gr.Column(scale=1): | |
output_video = gr.Video( | |
label="Highlight Video", | |
visible=False, | |
interactive=False, | |
) | |
status = gr.Markdown() | |
analysis_accordion = gr.Accordion( | |
"Chain of thought details", | |
open=True, | |
visible=False | |
) | |
with analysis_accordion: | |
video_description = gr.Markdown("", elem_id="video_desc") | |
highlight_types = gr.Markdown("", elem_id="highlight_types") | |
def on_process(video): | |
# Clear all components when starting new processing | |
yield [ | |
"", # Clear status | |
"", # Clear video description | |
"", # Clear highlight types | |
gr.update(value=None, visible=False), # Clear video | |
gr.update(visible=False) # Hide accordion | |
] | |
if not video: | |
yield [ | |
"Please upload a video", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
return | |
try: | |
duration = get_video_duration_seconds(video) | |
if duration > 1800: # 30 minutes | |
yield [ | |
"Video must be shorter than 30 minutes", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
return | |
# Make accordion visible as soon as processing starts | |
yield [ | |
"Loading model...", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
model, processor = load_model() | |
detector = BatchedVideoHighlightDetector( | |
model, | |
processor, | |
batch_size=8 | |
) | |
yield [ | |
"Analyzing video content...", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
video_desc = detector.analyze_video_content(video) | |
formatted_desc = f"### Summary:\n {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}" | |
yield [ | |
"Determining highlight types...", | |
formatted_desc, | |
"", | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
highlights = detector.determine_highlights(video_desc) | |
formatted_highlights = f"### Highlights to search for:\n {highlights[:500] + '...' if len(highlights) > 500 else highlights}" | |
# Get all segments | |
segments = get_fixed_30s_segments(video) | |
total_segments = len(segments) | |
kept_segments = [] | |
# Process segments in batches with direct UI updates | |
for i in range(0, len(segments), detector.batch_size): | |
batch_segments = segments[i:i + detector.batch_size] | |
# Update progress | |
progress = int((i / total_segments) * 100) | |
yield [ | |
f"Processing segments... {progress}% complete", | |
formatted_desc, | |
formatted_highlights, | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
# Process batch | |
keep_flags = detector._process_segment_batch( | |
video_path=video, | |
segments=batch_segments, | |
highlight_types=highlights, | |
total_segments=total_segments, | |
segments_processed=i | |
) | |
# Keep track of segments to include | |
for segment, keep in zip(batch_segments, keep_flags): | |
if keep: | |
kept_segments.append(segment) | |
# Create final video | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
temp_output = tmp_file.name | |
detector._concatenate_scenes(video, kept_segments, temp_output) | |
yield [ | |
"Processing complete!", | |
formatted_desc, | |
formatted_highlights, | |
gr.update(value=temp_output, visible=True), | |
gr.update(visible=True) | |
] | |
except Exception as e: | |
yield [ | |
f"Error processing video: {str(e)}", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
finally: | |
if model is not None: | |
del model | |
torch.cuda.empty_cache() | |
process_btn.click( | |
on_process, | |
inputs=[input_video], | |
outputs=[ | |
status, | |
video_description, | |
highlight_types, | |
output_video, | |
analysis_accordion | |
], | |
queue=True, | |
) | |
return app | |
if __name__ == "__main__": | |
# Initialize CUDA | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
zero = torch.Tensor([0]).to(device) | |
app = create_ui("video_spec.json") | |
app.launch() |