Spaces:
Running
on
A100
Running
on
A100
import os | |
import json | |
import gradio as gr | |
import tempfile | |
import torch | |
import spaces | |
from pathlib import Path | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
import subprocess | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def load_examples(json_path: str) -> dict: | |
with open(json_path, 'r') as f: | |
return json.load(f) | |
def format_duration(seconds: int) -> str: | |
hours = seconds // 3600 | |
minutes = (seconds % 3600) // 60 | |
secs = seconds % 60 | |
if hours > 0: | |
return f"{hours}:{minutes:02d}:{secs:02d}" | |
return f"{minutes}:{secs:02d}" | |
def get_video_duration_seconds(video_path: str) -> float: | |
"""Use ffprobe to get video duration in seconds.""" | |
cmd = [ | |
"ffprobe", | |
"-v", "quiet", | |
"-print_format", "json", | |
"-show_format", | |
video_path | |
] | |
result = subprocess.run(cmd, capture_output=True, text=True) | |
info = json.loads(result.stdout) | |
return float(info["format"]["duration"]) | |
class VideoHighlightDetector: | |
def __init__( | |
self, | |
model_path: str, | |
device: str = "cuda", | |
batch_size: int = 8 | |
): | |
self.device = device | |
self.batch_size = batch_size | |
# Initialize model and processor | |
self.processor = AutoProcessor.from_pretrained(model_path) | |
self.model = AutoModelForVision2Seq.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2" | |
).to(device) | |
def analyze_video_content(self, video_path: str) -> str: | |
"""Analyze video content to determine its type and description.""" | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "video", "path": video_path}, | |
{"type": "text", "text": "What type of video is this and what's happening in it? Be specific about the content type and general activities you observe."} | |
] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) | |
return self.processor.decode(outputs[0], skip_special_tokens=True) | |
def determine_highlights(self, video_description: str) -> str: | |
"""Determine what constitutes highlights based on video description.""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels."}] | |
}, | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": f"""Based on this video description: | |
{video_description} | |
List which rare segments should be included in a best of the best highlight."""}] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7) | |
return self.processor.decode(outputs[0], skip_special_tokens=True) | |
def process_segment(self, video_path: str, highlight_types: str) -> bool: | |
"""Process a video segment and determine if it contains highlights.""" | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "video", "path": video_path}, | |
{"type": "text", "text": f"""Do you see any of the following types of highlight moments in this video segment? | |
Potential highlights to look for: | |
{highlight_types} | |
Only answer yes if you see any of those moments and answer no if you don't."""} | |
] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False) | |
response = self.processor.decode(outputs[0], skip_special_tokens=True).lower() | |
return "yes" in response | |
def _concatenate_scenes( | |
self, | |
video_path: str, | |
scene_times: list, | |
output_path: str | |
): | |
"""Concatenate selected scenes into final video.""" | |
if not scene_times: | |
logger.warning("No scenes to concatenate, skipping.") | |
return | |
filter_complex_parts = [] | |
concat_inputs = [] | |
for i, (start_sec, end_sec) in enumerate(scene_times): | |
filter_complex_parts.append( | |
f"[0:v]trim=start={start_sec}:end={end_sec}," | |
f"setpts=PTS-STARTPTS[v{i}];" | |
) | |
filter_complex_parts.append( | |
f"[0:a]atrim=start={start_sec}:end={end_sec}," | |
f"asetpts=PTS-STARTPTS[a{i}];" | |
) | |
concat_inputs.append(f"[v{i}][a{i}]") | |
concat_filter = f"{''.join(concat_inputs)}concat=n={len(scene_times)}:v=1:a=1[outv][outa]" | |
filter_complex = "".join(filter_complex_parts) + concat_filter | |
cmd = [ | |
"ffmpeg", | |
"-y", | |
"-i", video_path, | |
"-filter_complex", filter_complex, | |
"-map", "[outv]", | |
"-map", "[outa]", | |
"-c:v", "libx264", | |
"-c:a", "aac", | |
output_path | |
] | |
logger.info(f"Running ffmpeg command: {' '.join(cmd)}") | |
subprocess.run(cmd, check=True) | |
def create_ui(examples_path: str, model_path: str): | |
examples_data = load_examples(examples_path) | |
with gr.Blocks() as app: | |
gr.Markdown("# Video Highlight Generator") | |
gr.Markdown("Upload a video and get an automated highlight reel!") | |
with gr.Row(): | |
gr.Markdown("## Example Results") | |
with gr.Row(): | |
for example in examples_data["examples"]: | |
with gr.Column(): | |
gr.Video( | |
value=example["original"]["url"], | |
label=f"Original ({format_duration(example['original']['duration_seconds'])})", | |
interactive=False | |
) | |
gr.Markdown(f"### {example['title']}") | |
with gr.Column(): | |
gr.Video( | |
value=example["highlights"]["url"], | |
label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})", | |
interactive=False | |
) | |
with gr.Accordion("Chain of thought details", open=False): | |
gr.Markdown(f"### Summary:\n{example['analysis']['video_description']}") | |
gr.Markdown(f"### Highlights to search for:\n{example['analysis']['highlight_types']}") | |
gr.Markdown("## Try It Yourself!") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_video = gr.Video( | |
label="Upload your video (max 30 minutes)", | |
interactive=True | |
) | |
process_btn = gr.Button("Process Video", variant="primary") | |
with gr.Column(scale=1): | |
output_video = gr.Video( | |
label="Highlight Video", | |
visible=False, | |
interactive=False, | |
) | |
status = gr.Markdown() | |
analysis_accordion = gr.Accordion( | |
"Chain of thought details", | |
open=True, | |
visible=False | |
) | |
with analysis_accordion: | |
video_description = gr.Markdown("", elem_id="video_desc") | |
highlight_types = gr.Markdown("", elem_id="highlight_types") | |
def on_process(video): | |
# Clear all components when starting new processing | |
yield [ | |
"", # Clear status | |
"", # Clear video description | |
"", # Clear highlight types | |
gr.update(value=None, visible=False), # Clear video | |
gr.update(visible=False) # Hide accordion | |
] | |
if not video: | |
yield [ | |
"Please upload a video", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
return | |
try: | |
duration = get_video_duration_seconds(video) | |
if duration > 1800: # 30 minutes | |
yield [ | |
"Video must be shorter than 30 minutes", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
return | |
yield [ | |
"Initializing video highlight detector...", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
detector = VideoHighlightDetector( | |
model_path=model_path, | |
batch_size=8 | |
) | |
yield [ | |
"Analyzing video content...", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
video_desc = detector.analyze_video_content(video) | |
formatted_desc = f"### Summary:\n {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}" | |
yield [ | |
"Determining highlight types...", | |
formatted_desc, | |
"", | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
highlights = detector.determine_highlights(video_desc) | |
formatted_highlights = f"### Highlights to search for:\n {highlights[:500] + '...' if len(highlights) > 500 else highlights}" | |
# Split video into segments | |
temp_dir = "temp_segments" | |
os.makedirs(temp_dir, exist_ok=True) | |
segment_length = 10.0 | |
duration = get_video_duration_seconds(video) | |
kept_segments = [] | |
segments_processed = 0 | |
total_segments = int(duration / segment_length) | |
for start_time in range(0, int(duration), int(segment_length)): | |
segments_processed += 1 | |
progress = int((segments_processed / total_segments) * 100) | |
yield [ | |
f"Processing segments... {progress}% complete", | |
formatted_desc, | |
formatted_highlights, | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
# Create segment | |
segment_path = f"{temp_dir}/segment_{start_time}.mp4" | |
end_time = min(start_time + segment_length, duration) | |
cmd = [ | |
"ffmpeg", | |
"-y", | |
"-i", video, | |
"-ss", str(start_time), | |
"-t", str(segment_length), | |
"-c", "copy", | |
segment_path | |
] | |
subprocess.run(cmd, check=True) | |
# Process segment | |
if detector.process_segment(segment_path, highlights): | |
kept_segments.append((start_time, end_time)) | |
# Clean up segment file | |
os.remove(segment_path) | |
# Remove temp directory | |
os.rmdir(temp_dir) | |
# Create final video | |
if kept_segments: | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
temp_output = tmp_file.name | |
detector._concatenate_scenes(video, kept_segments, temp_output) | |
yield [ | |
"Processing complete!", | |
formatted_desc, | |
formatted_highlights, | |
gr.update(value=temp_output, visible=True), | |
gr.update(visible=True) | |
] | |
else: | |
yield [ | |
"No highlights detected in the video.", | |
formatted_desc, | |
formatted_highlights, | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
except Exception as e: | |
logger.exception("Error processing video") | |
yield [ | |
f"Error processing video: {str(e)}", | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
finally: | |
# Clean up | |
torch.cuda.empty_cache() | |
process_btn.click( | |
on_process, | |
inputs=[input_video], | |
outputs=[ | |
status, | |
video_description, | |
highlight_types, | |
output_video, | |
analysis_accordion | |
], | |
queue=True, | |
) | |
return app | |
if __name__ == "__main__": | |
# Initialize CUDA | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
MODEL_PATH = os.getenv("MODEL_PATH", "HuggingFaceTB/SmolVLM2-2.2B-Instruct") | |
app = create_ui("video_spec.json", MODEL_PATH) | |
app.launch() |