mfarre HF staff commited on
Commit
141829b
·
1 Parent(s): e43a4bd

using transformers to handle the model

Browse files
Files changed (2) hide show
  1. app.py +226 -61
  2. requirements.txt +1 -1
app.py CHANGED
@@ -2,23 +2,15 @@ import os
2
  import json
3
  import gradio as gr
4
  import tempfile
5
- from PIL import Image, ImageDraw, ImageFont
6
- import cv2
7
- from typing import Tuple, Optional
8
- import torch
9
- from pathlib import Path
10
- import time
11
  import torch
12
  import spaces
13
- import os
14
-
 
 
15
 
16
- from video_highlight_detector import (
17
- load_model,
18
- BatchedVideoHighlightDetector,
19
- get_video_duration_seconds,
20
- get_fixed_30s_segments
21
- )
22
 
23
  def load_examples(json_path: str) -> dict:
24
  with open(json_path, 'r') as f:
@@ -32,8 +24,161 @@ def format_duration(seconds: int) -> str:
32
  return f"{hours}:{minutes:02d}:{secs:02d}"
33
  return f"{minutes}:{secs:02d}"
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def create_ui(examples_path: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  examples_data = load_examples(examples_path)
38
 
39
  with gr.Blocks() as app:
@@ -54,7 +199,6 @@ def create_ui(examples_path: str):
54
  gr.Markdown(f"### {example['title']}")
55
 
56
  with gr.Column():
57
-
58
  gr.Video(
59
  value=example["highlights"]["url"],
60
  label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})",
@@ -102,6 +246,7 @@ def create_ui(examples_path: str):
102
  gr.update(value=None, visible=False), # Clear video
103
  gr.update(visible=False) # Hide accordion
104
  ]
 
105
  if not video:
106
  yield [
107
  "Please upload a video",
@@ -124,19 +269,16 @@ def create_ui(examples_path: str):
124
  ]
125
  return
126
 
127
- # Make accordion visible as soon as processing starts
128
  yield [
129
- "Loading model...",
130
  "",
131
  "",
132
  gr.update(visible=False),
133
  gr.update(visible=False)
134
  ]
135
 
136
- model, processor = load_model()
137
- detector = BatchedVideoHighlightDetector(
138
- model,
139
- processor,
140
  batch_size=8
141
  )
142
 
@@ -161,18 +303,21 @@ def create_ui(examples_path: str):
161
 
162
  highlights = detector.determine_highlights(video_desc)
163
  formatted_highlights = f"### Highlights to search for:\n {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
 
 
 
 
164
 
165
- # Get all segments
166
- segments = get_fixed_30s_segments(video)
167
- total_segments = len(segments)
168
  kept_segments = []
169
-
170
- # Process segments in batches with direct UI updates
171
- for i in range(0, len(segments), detector.batch_size):
172
- batch_segments = segments[i:i + detector.batch_size]
 
 
173
 
174
- # Update progress
175
- progress = int((i / total_segments) * 100)
176
  yield [
177
  f"Processing segments... {progress}% complete",
178
  formatted_desc,
@@ -180,35 +325,56 @@ def create_ui(examples_path: str):
180
  gr.update(visible=False),
181
  gr.update(visible=True)
182
  ]
 
 
 
 
183
 
184
- # Process batch
185
- keep_flags = detector._process_segment_batch(
186
- video_path=video,
187
- segments=batch_segments,
188
- highlight_types=highlights,
189
- total_segments=total_segments,
190
- segments_processed=i
191
- )
 
 
 
 
 
 
192
 
193
- # Keep track of segments to include
194
- for segment, keep in zip(batch_segments, keep_flags):
195
- if keep:
196
- kept_segments.append(segment)
 
197
 
198
  # Create final video
199
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
200
- temp_output = tmp_file.name
201
- detector._concatenate_scenes(video, kept_segments, temp_output)
202
-
203
- yield [
204
- "Processing complete!",
205
- formatted_desc,
206
- formatted_highlights,
207
- gr.update(value=temp_output, visible=True),
208
- gr.update(visible=True)
209
- ]
 
 
 
 
 
 
 
 
 
210
 
211
  except Exception as e:
 
212
  yield [
213
  f"Error processing video: {str(e)}",
214
  "",
@@ -217,10 +383,8 @@ def create_ui(examples_path: str):
217
  gr.update(visible=False)
218
  ]
219
  finally:
220
- if model is not None:
221
- del model
222
- torch.cuda.empty_cache()
223
-
224
 
225
  process_btn.click(
226
  on_process,
@@ -240,7 +404,8 @@ def create_ui(examples_path: str):
240
  if __name__ == "__main__":
241
  # Initialize CUDA
242
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
243
- zero = torch.Tensor([0]).to(device)
244
-
245
- app = create_ui("video_spec.json")
 
246
  app.launch()
 
2
  import json
3
  import gradio as gr
4
  import tempfile
 
 
 
 
 
 
5
  import torch
6
  import spaces
7
+ from pathlib import Path
8
+ from transformers import AutoProcessor, AutoModelForVision2Seq
9
+ import subprocess
10
+ import logging
11
 
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
 
 
 
 
14
 
15
  def load_examples(json_path: str) -> dict:
16
  with open(json_path, 'r') as f:
 
24
  return f"{hours}:{minutes:02d}:{secs:02d}"
25
  return f"{minutes}:{secs:02d}"
26
 
27
+ def get_video_duration_seconds(video_path: str) -> float:
28
+ """Use ffprobe to get video duration in seconds."""
29
+ cmd = [
30
+ "ffprobe",
31
+ "-v", "quiet",
32
+ "-print_format", "json",
33
+ "-show_format",
34
+ video_path
35
+ ]
36
+ result = subprocess.run(cmd, capture_output=True, text=True)
37
+ info = json.loads(result.stdout)
38
+ return float(info["format"]["duration"])
39
 
40
+ class VideoHighlightDetector:
41
+ def __init__(
42
+ self,
43
+ model_path: str,
44
+ device: str = "cuda",
45
+ batch_size: int = 8
46
+ ):
47
+ self.device = device
48
+ self.batch_size = batch_size
49
+
50
+ # Initialize model and processor
51
+ self.processor = AutoProcessor.from_pretrained(model_path)
52
+ self.model = AutoModelForVision2Seq.from_pretrained(
53
+ model_path,
54
+ torch_dtype=torch.bfloat16,
55
+ attn_implementation="flash_attention_2"
56
+ ).to(device)
57
+
58
+ def analyze_video_content(self, video_path: str) -> str:
59
+ """Analyze video content to determine its type and description."""
60
+ messages = [
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {"type": "video", "path": video_path},
65
+ {"type": "text", "text": "What type of video is this and what's happening in it? Be specific about the content type and general activities you observe."}
66
+ ]
67
+ }
68
+ ]
69
+
70
+ inputs = self.processor.apply_chat_template(
71
+ messages,
72
+ add_generation_prompt=True,
73
+ tokenize=True,
74
+ return_dict=True,
75
+ return_tensors="pt"
76
+ ).to(self.device)
77
+
78
+ outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
79
+ return self.processor.decode(outputs[0], skip_special_tokens=True)
80
+
81
+ def determine_highlights(self, video_description: str) -> str:
82
+ """Determine what constitutes highlights based on video description."""
83
+ messages = [
84
+ {
85
+ "role": "system",
86
+ "content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels."}]
87
+ },
88
+ {
89
+ "role": "user",
90
+ "content": [{"type": "text", "text": f"""Based on this video description:
91
+
92
+ {video_description}
93
+
94
+ List which rare segments should be included in a best of the best highlight."""}]
95
+ }
96
+ ]
97
+
98
+ inputs = self.processor.apply_chat_template(
99
+ messages,
100
+ add_generation_prompt=True,
101
+ tokenize=True,
102
+ return_dict=True,
103
+ return_tensors="pt"
104
+ ).to(self.device)
105
+
106
+ outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
107
+ return self.processor.decode(outputs[0], skip_special_tokens=True)
108
+
109
+ def process_segment(self, video_path: str, highlight_types: str) -> bool:
110
+ """Process a video segment and determine if it contains highlights."""
111
+ messages = [
112
+ {
113
+ "role": "user",
114
+ "content": [
115
+ {"type": "video", "path": video_path},
116
+ {"type": "text", "text": f"""Do you see any of the following types of highlight moments in this video segment?
117
+
118
+ Potential highlights to look for:
119
+ {highlight_types}
120
+
121
+ Only answer yes if you see any of those moments and answer no if you don't."""}
122
+ ]
123
+ }
124
+ ]
125
+
126
+ inputs = self.processor.apply_chat_template(
127
+ messages,
128
+ add_generation_prompt=True,
129
+ tokenize=True,
130
+ return_dict=True,
131
+ return_tensors="pt"
132
+ ).to(self.device)
133
+
134
+ outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False)
135
+ response = self.processor.decode(outputs[0], skip_special_tokens=True).lower()
136
+
137
+ return "yes" in response
138
+
139
+ def _concatenate_scenes(
140
+ self,
141
+ video_path: str,
142
+ scene_times: list,
143
+ output_path: str
144
+ ):
145
+ """Concatenate selected scenes into final video."""
146
+ if not scene_times:
147
+ logger.warning("No scenes to concatenate, skipping.")
148
+ return
149
+
150
+ filter_complex_parts = []
151
+ concat_inputs = []
152
+ for i, (start_sec, end_sec) in enumerate(scene_times):
153
+ filter_complex_parts.append(
154
+ f"[0:v]trim=start={start_sec}:end={end_sec},"
155
+ f"setpts=PTS-STARTPTS[v{i}];"
156
+ )
157
+ filter_complex_parts.append(
158
+ f"[0:a]atrim=start={start_sec}:end={end_sec},"
159
+ f"asetpts=PTS-STARTPTS[a{i}];"
160
+ )
161
+ concat_inputs.append(f"[v{i}][a{i}]")
162
+
163
+ concat_filter = f"{''.join(concat_inputs)}concat=n={len(scene_times)}:v=1:a=1[outv][outa]"
164
+ filter_complex = "".join(filter_complex_parts) + concat_filter
165
+
166
+ cmd = [
167
+ "ffmpeg",
168
+ "-y",
169
+ "-i", video_path,
170
+ "-filter_complex", filter_complex,
171
+ "-map", "[outv]",
172
+ "-map", "[outa]",
173
+ "-c:v", "libx264",
174
+ "-c:a", "aac",
175
+ output_path
176
+ ]
177
+
178
+ logger.info(f"Running ffmpeg command: {' '.join(cmd)}")
179
+ subprocess.run(cmd, check=True)
180
+
181
+ def create_ui(examples_path: str, model_path: str):
182
  examples_data = load_examples(examples_path)
183
 
184
  with gr.Blocks() as app:
 
199
  gr.Markdown(f"### {example['title']}")
200
 
201
  with gr.Column():
 
202
  gr.Video(
203
  value=example["highlights"]["url"],
204
  label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})",
 
246
  gr.update(value=None, visible=False), # Clear video
247
  gr.update(visible=False) # Hide accordion
248
  ]
249
+
250
  if not video:
251
  yield [
252
  "Please upload a video",
 
269
  ]
270
  return
271
 
 
272
  yield [
273
+ "Initializing video highlight detector...",
274
  "",
275
  "",
276
  gr.update(visible=False),
277
  gr.update(visible=False)
278
  ]
279
 
280
+ detector = VideoHighlightDetector(
281
+ model_path=model_path,
 
 
282
  batch_size=8
283
  )
284
 
 
303
 
304
  highlights = detector.determine_highlights(video_desc)
305
  formatted_highlights = f"### Highlights to search for:\n {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
306
+
307
+ # Split video into segments
308
+ temp_dir = "temp_segments"
309
+ os.makedirs(temp_dir, exist_ok=True)
310
 
311
+ segment_length = 10.0
312
+ duration = get_video_duration_seconds(video)
 
313
  kept_segments = []
314
+ segments_processed = 0
315
+ total_segments = int(duration / segment_length)
316
+
317
+ for start_time in range(0, int(duration), int(segment_length)):
318
+ segments_processed += 1
319
+ progress = int((segments_processed / total_segments) * 100)
320
 
 
 
321
  yield [
322
  f"Processing segments... {progress}% complete",
323
  formatted_desc,
 
325
  gr.update(visible=False),
326
  gr.update(visible=True)
327
  ]
328
+
329
+ # Create segment
330
+ segment_path = f"{temp_dir}/segment_{start_time}.mp4"
331
+ end_time = min(start_time + segment_length, duration)
332
 
333
+ cmd = [
334
+ "ffmpeg",
335
+ "-y",
336
+ "-i", video,
337
+ "-ss", str(start_time),
338
+ "-t", str(segment_length),
339
+ "-c", "copy",
340
+ segment_path
341
+ ]
342
+ subprocess.run(cmd, check=True)
343
+
344
+ # Process segment
345
+ if detector.process_segment(segment_path, highlights):
346
+ kept_segments.append((start_time, end_time))
347
 
348
+ # Clean up segment file
349
+ os.remove(segment_path)
350
+
351
+ # Remove temp directory
352
+ os.rmdir(temp_dir)
353
 
354
  # Create final video
355
+ if kept_segments:
356
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
357
+ temp_output = tmp_file.name
358
+ detector._concatenate_scenes(video, kept_segments, temp_output)
359
+
360
+ yield [
361
+ "Processing complete!",
362
+ formatted_desc,
363
+ formatted_highlights,
364
+ gr.update(value=temp_output, visible=True),
365
+ gr.update(visible=True)
366
+ ]
367
+ else:
368
+ yield [
369
+ "No highlights detected in the video.",
370
+ formatted_desc,
371
+ formatted_highlights,
372
+ gr.update(visible=False),
373
+ gr.update(visible=True)
374
+ ]
375
 
376
  except Exception as e:
377
+ logger.exception("Error processing video")
378
  yield [
379
  f"Error processing video: {str(e)}",
380
  "",
 
383
  gr.update(visible=False)
384
  ]
385
  finally:
386
+ # Clean up
387
+ torch.cuda.empty_cache()
 
 
388
 
389
  process_btn.click(
390
  on_process,
 
404
  if __name__ == "__main__":
405
  # Initialize CUDA
406
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
407
+
408
+ MODEL_PATH = os.getenv("MODEL_PATH", "HuggingFaceTB/SmolVLM2-2.2B-Instruct")
409
+
410
+ app = create_ui("video_spec.json", MODEL_PATH)
411
  app.launch()
requirements.txt CHANGED
@@ -2,6 +2,6 @@ Pillow
2
  opencv-python
3
  num2words
4
  ffmpeg-python
5
- transformers
6
  accelerate>=0.26.0
7
  decord==0.6.0
 
2
  opencv-python
3
  num2words
4
  ffmpeg-python
5
+ transformers @ git+https://github.com/huggingface/transformers.git@refs/pull/36126/head
6
  accelerate>=0.26.0
7
  decord==0.6.0