mfarre HF staff commited on
Commit
be5d51f
·
1 Parent(s): 8323202
Files changed (2) hide show
  1. app.py +94 -75
  2. video_highlight_detector.py +25 -5
app.py CHANGED
@@ -63,7 +63,6 @@ def create_ui(examples_path: str):
63
  gr.Markdown(f"#Summary: {example['analysis']['video_description']}")
64
  gr.Markdown(f"#Highlights to search for: {example['analysis']['highlight_types']}")
65
 
66
-
67
  gr.Markdown("## Try It Yourself!")
68
  with gr.Row():
69
  with gr.Column(scale=1):
@@ -92,15 +91,20 @@ def create_ui(examples_path: str):
92
  video_description = gr.Markdown("", elem_id="video_desc")
93
  highlight_types = gr.Markdown("", elem_id="highlight_types")
94
 
 
 
 
 
 
95
  @spaces.GPU
96
  def on_process(video):
97
  if not video:
98
  yield [
99
- "Please upload a video", # status
100
- "", # video_description
101
- "", # highlight_types
102
- gr.update(visible=False), # output_video
103
- gr.update(visible=False) # analysis_accordion
104
  ]
105
  return
106
 
@@ -126,7 +130,8 @@ def create_ui(examples_path: str):
126
  ]
127
 
128
  model, processor = load_model()
129
- detector = BatchedVideoHighlightDetector(model, processor, batch_size=8)
 
130
 
131
  yield [
132
  "Analyzing video content...",
@@ -139,7 +144,6 @@ def create_ui(examples_path: str):
139
  video_desc = detector.analyze_video_content(video)
140
  formatted_desc = f"#Summary: {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}"
141
 
142
- # Update description as soon as it's available
143
  yield [
144
  "Determining highlight types...",
145
  formatted_desc,
@@ -151,14 +155,22 @@ def create_ui(examples_path: str):
151
  highlights = detector.determine_highlights(video_desc)
152
  formatted_highlights = f"#Highlights to search for: {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
153
 
154
- # Update highlights as soon as they're available
155
- yield [
156
- "Detecting and extracting highlights...",
157
- formatted_desc,
158
- formatted_highlights,
159
- gr.update(visible=False),
160
- gr.update(visible=True)
161
- ]
 
 
 
 
 
 
 
 
162
 
163
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
164
  temp_output = tmp_file.name
@@ -195,7 +207,6 @@ def create_ui(examples_path: str):
195
  )
196
 
197
  return app
198
-
199
  # gr.Markdown("## Try It Yourself!")
200
  # with gr.Row():
201
  # with gr.Column(scale=1):
@@ -227,99 +238,107 @@ def create_ui(examples_path: str):
227
  # @spaces.GPU
228
  # def on_process(video):
229
  # if not video:
230
- # return {
231
- # status: "Please upload a video",
232
- # video_description: "",
233
- # highlight_types: "",
234
- # output_video: gr.update(visible=False),
235
- # analysis_accordion: gr.update(visible=False)
236
- # }
 
237
 
238
  # try:
239
  # duration = get_video_duration_seconds(video)
240
  # if duration > 1200: # 20 minutes
241
- # return {
242
- # status: "Video must be shorter than 20 minutes",
243
- # video_description: "",
244
- # highlight_types: "",
245
- # output_video: gr.update(visible=False),
246
- # analysis_accordion: gr.update(visible=False)
247
- # }
 
248
 
249
  # # Make accordion visible as soon as processing starts
250
- # yield {
251
- # status: "Loading model...",
252
- # video_description: "",
253
- # highlight_types: "",
254
- # output_video: gr.update(visible=False),
255
- # analysis_accordion: gr.update(visible=True)
256
- # }
257
 
258
  # model, processor = load_model()
259
  # detector = BatchedVideoHighlightDetector(model, processor, batch_size=8)
260
 
261
- # yield {
262
- # status: "Analyzing video content...",
263
- # video_description: "",
264
- # highlight_types: "",
265
- # output_video: gr.update(visible=False),
266
- # analysis_accordion: gr.update(visible=True)
267
- # }
268
 
269
  # video_desc = detector.analyze_video_content(video)
270
  # formatted_desc = f"#Summary: {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}"
271
 
272
  # # Update description as soon as it's available
273
- # yield {
274
- # status: "Determining highlight types...",
275
- # video_description: formatted_desc,
276
- # highlight_types: "",
277
- # output_video: gr.update(visible=False),
278
- # analysis_accordion: gr.update(visible=True)
279
- # }
280
 
281
  # highlights = detector.determine_highlights(video_desc)
282
  # formatted_highlights = f"#Highlights to search for: {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
283
 
284
  # # Update highlights as soon as they're available
285
- # yield {
286
- # status: "Detecting and extracting highlights...",
287
- # video_description: formatted_desc,
288
- # highlight_types: formatted_highlights,
289
- # output_video: gr.update(visible=False),
290
- # analysis_accordion: gr.update(visible=True)
291
- # }
292
 
293
  # with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
294
  # temp_output = tmp_file.name
295
  # detector.create_highlight_video(video, temp_output)
296
 
297
- # return {
298
- # status: "Processing complete!",
299
- # video_description: formatted_desc,
300
- # highlight_types: formatted_highlights,
301
- # output_video: gr.update(value=temp_output, visible=True),
302
- # analysis_accordion: gr.update(visible=True)
303
- # }
304
 
305
  # except Exception as e:
306
- # return {
307
- # status: f"Error processing video: {str(e)}",
308
- # video_description: "",
309
- # highlight_types: "",
310
- # output_video: gr.update(visible=False),
311
- # analysis_accordion: gr.update(visible=False)
312
- # }
313
 
314
  # process_btn.click(
315
  # on_process,
316
  # inputs=[input_video],
317
- # outputs=[status, video_description, highlight_types, output_video, analysis_accordion]
 
 
 
 
 
 
 
318
  # )
319
 
320
  # return app
321
 
322
-
323
  if __name__ == "__main__":
324
  # Initialize CUDA
325
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
63
  gr.Markdown(f"#Summary: {example['analysis']['video_description']}")
64
  gr.Markdown(f"#Highlights to search for: {example['analysis']['highlight_types']}")
65
 
 
66
  gr.Markdown("## Try It Yourself!")
67
  with gr.Row():
68
  with gr.Column(scale=1):
 
91
  video_description = gr.Markdown("", elem_id="video_desc")
92
  highlight_types = gr.Markdown("", elem_id="highlight_types")
93
 
94
+ def progress_callback(current, total):
95
+ """Callback to update progress percentage"""
96
+ percentage = int((current / total) * 100)
97
+ return f"Processing segments... {percentage}% complete"
98
+
99
  @spaces.GPU
100
  def on_process(video):
101
  if not video:
102
  yield [
103
+ "Please upload a video",
104
+ "",
105
+ "",
106
+ gr.update(visible=False),
107
+ gr.update(visible=False)
108
  ]
109
  return
110
 
 
130
  ]
131
 
132
  model, processor = load_model()
133
+ detector = BatchedVideoHighlightDetector(model, processor, batch_size=8, progress_callback=lambda current, total: print(f"Progress: {current}/{total}")
134
+ )
135
 
136
  yield [
137
  "Analyzing video content...",
 
144
  video_desc = detector.analyze_video_content(video)
145
  formatted_desc = f"#Summary: {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}"
146
 
 
147
  yield [
148
  "Determining highlight types...",
149
  formatted_desc,
 
155
  highlights = detector.determine_highlights(video_desc)
156
  formatted_highlights = f"#Highlights to search for: {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
157
 
158
+ # Get total number of segments for progress tracking
159
+ segments = get_fixed_30s_segments(video)
160
+ total_segments = len(segments)
161
+
162
+ # Process segments in batches with progress updates
163
+ for i in range(0, total_segments, detector.batch_size):
164
+ current_batch = i + detector.batch_size
165
+ progress_msg = progress_callback(min(current_batch, total_segments), total_segments)
166
+
167
+ yield [
168
+ progress_msg,
169
+ formatted_desc,
170
+ formatted_highlights,
171
+ gr.update(visible=False),
172
+ gr.update(visible=True)
173
+ ]
174
 
175
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
176
  temp_output = tmp_file.name
 
207
  )
208
 
209
  return app
 
210
  # gr.Markdown("## Try It Yourself!")
211
  # with gr.Row():
212
  # with gr.Column(scale=1):
 
238
  # @spaces.GPU
239
  # def on_process(video):
240
  # if not video:
241
+ # yield [
242
+ # "Please upload a video", # status
243
+ # "", # video_description
244
+ # "", # highlight_types
245
+ # gr.update(visible=False), # output_video
246
+ # gr.update(visible=False) # analysis_accordion
247
+ # ]
248
+ # return
249
 
250
  # try:
251
  # duration = get_video_duration_seconds(video)
252
  # if duration > 1200: # 20 minutes
253
+ # yield [
254
+ # "Video must be shorter than 20 minutes",
255
+ # "",
256
+ # "",
257
+ # gr.update(visible=False),
258
+ # gr.update(visible=False)
259
+ # ]
260
+ # return
261
 
262
  # # Make accordion visible as soon as processing starts
263
+ # yield [
264
+ # "Loading model...",
265
+ # "",
266
+ # "",
267
+ # gr.update(visible=False),
268
+ # gr.update(visible=True)
269
+ # ]
270
 
271
  # model, processor = load_model()
272
  # detector = BatchedVideoHighlightDetector(model, processor, batch_size=8)
273
 
274
+ # yield [
275
+ # "Analyzing video content...",
276
+ # "",
277
+ # "",
278
+ # gr.update(visible=False),
279
+ # gr.update(visible=True)
280
+ # ]
281
 
282
  # video_desc = detector.analyze_video_content(video)
283
  # formatted_desc = f"#Summary: {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}"
284
 
285
  # # Update description as soon as it's available
286
+ # yield [
287
+ # "Determining highlight types...",
288
+ # formatted_desc,
289
+ # "",
290
+ # gr.update(visible=False),
291
+ # gr.update(visible=True)
292
+ # ]
293
 
294
  # highlights = detector.determine_highlights(video_desc)
295
  # formatted_highlights = f"#Highlights to search for: {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
296
 
297
  # # Update highlights as soon as they're available
298
+ # yield [
299
+ # "Detecting and extracting highlights...",
300
+ # formatted_desc,
301
+ # formatted_highlights,
302
+ # gr.update(visible=False),
303
+ # gr.update(visible=True)
304
+ # ]
305
 
306
  # with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
307
  # temp_output = tmp_file.name
308
  # detector.create_highlight_video(video, temp_output)
309
 
310
+ # yield [
311
+ # "Processing complete!",
312
+ # formatted_desc,
313
+ # formatted_highlights,
314
+ # gr.update(value=temp_output, visible=True),
315
+ # gr.update(visible=True)
316
+ # ]
317
 
318
  # except Exception as e:
319
+ # yield [
320
+ # f"Error processing video: {str(e)}",
321
+ # "",
322
+ # "",
323
+ # gr.update(visible=False),
324
+ # gr.update(visible=False)
325
+ # ]
326
 
327
  # process_btn.click(
328
  # on_process,
329
  # inputs=[input_video],
330
+ # outputs=[
331
+ # status,
332
+ # video_description,
333
+ # highlight_types,
334
+ # output_video,
335
+ # analysis_accordion
336
+ # ],
337
+ # queue=True,
338
  # )
339
 
340
  # return app
341
 
 
342
  if __name__ == "__main__":
343
  # Initialize CUDA
344
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
video_highlight_detector.py CHANGED
@@ -317,7 +317,8 @@ class BatchedVideoHighlightDetector:
317
  device="cuda",
318
  batch_size=8,
319
  max_frames_per_segment=32,
320
- target_fps=1.0
 
321
  ):
322
  self.model = model
323
  self.processor = processor
@@ -325,6 +326,7 @@ class BatchedVideoHighlightDetector:
325
  self.batch_size = batch_size
326
  self.max_frames_per_segment = max_frames_per_segment
327
  self.target_fps = target_fps
 
328
 
329
  def _extract_frames_batch(
330
  self,
@@ -466,10 +468,13 @@ class BatchedVideoHighlightDetector:
466
  self,
467
  video_path: str,
468
  segments: List[Tuple[float, float]],
469
- highlight_types: str
 
 
470
  ) -> List[bool]:
471
  """
472
  Process a batch of segments and return which ones contain highlights.
 
473
  """
474
  # Extract frames for all segments in batch
475
  frame_batches = self._extract_frames_batch(video_path, segments)
@@ -493,12 +498,17 @@ class BatchedVideoHighlightDetector:
493
  for output in outputs
494
  ]
495
 
 
 
 
 
496
  # Check for "yes" in responses
497
  return ["yes" in response for response in responses]
498
 
499
  def create_highlight_video(self, video_path: str, output_path: str) -> List[Tuple[float, float]]:
500
  """
501
  Main function that executes the batched highlight detection pipeline.
 
502
  """
503
  # Step 1: Analyze video content
504
  logger.info("Step 1: Analyzing video content...")
@@ -511,15 +521,25 @@ class BatchedVideoHighlightDetector:
511
  logger.info(f"Looking for highlights: {highlight_types}")
512
 
513
  # Step 3: Get all segments
514
- segments = self._get_fixed_30s_segments(video_path)
 
 
515
 
516
  # Step 4: Process segments in batches
517
  logger.info("Step 3: Detecting highlight segments in batches...")
518
  kept_segments = []
519
 
520
- for i in tqdm(range(0, len(segments), self.batch_size)):
521
  batch_segments = segments[i:i + self.batch_size]
522
- keep_flags = self._process_segment_batch(video_path, batch_segments, highlight_types)
 
 
 
 
 
 
 
 
523
 
524
  for segment, keep in zip(batch_segments, keep_flags):
525
  if keep:
 
317
  device="cuda",
318
  batch_size=8,
319
  max_frames_per_segment=32,
320
+ target_fps=1.0,
321
+ progress_callback=None
322
  ):
323
  self.model = model
324
  self.processor = processor
 
326
  self.batch_size = batch_size
327
  self.max_frames_per_segment = max_frames_per_segment
328
  self.target_fps = target_fps
329
+ self.progress_callback = progress_callback
330
 
331
  def _extract_frames_batch(
332
  self,
 
468
  self,
469
  video_path: str,
470
  segments: List[Tuple[float, float]],
471
+ highlight_types: str,
472
+ total_segments: int,
473
+ segments_processed: int
474
  ) -> List[bool]:
475
  """
476
  Process a batch of segments and return which ones contain highlights.
477
+ Now includes progress tracking.
478
  """
479
  # Extract frames for all segments in batch
480
  frame_batches = self._extract_frames_batch(video_path, segments)
 
498
  for output in outputs
499
  ]
500
 
501
+ # Update progress if callback is provided
502
+ if self.progress_callback:
503
+ self.progress_callback(segments_processed + len(segments), total_segments)
504
+
505
  # Check for "yes" in responses
506
  return ["yes" in response for response in responses]
507
 
508
  def create_highlight_video(self, video_path: str, output_path: str) -> List[Tuple[float, float]]:
509
  """
510
  Main function that executes the batched highlight detection pipeline.
511
+ Now includes progress tracking.
512
  """
513
  # Step 1: Analyze video content
514
  logger.info("Step 1: Analyzing video content...")
 
521
  logger.info(f"Looking for highlights: {highlight_types}")
522
 
523
  # Step 3: Get all segments
524
+ segments = get_fixed_30s_segments(video_path)
525
+ total_segments = len(segments)
526
+ segments_processed = 0
527
 
528
  # Step 4: Process segments in batches
529
  logger.info("Step 3: Detecting highlight segments in batches...")
530
  kept_segments = []
531
 
532
+ for i in range(0, len(segments), self.batch_size):
533
  batch_segments = segments[i:i + self.batch_size]
534
+ keep_flags = self._process_segment_batch(
535
+ video_path,
536
+ batch_segments,
537
+ highlight_types,
538
+ total_segments,
539
+ segments_processed
540
+ )
541
+
542
+ segments_processed += len(batch_segments)
543
 
544
  for segment, keep in zip(batch_segments, keep_flags):
545
  if keep: