import os import json import gradio as gr import tempfile import torch import spaces from pathlib import Path from transformers import AutoProcessor, AutoModelForVision2Seq import subprocess import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def load_examples(json_path: str) -> dict: with open(json_path, 'r') as f: return json.load(f) def format_duration(seconds: int) -> str: hours = seconds // 3600 minutes = (seconds % 3600) // 60 secs = seconds % 60 if hours > 0: return f"{hours}:{minutes:02d}:{secs:02d}" return f"{minutes}:{secs:02d}" def get_video_duration_seconds(video_path: str) -> float: """Use ffprobe to get video duration in seconds.""" cmd = [ "ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", video_path ] result = subprocess.run(cmd, capture_output=True, text=True) info = json.loads(result.stdout) return float(info["format"]["duration"]) class VideoHighlightDetector: def __init__( self, model_path: str, device: str = "cuda", batch_size: int = 8 ): self.device = device self.batch_size = batch_size # Initialize model and processor self.processor = AutoProcessor.from_pretrained(model_path) self.model = AutoModelForVision2Seq.from_pretrained( model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ).to(device) def analyze_video_content(self, video_path: str) -> str: """Analyze video content to determine its type and description.""" messages = [ { "role": "user", "content": [ {"type": "video", "path": video_path}, {"type": "text", "text": "What type of video is this and what's happening in it? Be specific about the content type and general activities you observe."} ] } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(self.device) outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) return self.processor.decode(outputs[0], skip_special_tokens=True) def determine_highlights(self, video_description: str) -> str: """Determine what constitutes highlights based on video description.""" messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels."}] }, { "role": "user", "content": [{"type": "text", "text": f"""Based on this video description: {video_description} List which rare segments should be included in a best of the best highlight."""}] } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(self.device) outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7) return self.processor.decode(outputs[0], skip_special_tokens=True) def process_segment(self, video_path: str, highlight_types: str) -> bool: """Process a video segment and determine if it contains highlights.""" messages = [ { "role": "user", "content": [ {"type": "video", "path": video_path}, {"type": "text", "text": f"""Do you see any of the following types of highlight moments in this video segment? Potential highlights to look for: {highlight_types} Only answer yes if you see any of those moments and answer no if you don't."""} ] } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(self.device) outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False) response = self.processor.decode(outputs[0], skip_special_tokens=True).lower() return "yes" in response def _concatenate_scenes( self, video_path: str, scene_times: list, output_path: str ): """Concatenate selected scenes into final video.""" if not scene_times: logger.warning("No scenes to concatenate, skipping.") return filter_complex_parts = [] concat_inputs = [] for i, (start_sec, end_sec) in enumerate(scene_times): filter_complex_parts.append( f"[0:v]trim=start={start_sec}:end={end_sec}," f"setpts=PTS-STARTPTS[v{i}];" ) filter_complex_parts.append( f"[0:a]atrim=start={start_sec}:end={end_sec}," f"asetpts=PTS-STARTPTS[a{i}];" ) concat_inputs.append(f"[v{i}][a{i}]") concat_filter = f"{''.join(concat_inputs)}concat=n={len(scene_times)}:v=1:a=1[outv][outa]" filter_complex = "".join(filter_complex_parts) + concat_filter cmd = [ "ffmpeg", "-y", "-i", video_path, "-filter_complex", filter_complex, "-map", "[outv]", "-map", "[outa]", "-c:v", "libx264", "-c:a", "aac", output_path ] logger.info(f"Running ffmpeg command: {' '.join(cmd)}") subprocess.run(cmd, check=True) def create_ui(examples_path: str, model_path: str): examples_data = load_examples(examples_path) with gr.Blocks() as app: gr.Markdown("# Video Highlight Generator") gr.Markdown("Upload a video and get an automated highlight reel!") with gr.Row(): gr.Markdown("## Example Results") with gr.Row(): for example in examples_data["examples"]: with gr.Column(): gr.Video( value=example["original"]["url"], label=f"Original ({format_duration(example['original']['duration_seconds'])})", interactive=False ) gr.Markdown(f"### {example['title']}") with gr.Column(): gr.Video( value=example["highlights"]["url"], label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})", interactive=False ) with gr.Accordion("Chain of thought details", open=False): gr.Markdown(f"### Summary:\n{example['analysis']['video_description']}") gr.Markdown(f"### Highlights to search for:\n{example['analysis']['highlight_types']}") gr.Markdown("## Try It Yourself!") with gr.Row(): with gr.Column(scale=1): input_video = gr.Video( label="Upload your video (max 30 minutes)", interactive=True ) process_btn = gr.Button("Process Video", variant="primary") with gr.Column(scale=1): output_video = gr.Video( label="Highlight Video", visible=False, interactive=False, ) status = gr.Markdown() analysis_accordion = gr.Accordion( "Chain of thought details", open=True, visible=False ) with analysis_accordion: video_description = gr.Markdown("", elem_id="video_desc") highlight_types = gr.Markdown("", elem_id="highlight_types") @spaces.GPU def on_process(video): # Clear all components when starting new processing yield [ "", # Clear status "", # Clear video description "", # Clear highlight types gr.update(value=None, visible=False), # Clear video gr.update(visible=False) # Hide accordion ] if not video: yield [ "Please upload a video", "", "", gr.update(visible=False), gr.update(visible=False) ] return try: duration = get_video_duration_seconds(video) if duration > 1800: # 30 minutes yield [ "Video must be shorter than 30 minutes", "", "", gr.update(visible=False), gr.update(visible=False) ] return yield [ "Initializing video highlight detector...", "", "", gr.update(visible=False), gr.update(visible=False) ] detector = VideoHighlightDetector( model_path=model_path, batch_size=8 ) yield [ "Analyzing video content...", "", "", gr.update(visible=False), gr.update(visible=True) ] video_desc = detector.analyze_video_content(video) formatted_desc = f"### Summary:\n {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}" yield [ "Determining highlight types...", formatted_desc, "", gr.update(visible=False), gr.update(visible=True) ] highlights = detector.determine_highlights(video_desc) formatted_highlights = f"### Highlights to search for:\n {highlights[:500] + '...' if len(highlights) > 500 else highlights}" # Split video into segments temp_dir = "temp_segments" os.makedirs(temp_dir, exist_ok=True) segment_length = 10.0 duration = get_video_duration_seconds(video) kept_segments = [] segments_processed = 0 total_segments = int(duration / segment_length) for start_time in range(0, int(duration), int(segment_length)): segments_processed += 1 progress = int((segments_processed / total_segments) * 100) yield [ f"Processing segments... {progress}% complete", formatted_desc, formatted_highlights, gr.update(visible=False), gr.update(visible=True) ] # Create segment segment_path = f"{temp_dir}/segment_{start_time}.mp4" end_time = min(start_time + segment_length, duration) cmd = [ "ffmpeg", "-y", "-i", video, "-ss", str(start_time), "-t", str(segment_length), "-c", "copy", segment_path ] subprocess.run(cmd, check=True) # Process segment if detector.process_segment(segment_path, highlights): kept_segments.append((start_time, end_time)) # Clean up segment file os.remove(segment_path) # Remove temp directory os.rmdir(temp_dir) # Create final video if kept_segments: with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: temp_output = tmp_file.name detector._concatenate_scenes(video, kept_segments, temp_output) yield [ "Processing complete!", formatted_desc, formatted_highlights, gr.update(value=temp_output, visible=True), gr.update(visible=True) ] else: yield [ "No highlights detected in the video.", formatted_desc, formatted_highlights, gr.update(visible=False), gr.update(visible=True) ] except Exception as e: logger.exception("Error processing video") yield [ f"Error processing video: {str(e)}", "", "", gr.update(visible=False), gr.update(visible=False) ] finally: # Clean up torch.cuda.empty_cache() process_btn.click( on_process, inputs=[input_video], outputs=[ status, video_description, highlight_types, output_video, analysis_accordion ], queue=True, ) return app if __name__ == "__main__": # Initialize CUDA device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MODEL_PATH = os.getenv("MODEL_PATH", "HuggingFaceTB/SmolVLM2-2.2B-Instruct") app = create_ui("video_spec.json", MODEL_PATH) app.launch()