mfarre HF staff commited on
Commit
880de81
·
1 Parent(s): 1d8a145

initial test

Browse files
Files changed (5) hide show
  1. app.py +192 -0
  2. modeling_smolvlm.py +297 -0
  3. requirements.txt +4 -0
  4. video_highlight_detector.py +785 -0
  5. video_spec.json +62 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ import tempfile
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import cv2
7
+ from typing import Tuple, Optional
8
+ import torch
9
+ import spaces
10
+ from pathlib import Path
11
+ import time
12
+
13
+ # Import your highlight detection code
14
+ from video_highlight_detector import (
15
+ load_model,
16
+ BatchedVideoHighlightDetector,
17
+ get_video_duration_seconds
18
+ )
19
+
20
+ def load_examples(json_path: str) -> dict:
21
+ """Load pre-computed examples from JSON file"""
22
+ with open(json_path, 'r') as f:
23
+ return json.load(f)
24
+
25
+ def format_duration(seconds: int) -> str:
26
+ """Convert seconds to MM:SS or HH:MM:SS format"""
27
+ hours = seconds // 3600
28
+ minutes = (seconds % 3600) // 60
29
+ secs = seconds % 60
30
+ if hours > 0:
31
+ return f"{hours}:{minutes:02d}:{secs:02d}"
32
+ return f"{minutes}:{secs:02d}"
33
+
34
+ def add_watermark(video_path: str, output_path: str):
35
+ """Add watermark to video using ffmpeg"""
36
+ watermark_text = "🤗 SmolVLM2 Highlight"
37
+ command = f"""ffmpeg -i {video_path} -vf \
38
+ "drawtext=text='{watermark_text}':fontcolor=white:fontsize=24:box=1:[email protected]:\
39
+ boxborderw=5:x=w-tw-10:y=h-th-10" \
40
+ -codec:a copy {output_path}"""
41
+ os.system(command)
42
+
43
+ def process_video(
44
+ video_path: str,
45
+ progress = gr.Progress()
46
+ ) -> Tuple[str, str, str, str]:
47
+ """
48
+ Process video and return paths to:
49
+ - Processed video with watermark
50
+ - Video description
51
+ - Highlight types
52
+ - Error message (if any)
53
+ """
54
+ try:
55
+ # Check video duration
56
+ duration = get_video_duration_seconds(video_path)
57
+ if duration > 1200: # 20 minutes
58
+ return None, None, None, "Video must be shorter than 20 minutes"
59
+
60
+ # Load model (could be cached)
61
+ progress(0.1, desc="Loading model...")
62
+ model, processor = load_model()
63
+ detector = BatchedVideoHighlightDetector(model, processor)
64
+
65
+ # Analyze video content
66
+ progress(0.2, desc="Analyzing video content...")
67
+ video_description = detector.analyze_video_content(video_path)
68
+
69
+ # Determine highlights
70
+ progress(0.3, desc="Determining highlight types...")
71
+ highlight_types = detector.determine_highlights(video_description)
72
+
73
+ # Create highlight video
74
+ progress(0.4, desc="Detecting and extracting highlights...")
75
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
76
+ temp_output = tmp_file.name
77
+
78
+ detector.create_highlight_video(video_path, temp_output)
79
+
80
+ # Add watermark
81
+ progress(0.9, desc="Adding watermark...")
82
+ output_path = temp_output.replace('.mp4', '_watermark.mp4')
83
+ add_watermark(temp_output, output_path)
84
+
85
+ # Cleanup
86
+ os.unlink(temp_output)
87
+
88
+ # Truncate description and highlights if too long
89
+ video_description = video_description[:500] + "..." if len(video_description) > 500 else video_description
90
+ highlight_types = highlight_types[:500] + "..." if len(highlight_types) > 500 else highlight_types
91
+
92
+ return output_path, video_description, highlight_types, None
93
+
94
+ except Exception as e:
95
+ return None, None, None, f"Error processing video: {str(e)}"
96
+
97
+
98
+ def create_ui(examples_path: str):
99
+ """Create the Gradio interface with optional thumbnails"""
100
+ examples_data = load_examples(examples_path)
101
+
102
+ with gr.Blocks() as app:
103
+ gr.Markdown("# Video Highlight Generator")
104
+ gr.Markdown("Upload a video (max 20 minutes) and get an automated highlight reel!")
105
+
106
+ # Pre-computed examples section
107
+ with gr.Row():
108
+ gr.Markdown("## Example Results")
109
+
110
+ for example in examples_data["examples"]:
111
+ with gr.Row():
112
+ with gr.Column():
113
+ # Use thumbnail if available, otherwise default to video
114
+ video_component = gr.Video(
115
+ example["original"]["url"],
116
+ label=f"Original ({format_duration(example['original']['duration_seconds'])})",
117
+ thumbnail=example["original"].get("thumbnail_url", None)
118
+ )
119
+ gr.Markdown(example["title"])
120
+
121
+ with gr.Column():
122
+ gr.Video(
123
+ example["highlights"]["url"],
124
+ label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})",
125
+ thumbnail=example["highlights"].get("thumbnail_url", None)
126
+ )
127
+ with gr.Accordion("Analysis", open=False):
128
+ gr.Markdown(example["analysis"]["video_description"])
129
+ gr.Markdown(example["analysis"]["highlight_types"])
130
+
131
+ # Upload section
132
+ gr.Markdown("## Try It Yourself!")
133
+ with gr.Row():
134
+ input_video = gr.Video(
135
+ label="Upload your video (max 20 minutes)",
136
+ source="upload"
137
+ )
138
+
139
+ # Results section (initially hidden)
140
+ with gr.Row(visible=False) as results_row:
141
+ with gr.Column():
142
+ video_description = gr.Markdown(label="Video Analysis")
143
+ with gr.Column():
144
+ highlight_types = gr.Markdown(label="Detected Highlights")
145
+
146
+ with gr.Row(visible=False) as output_row:
147
+ output_video = gr.Video(label="Highlight Video")
148
+ download_btn = gr.Button("Download Highlights")
149
+
150
+ # Error message
151
+ error_msg = gr.Markdown(visible=False)
152
+
153
+ # Process video when uploaded
154
+ def on_upload(video):
155
+ results_row.visible = False
156
+ output_row.visible = False
157
+ error_msg.visible = False
158
+
159
+ if not video:
160
+ error_msg.visible = True
161
+ error_msg.value = "Please upload a video"
162
+ return None, None, None, error_msg
163
+
164
+ output_path, desc, highlights, err = process_video(video)
165
+
166
+ if err:
167
+ error_msg.visible = True
168
+ error_msg.value = err
169
+ return None, None, None, error_msg
170
+
171
+ results_row.visible = True
172
+ output_row.visible = True
173
+ return output_path, desc, highlights, ""
174
+
175
+ input_video.change(
176
+ on_upload,
177
+ inputs=[input_video],
178
+ outputs=[output_video, video_description, highlight_types, error_msg]
179
+ )
180
+
181
+ # Download button
182
+ download_btn.click(
183
+ lambda x: x,
184
+ inputs=[output_video],
185
+ outputs=[output_video]
186
+ )
187
+
188
+ return app
189
+
190
+ if __name__ == "__main__":
191
+ app = create_ui("video_spec.json")
192
+ app.launch()
modeling_smolvlm.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import Idefics3Model, Idefics3ForConditionalGeneration
4
+ from typing import Dict, Any, List, Optional, Union, Tuple
5
+ from transformers.cache_utils import Cache, DynamicCache
6
+
7
+ from transformers.utils import add_start_docstrings_to_model_forward, logging
8
+ from transformers.models.idefics3.modeling_idefics3 import IDEFICS3_INPUTS_DOCSTRING, Idefics3BaseModelOutputWithPast
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ class SmolVLMModel(Idefics3Model):
13
+ """
14
+ A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
15
+ in forward. Instead, we override inputs_merger here with custom logic.
16
+ """
17
+ def inputs_merger(
18
+ self,
19
+ input_ids: torch.LongTensor,
20
+ inputs_embeds: torch.Tensor,
21
+ image_hidden_states: torch.Tensor
22
+ ) -> torch.Tensor:
23
+ """
24
+ Merge text embeddings with image embeddings out-of-place (no in-place indexing).
25
+
26
+ The shapes are something like:
27
+ - input_ids: (B, T)
28
+ - inputs_embeds: (B, T, D)
29
+ - image_hidden_states:(N, S, D) where N is total images across the batch,
30
+ S is #patches (or #slots) per image, D is embedding dim.
31
+
32
+ Logic:
33
+ 1) For each sample in the batch, find <image> tokens in the text.
34
+ 2) If zero <image> tokens => text-only. Concatenate a zero-length slice
35
+ from image_hidden_states but do NOT advance the offset. This ensures
36
+ the model's image encoder is still in the computation graph, but we
37
+ skip "consuming" any image block for a text-only sample.
38
+ 3) If there are <image> tokens, they appear in multiples of S for each image
39
+ (because each image is S embeddings). We chunk those positions into groups
40
+ of S. For each chunk => we consume one block from image_hidden_states[offset]
41
+ (which is shape (S, D)), and place each row into the text in place of a token.
42
+
43
+ Returns:
44
+ A tensor of (B, T, D).
45
+ """
46
+
47
+ ##############################################
48
+ # 1) Basic shape checks
49
+ ##############################################
50
+ #old_merger_outputs = self.inputs_merger_old(input_ids, inputs_embeds, image_hidden_states)
51
+ B, T, D_text = inputs_embeds.shape
52
+ N, S, D_img = image_hidden_states.shape
53
+ if D_text != D_img:
54
+ raise ValueError(
55
+ f"Text embedding dim {D_text} != image embedding dim {D_img}"
56
+ )
57
+
58
+ ##############################################
59
+ # 2) We'll track how many images we've used so far across the entire batch
60
+ ##############################################
61
+ image_offset = 0
62
+
63
+ # We'll store one merged tensor per batch sample
64
+ merged_outputs: List[torch.Tensor] = []
65
+
66
+ ##############################################
67
+ # 3) Iterate through each sample
68
+ ##############################################
69
+ for b_idx, (cur_ids, cur_embeds) in enumerate(zip(input_ids, inputs_embeds)):
70
+ # Find positions of <image> tokens in the text
71
+ image_positions = (cur_ids == self.image_token_id).nonzero(as_tuple=True)[0]
72
+ num_image_tokens = len(image_positions)
73
+
74
+ # If no <image> => text-only
75
+ if num_image_tokens == 0:
76
+ # We do not consume any row from image_hidden_states;
77
+ # but we do a zero-length slice so the image encoder is in the graph.
78
+ empty_slice = image_hidden_states[0][:0, :] # shape (0, D)
79
+ # Concatenate text plus that empty slice.
80
+ # NOTE: this is important for DeepSpeed.
81
+ merged_text_only = torch.cat([cur_embeds, empty_slice], dim=0)
82
+ merged_outputs.append(merged_text_only)
83
+ continue
84
+
85
+ # Otherwise, we have at least one <image> token.
86
+ # Typically, if each image is S embeddings, we expect the total # of <image> tokens
87
+ # in this sample to be multiple of S => each group of S tokens = 1 image
88
+ if num_image_tokens % S != 0:
89
+ raise ValueError(
90
+ f"Sample {b_idx} has {num_image_tokens} <image> tokens, not a multiple of S={S}. "
91
+ "Cannot map them to blocks of shape (S, D)."
92
+ )
93
+
94
+ # We'll chunk image_positions into groups of size S
95
+ positions_list = image_positions.tolist()
96
+ # Example: if num_image_tokens=162 and S=81 => we have 2 images => 2 chunks each of length 81
97
+ chunks = [
98
+ positions_list[i : i + S]
99
+ for i in range(0, num_image_tokens, S)
100
+ ]
101
+
102
+ # We'll build a list of segments: text, then image row(s), text, etc.
103
+ segments = []
104
+ text_start = 0
105
+
106
+ # For each chunk (each chunk => 1 image)
107
+ for chunk in chunks:
108
+ # image_hidden_states[image_offset] => shape (S, D)
109
+ cur_block = image_hidden_states[image_offset]
110
+ image_offset += 1
111
+
112
+ # We'll iterate over the S positions in ascending order
113
+ for i_s, pos in enumerate(chunk):
114
+ # Add text from [text_start..pos)
115
+ if pos > text_start:
116
+ segments.append(cur_embeds[text_start:pos])
117
+ # Then add one row from cur_block => shape (1, D)
118
+ row_of_block = cur_block[i_s : i_s + 1, :]
119
+ segments.append(row_of_block)
120
+ # skip the <image> token
121
+ text_start = pos + 1
122
+
123
+ # leftover text after the final <image> token
124
+ if text_start < T:
125
+ segments.append(cur_embeds[text_start:])
126
+
127
+ # cat them into a single (T_b, D) tensor
128
+ merged_sample = torch.cat(segments, dim=0)
129
+ merged_outputs.append(merged_sample)
130
+
131
+ merged_outputs = torch.stack(merged_outputs)
132
+ #assert (old_merger_outputs==merged_outputs).all()
133
+ return merged_outputs
134
+
135
+
136
+ @add_start_docstrings_to_model_forward(
137
+ """
138
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
139
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
140
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
141
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
142
+ For efficiency, we only pass through the vision_model's forward the real images by
143
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
144
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
145
+ """,
146
+ IDEFICS3_INPUTS_DOCSTRING,
147
+ )
148
+ def forward(
149
+ self,
150
+ input_ids: torch.LongTensor = None,
151
+ attention_mask: Optional[torch.Tensor] = None,
152
+ position_ids: Optional[torch.LongTensor] = None,
153
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
154
+ inputs_embeds: Optional[torch.FloatTensor] = None,
155
+ pixel_values: Optional[torch.FloatTensor] = None,
156
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
157
+ image_hidden_states: Optional[torch.FloatTensor] = None,
158
+ use_cache: Optional[bool] = None,
159
+ output_attentions: Optional[bool] = None,
160
+ output_hidden_states: Optional[bool] = None,
161
+ return_dict: Optional[bool] = None,
162
+ ) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
163
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
164
+ output_hidden_states = (
165
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
166
+ )
167
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
+
170
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
171
+ logger.warning_once(
172
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
173
+ )
174
+ use_cache = False
175
+
176
+ # retrieve input_ids and inputs_embeds
177
+ if input_ids is not None:
178
+ batch_size, seq_length = input_ids.shape
179
+ elif inputs_embeds is not None:
180
+ batch_size, seq_length, _ = inputs_embeds.shape
181
+ else:
182
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
183
+
184
+ past_seen_tokens = 0
185
+ if use_cache:
186
+ if past_key_values is None:
187
+ past_key_values = DynamicCache()
188
+ past_seen_tokens = past_key_values.get_seq_length()
189
+
190
+ if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
191
+ raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
192
+
193
+ if inputs_embeds is None:
194
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
195
+
196
+ # START VISUAL INPUTS INTEGRATION
197
+ if pixel_values is not None and image_hidden_states is not None:
198
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
199
+ elif pixel_values is not None:
200
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
201
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
202
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
203
+
204
+ # Remove padding images - padding images are full 0.
205
+ nb_values_per_image = pixel_values.shape[1:].numel()
206
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
207
+
208
+ if not any(real_images_inds):
209
+ # no images, leave one empty image.
210
+ real_images_inds[0] = True
211
+
212
+ pixel_values = pixel_values[real_images_inds].contiguous()
213
+
214
+ # Handle the vision attention mask
215
+ if pixel_attention_mask is None:
216
+ pixel_attention_mask = torch.ones(
217
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
218
+ dtype=torch.bool,
219
+ device=pixel_values.device,
220
+ )
221
+ else:
222
+ # Remove padding images from the mask
223
+ pixel_attention_mask = pixel_attention_mask.view(
224
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
225
+ )
226
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
227
+
228
+ patch_size = self.config.vision_config.patch_size
229
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
230
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
231
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
232
+
233
+ # Get sequence from the vision encoder
234
+ image_hidden_states = self.vision_model(
235
+ pixel_values=pixel_values,
236
+ patch_attention_mask=patch_attention_mask,
237
+ ).last_hidden_state
238
+
239
+ # Modality projection & resampling
240
+ image_hidden_states = self.connector(image_hidden_states)
241
+
242
+ elif image_hidden_states is not None:
243
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
244
+
245
+ if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
246
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
247
+ # that simply don't exist
248
+ inputs_embeds = self.inputs_merger(
249
+ input_ids=input_ids,
250
+ inputs_embeds=inputs_embeds,
251
+ image_hidden_states=image_hidden_states,
252
+ )
253
+
254
+ outputs = self.text_model(
255
+ inputs_embeds=inputs_embeds,
256
+ attention_mask=attention_mask,
257
+ position_ids=position_ids,
258
+ past_key_values=past_key_values,
259
+ use_cache=use_cache,
260
+ output_attentions=output_attentions,
261
+ output_hidden_states=output_hidden_states,
262
+ return_dict=return_dict,
263
+ )
264
+
265
+ if not return_dict:
266
+ return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
267
+
268
+ return Idefics3BaseModelOutputWithPast(
269
+ last_hidden_state=outputs.last_hidden_state,
270
+ past_key_values=outputs.past_key_values,
271
+ hidden_states=outputs.hidden_states,
272
+ attentions=outputs.attentions,
273
+ image_hidden_states=image_hidden_states,
274
+ )
275
+
276
+
277
+
278
+
279
+ class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
280
+ """
281
+ A subclass of Idefics3ForConditionalGeneration that uses MyIdefics3Model
282
+ instead of the default Idefics3Model.
283
+ """
284
+
285
+ def __init__(self, config):
286
+ super().__init__(config)
287
+ # Instead of the original self.model = Idefics3Model(config),
288
+ # we point to our custom class.
289
+ self.model = SmolVLMModel(config)
290
+
291
+ # We *keep* the same lm_head from the parent, or re-init if you prefer:
292
+ self.lm_head = nn.Linear(
293
+ config.text_config.hidden_size, config.text_config.vocab_size, bias=False
294
+ )
295
+
296
+ # If parent sets up any post_init() logic:
297
+ self.post_init()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Pillow
2
+ opencv-python
3
+ num2words
4
+ ffmpeg-python
video_highlight_detector.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Tuple, Dict, Optional
3
+ from tqdm import tqdm
4
+ import logging
5
+ from PIL import Image
6
+ import datetime
7
+ from num2words import num2words
8
+ import subprocess
9
+ import sys
10
+ from modeling_smolvlm import SmolVLMForConditionalGeneration
11
+ from transformers import AutoProcessor, AutoTokenizer
12
+ import json
13
+ import math
14
+ logger = logging.getLogger(__name__)
15
+ logging.basicConfig(
16
+ level=logging.INFO
17
+ )
18
+
19
+
20
+
21
+ SYSTEM_MESSAGE = (
22
+ "Carefully watch the video and pay attention to the cause and sequence of events, "
23
+ "the detail and movement of objects, and the action and pose of persons. "
24
+ "Based on your observations, answer the question with yes or no."
25
+ " <end_of_utterance>"
26
+ )
27
+
28
+ FRAME_TIMESTAMP_MESSAGE = "Frame from"
29
+ DEFAULT_VIDEO_INTRO = (
30
+ "You are provided the following series of {frame_count} frames "
31
+ "from a {video_duration} [H:MM:SS] video.\n"
32
+ )
33
+
34
+ # ----------------------------------------------------------------------
35
+ # Helper functions for resizing, etc.
36
+ # ----------------------------------------------------------------------
37
+
38
+ def round_by_factor(number: float, factor: int) -> int:
39
+ return round(number / factor) * factor
40
+
41
+ def ceil_by_factor(number: float, factor: int) -> int:
42
+ return math.ceil(number / factor) * factor
43
+
44
+ def floor_by_factor(number: float, factor: int) -> int:
45
+ return math.floor(number / factor) * factor
46
+
47
+ def smart_resize(
48
+ height: int,
49
+ width: int,
50
+ factor: int,
51
+ min_pixels: int,
52
+ max_pixels: int,
53
+ max_ratio: float,
54
+ ) -> Tuple[int, int]:
55
+ """
56
+ Rescale (height, width) so that:
57
+ - aspect ratio <= max_ratio
58
+ - total area in [min_pixels, max_pixels]
59
+ - each dimension is multiple of factor
60
+ """
61
+ ratio = max(height, width) / min(height, width)
62
+ if ratio > max_ratio:
63
+ raise ValueError(f"Aspect ratio {ratio:.2f} > {max_ratio}")
64
+
65
+ h_ = max(factor, round_by_factor(height, factor))
66
+ w_ = max(factor, round_by_factor(width, factor))
67
+ area = h_ * w_
68
+
69
+ if area > max_pixels:
70
+ scale = math.sqrt((height * width) / max_pixels)
71
+ h_ = floor_by_factor(height / scale, factor)
72
+ w_ = floor_by_factor(width / scale, factor)
73
+ elif area < min_pixels:
74
+ scale = math.sqrt(min_pixels / (height * width))
75
+ h_ = ceil_by_factor(height * scale, factor)
76
+ w_ = ceil_by_factor(width * scale, factor)
77
+ return h_, w_
78
+
79
+ def _smart_nframes(
80
+ total_frames: int,
81
+ video_fps: float,
82
+ frame_factor: int = 1,
83
+ target_fps: float = 2.0,
84
+ min_frames: int = 4,
85
+ max_frames: int = 32
86
+ ) -> int:
87
+ """
88
+ Decide how many frames to pick from a range based on target FPS.
89
+ Result is clamped to [min_frames, max_frames] and must be multiple of frame_factor.
90
+ """
91
+ minf = ceil_by_factor(min_frames, frame_factor)
92
+ maxf = floor_by_factor(min(max_frames, total_frames), frame_factor)
93
+ val = total_frames / video_fps * target_fps
94
+ val = min(max(val, minf), maxf)
95
+ nframes = round_by_factor(val, frame_factor)
96
+
97
+ if not (frame_factor <= nframes <= total_frames):
98
+ raise ValueError(f"Invalid nframes={nframes}, out of range.")
99
+ return int(nframes)
100
+
101
+
102
+ def get_video_duration_seconds(video_path: str) -> float:
103
+ """
104
+ Use ffprobe to retrieve the total duration of a video (in seconds).
105
+ """
106
+ cmd = [
107
+ "ffprobe",
108
+ "-v", "quiet",
109
+ "-print_format", "json",
110
+ "-show_format",
111
+ video_path
112
+ ]
113
+ result = subprocess.run(cmd, capture_output=True, text=True)
114
+ info = json.loads(result.stdout)
115
+ return float(info["format"]["duration"])
116
+
117
+ def get_fixed_30s_segments(video_path: str) -> list:
118
+ """
119
+ Produce a list of (start_sec, end_sec) tuples in 30-second blocks
120
+ for the entire video.
121
+ """
122
+ duration = get_video_duration_seconds(video_path)
123
+ segments = []
124
+ start = 0.0
125
+ block_size = 10.0
126
+
127
+ while start < duration:
128
+ end = min(start + block_size, duration)
129
+ segments.append((start, end))
130
+ start = end
131
+
132
+ return segments
133
+
134
+
135
+
136
+ class SmartVideoFrameExtractor:
137
+ """
138
+ This class extracts frames from a specific portion of a video
139
+ (defined by start_frame and end_frame or start_sec and end_sec).
140
+ """
141
+ def __init__(
142
+ self,
143
+ frame_factor: int = 1,
144
+ min_pixels: int = 384 * 384,
145
+ max_pixels: int = 384 * 384 * 4,
146
+ max_ratio: float = 2.0
147
+ ):
148
+ self.frame_factor = frame_factor
149
+ self.min_pixels = min_pixels
150
+ self.max_pixels = max_pixels
151
+ self.max_ratio = max_ratio
152
+
153
+ try:
154
+ import decord
155
+ self.reader = "decord"
156
+ decord.bridge.set_bridge("torch")
157
+ except ImportError:
158
+ self.reader = "torchvision"
159
+ logger.info("Decord not found, falling back to torchvision")
160
+
161
+ def extract_frames(
162
+ self,
163
+ video_path: str,
164
+ start_sec: float,
165
+ end_sec: float,
166
+ target_fps: float = 1.0,
167
+ min_frames: int = 4,
168
+ max_frames: int = 32
169
+ ) -> Tuple[List[Image.Image], List[str]]:
170
+ """Extract frames from [start_sec, end_sec] using decord or torchvision."""
171
+ if self.reader == "decord":
172
+ return self._extract_frames_decord(
173
+ video_path, start_sec, end_sec, target_fps, min_frames, max_frames
174
+ )
175
+ else:
176
+ return self._extract_frames_torchvision(
177
+ video_path, start_sec, end_sec, target_fps, min_frames, max_frames
178
+ )
179
+
180
+ def _extract_frames_decord(
181
+ self,
182
+ video_path: str,
183
+ start_sec: float,
184
+ end_sec: float,
185
+ target_fps: float,
186
+ min_frames: int,
187
+ max_frames: int
188
+ ) -> Tuple[List[Image.Image], List[str]]:
189
+ """Extract frames with decord from a certain segment."""
190
+ import decord
191
+ from decord import VideoReader
192
+
193
+ vr = VideoReader(video_path)
194
+ total_frames = len(vr)
195
+ video_fps = vr.get_avg_fps()
196
+
197
+ # Convert start/end times to frame indices
198
+ start_frame = int(start_sec * video_fps)
199
+ end_frame = min(int(end_sec * video_fps), total_frames - 1)
200
+ if start_frame >= end_frame:
201
+ return [], []
202
+
203
+ working_frames = end_frame - start_frame + 1
204
+ nframes = _smart_nframes(
205
+ working_frames,
206
+ video_fps,
207
+ self.frame_factor,
208
+ target_fps,
209
+ min_frames,
210
+ max_frames
211
+ )
212
+ indices = torch.linspace(start_frame, end_frame, nframes).round().long()
213
+
214
+ frames_tensor = vr.get_batch(indices).cpu() # NHWC
215
+ frames = []
216
+ timestamps = []
217
+
218
+ for i, frame_idx in enumerate(indices):
219
+ frame = frames_tensor[i].numpy()
220
+ pil_image = Image.fromarray(frame).convert("RGB")
221
+
222
+ # Compute timestamp
223
+ sec = frame_idx.item() / video_fps
224
+ mm = int(sec // 60)
225
+ ss = int(sec % 60)
226
+ timestamps.append(f"{mm:02d}:{ss:02d}")
227
+
228
+ # Resize
229
+ w, h = pil_image.size
230
+ rh, rw = smart_resize(
231
+ h, w,
232
+ factor=8,
233
+ min_pixels=self.min_pixels,
234
+ max_pixels=self.max_pixels,
235
+ max_ratio=self.max_ratio
236
+ )
237
+ pil_image = pil_image.resize((rw, rh), Image.Resampling.LANCZOS)
238
+ frames.append(pil_image)
239
+
240
+ return frames, timestamps, end_sec - start_sec
241
+
242
+ def _extract_frames_torchvision(
243
+ self,
244
+ video_path: str,
245
+ start_sec: float,
246
+ end_sec: float,
247
+ target_fps: float,
248
+ min_frames: int,
249
+ max_frames: int
250
+ ) -> Tuple[List[Image.Image], List[str]]:
251
+ """Extract frames with torchvision from a certain segment."""
252
+ from torchvision import io
253
+
254
+ # Read entire video (beware of memory usage on large videos!)
255
+ vid, _, info = io.read_video(
256
+ video_path,
257
+ start_pts=0,
258
+ end_pts=None,
259
+ pts_unit="sec",
260
+ output_format="TCHW"
261
+ )
262
+
263
+ total_frames = vid.size(0)
264
+ video_fps = info["video_fps"]
265
+
266
+ # Convert start/end times to frame indices
267
+ start_frame = int(start_sec * video_fps)
268
+ end_frame = min(int(end_sec * video_fps), total_frames - 1)
269
+ if start_frame >= end_frame:
270
+ return [], []
271
+
272
+ working_frames = end_frame - start_frame + 1
273
+ nframes = _smart_nframes(
274
+ working_frames,
275
+ video_fps,
276
+ self.frame_factor,
277
+ target_fps,
278
+ min_frames,
279
+ max_frames
280
+ )
281
+ indices = torch.linspace(start_frame, end_frame, nframes).round().long()
282
+
283
+ frames = []
284
+ timestamps = []
285
+ for idx in indices:
286
+ frame = vid[idx].permute(1, 2, 0).numpy()
287
+ pil_image = Image.fromarray(frame).convert("RGB")
288
+
289
+ sec = idx.item() / video_fps
290
+ mm = int(sec // 60)
291
+ ss = int(sec % 60)
292
+ timestamps.append(f"{mm:02d}:{ss:02d}")
293
+
294
+ w, h = pil_image.size
295
+ rh, rw = smart_resize(
296
+ h, w,
297
+ factor=8,
298
+ min_pixels=self.min_pixels,
299
+ max_pixels=self.max_pixels,
300
+ max_ratio=self.max_ratio
301
+ )
302
+ pil_image = pil_image.resize((rw, rh), Image.Resampling.LANCZOS)
303
+ frames.append(pil_image)
304
+
305
+ return frames, timestamps, end_sec - start_sec
306
+
307
+
308
+ class BatchedVideoHighlightDetector:
309
+ """
310
+ Optimized version of video highlight detection that processes multiple segments
311
+ in parallel using batched inference.
312
+ """
313
+ def __init__(
314
+ self,
315
+ model,
316
+ processor,
317
+ device="cuda",
318
+ batch_size=8,
319
+ max_frames_per_segment=32,
320
+ target_fps=1.0
321
+ ):
322
+ self.model = model
323
+ self.processor = processor
324
+ self.device = device
325
+ self.batch_size = batch_size
326
+ self.max_frames_per_segment = max_frames_per_segment
327
+ self.target_fps = target_fps
328
+
329
+ def _extract_frames_batch(
330
+ self,
331
+ video_path: str,
332
+ segments: List[Tuple[float, float]]
333
+ ) -> List[Tuple[List[Image.Image], List[str], float]]:
334
+ """
335
+ Extract frames from multiple segments in parallel using decord's batch capabilities.
336
+ """
337
+ import decord
338
+ from decord import VideoReader
339
+ decord.bridge.set_bridge("torch")
340
+
341
+ # Open video once for all segments
342
+ vr = VideoReader(video_path)
343
+ video_fps = vr.get_avg_fps()
344
+ results = []
345
+
346
+ for start_sec, end_sec in segments:
347
+ # Convert time to frame indices
348
+ start_frame = int(start_sec * video_fps)
349
+ end_frame = min(int(end_sec * video_fps), len(vr) - 1)
350
+
351
+ # Calculate number of frames to sample
352
+ segment_duration = end_sec - start_sec
353
+ desired_frames = min(
354
+ int(segment_duration * self.target_fps),
355
+ self.max_frames_per_segment
356
+ )
357
+
358
+ # Generate frame indices
359
+ indices = torch.linspace(start_frame, end_frame, desired_frames).round().long()
360
+
361
+ # Extract frames
362
+ frames_tensor = vr.get_batch(indices).cpu() # NHWC format
363
+
364
+ # Convert to PIL and generate timestamps
365
+ frames = []
366
+ timestamps = []
367
+ for i, frame_idx in enumerate(indices):
368
+ frame = frames_tensor[i].numpy()
369
+ pil_image = Image.fromarray(frame).convert("RGB")
370
+
371
+ # Resize maintaining aspect ratio
372
+ w, h = pil_image.size
373
+ scale = min(384 / w, 384 / h)
374
+ new_w = int(w * scale)
375
+ new_h = int(h * scale)
376
+ pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
377
+
378
+ frames.append(pil_image)
379
+
380
+ # Generate timestamp
381
+ sec = frame_idx.item() / video_fps
382
+ mm = int(sec // 60)
383
+ ss = int(sec % 60)
384
+ timestamps.append(f"{mm:02d}:{ss:02d}")
385
+
386
+ results.append((frames, timestamps, segment_duration))
387
+
388
+ return results
389
+
390
+ def _prepare_batch_inputs(
391
+ self,
392
+ frame_batches: List[Tuple[List[Image.Image], List[str], float]],
393
+ highlight_types: str
394
+ ) -> Dict[str, torch.Tensor]:
395
+ """
396
+ Convert a batch of frame sequences into model inputs.
397
+ """
398
+ conversations = []
399
+ all_frames = []
400
+
401
+ for frames, timestamps, duration in frame_batches:
402
+ # Build conversation for each segment
403
+ conversation = [
404
+ {
405
+ "role": "system",
406
+ "content": [{
407
+ "type": "text",
408
+ "text": "You are a helpful assistant that analyzes videos for specific moments of interest."
409
+ }]
410
+ },
411
+ {
412
+ "role": "user",
413
+ "content": []
414
+ }
415
+ ]
416
+
417
+ # Add video intro
418
+ conversation[1]["content"].append({
419
+ "type": "text",
420
+ "text": f"You are provided the following series of {num2words(len(frames))} frames from a {str(datetime.timedelta(seconds=duration))} [H:MM:SS] video.\n"
421
+ })
422
+
423
+ # Add frames with timestamps
424
+ for ts, frame in zip(timestamps, frames):
425
+ conversation[1]["content"].extend([
426
+ {
427
+ "type": "text",
428
+ "text": f"Frame from {ts}:"
429
+ },
430
+ {
431
+ "type": "image"
432
+ }
433
+ ])
434
+
435
+ # Add highlight check question
436
+ conversation[1]["content"].append({
437
+ "type": "text",
438
+ "text": f"""Do you see any of the following types of highlight moments in these frames?
439
+
440
+ Potential highlights to look for:
441
+ {highlight_types}
442
+
443
+ Only answer yes if you see any of those moments and answer no if you don't."""
444
+ })
445
+
446
+ conversations.append(conversation)
447
+ all_frames.extend(frames)
448
+
449
+ # Convert to model inputs using processor
450
+ prompts = [
451
+ self.processor.apply_chat_template(conv, add_generation_prompt=True)
452
+ for conv in conversations
453
+ ]
454
+
455
+ # Create batched inputs
456
+ model_inputs = self.processor(
457
+ text=prompts,
458
+ images=all_frames,
459
+ return_tensors="pt",
460
+ padding=True
461
+ ).to(self.device)
462
+
463
+ return model_inputs
464
+
465
+ def _process_segment_batch(
466
+ self,
467
+ video_path: str,
468
+ segments: List[Tuple[float, float]],
469
+ highlight_types: str
470
+ ) -> List[bool]:
471
+ """
472
+ Process a batch of segments and return which ones contain highlights.
473
+ """
474
+ # Extract frames for all segments in batch
475
+ frame_batches = self._extract_frames_batch(video_path, segments)
476
+
477
+ # Prepare model inputs
478
+ model_inputs = self._prepare_batch_inputs(frame_batches, highlight_types)
479
+
480
+ # Generate responses for entire batch
481
+ outputs = self.model.generate(
482
+ **model_inputs,
483
+ max_new_tokens=256,
484
+ num_beams=5,
485
+ temperature=0.7,
486
+ do_sample=True,
487
+ use_cache=True
488
+ )
489
+
490
+ # Process responses
491
+ responses = [
492
+ self.processor.decode(output, skip_special_tokens=True).lower().split("assistant:")[1]
493
+ for output in outputs
494
+ ]
495
+
496
+ # Check for "yes" in responses
497
+ return ["yes" in response for response in responses]
498
+
499
+ def create_highlight_video(self, video_path: str, output_path: str) -> List[Tuple[float, float]]:
500
+ """
501
+ Main function that executes the batched highlight detection pipeline.
502
+ """
503
+ # Step 1: Analyze video content
504
+ logger.info("Step 1: Analyzing video content...")
505
+ video_description = self.analyze_video_content(video_path)
506
+ logger.info(f"Video description: {video_description}")
507
+
508
+ # Step 2: Determine highlight types
509
+ logger.info("Step 2: Determining highlight types...")
510
+ highlight_types = self.determine_highlights(video_description)
511
+ logger.info(f"Looking for highlights: {highlight_types}")
512
+
513
+ # Step 3: Get all segments
514
+ segments = self._get_fixed_30s_segments(video_path)
515
+
516
+ # Step 4: Process segments in batches
517
+ logger.info("Step 3: Detecting highlight segments in batches...")
518
+ kept_segments = []
519
+
520
+ for i in tqdm(range(0, len(segments), self.batch_size)):
521
+ batch_segments = segments[i:i + self.batch_size]
522
+ keep_flags = self._process_segment_batch(video_path, batch_segments, highlight_types)
523
+
524
+ for segment, keep in zip(batch_segments, keep_flags):
525
+ if keep:
526
+ kept_segments.append(segment)
527
+ logger.info(f"\tKeeping segment {segment}")
528
+
529
+ # Step 5: Create final video
530
+ if kept_segments:
531
+ logger.info(f"Creating highlight video with {len(kept_segments)} segments...")
532
+ self._concatenate_scenes(video_path, kept_segments, output_path)
533
+ else:
534
+ logger.info("No highlights detected")
535
+
536
+ return kept_segments
537
+
538
+
539
+ def analyze_video_content(self, video_path: str, sample_rate: float = 0.2) -> str:
540
+ """
541
+ Step 1: Sample frames from the full video and get a general description
542
+ """
543
+ extractor = SmartVideoFrameExtractor()
544
+ duration = get_video_duration_seconds(video_path)
545
+
546
+ # Sample frames from entire video
547
+ frames, timestamps, duration_seconds = extractor.extract_frames(
548
+ video_path,
549
+ start_sec=0,
550
+ end_sec=duration,
551
+ target_fps=sample_rate,
552
+ max_frames=32 # Limit total frames to not overwhelm model
553
+ )
554
+
555
+ # Build conversation asking for video description
556
+ system_message = "You are a helpful assistant that can understand videos. Describe what type of video this is and what's happening in it."
557
+ conversation = [
558
+ {
559
+ "role": "system",
560
+ "content": [{"type": "text", "text": system_message}]
561
+ },
562
+ {
563
+ "role": "user",
564
+ "content": []
565
+ }
566
+ ]
567
+
568
+ # Add video intro using DEFAULT_VIDEO_INTRO
569
+ conversation[1]["content"].append({
570
+ "type": "text",
571
+ "text": DEFAULT_VIDEO_INTRO.format(
572
+ frame_count=num2words(len(frames)),
573
+ video_duration=str(datetime.timedelta(seconds=duration_seconds))
574
+ )
575
+ })
576
+
577
+ # Add frames with timestamps
578
+ for ts, frame in zip(timestamps, frames):
579
+ conversation[1]["content"].extend([
580
+ {
581
+ "type": "text",
582
+ "text": f"{FRAME_TIMESTAMP_MESSAGE} {ts}:"
583
+ },
584
+ {
585
+ "type": "image"
586
+ }
587
+ ])
588
+
589
+ # Add question
590
+ conversation[1]["content"].append({
591
+ "type": "text",
592
+ "text": "What type of video is this and what's happening in it? Be specific about the content type and general activities you observe."
593
+ })
594
+
595
+ # Get model response
596
+ prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
597
+ model_inputs = self.processor(
598
+ text=prompt,
599
+ images=frames,
600
+ return_tensors="pt"
601
+ ).to(self.model.device)
602
+
603
+ outputs = self.model.generate(
604
+ **model_inputs,
605
+ max_new_tokens=512,
606
+ num_beams=5,
607
+ temperature=0.7,
608
+ do_sample=True,
609
+ use_cache=True
610
+ )
611
+ return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant:")[1]
612
+
613
+ def determine_highlights(self, video_description: str) -> str:
614
+ """
615
+ Step 2: Based on video description, determine what would constitute highlights
616
+ """
617
+ conversation = [{
618
+ "role": "system",
619
+ "content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels. You understand that the most engaging highlights are brief and focus only on exceptional moments that are statistically rare or particularly dramatic. For sports content, you typically select only 3-5 of the most remarkable moments that would make viewers say 'I can't believe that happened!'"}]
620
+ }, {
621
+ "role": "user",
622
+ "content": [{
623
+ "type": "text",
624
+ "text": f"""Here is a description of a video:
625
+
626
+ {video_description}
627
+
628
+ Based on this description, list which rare segments should be included in a best of the best higlight."""
629
+ }]
630
+ }]
631
+ # Based on this description, what unique segments should be included in a highlight video? list moments that cannot be missed and their description, nothing else."""
632
+
633
+ # Based on this description, what unique segments should be included in a highlight video? list moments that cannot be missed."""
634
+
635
+ prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
636
+ model_inputs = self.processor(text=prompt, return_tensors="pt").to(self.model.device)
637
+
638
+ outputs = self.model.generate(
639
+ **model_inputs,
640
+ max_new_tokens=256,
641
+ num_beams=5,
642
+ temperature=0.7,
643
+ do_sample=True
644
+ )
645
+ return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant:")[1]
646
+
647
+
648
+ def _get_fixed_30s_segments(self, video_path: str) -> List[Tuple[float, float]]:
649
+ """Helper to get video segments"""
650
+ duration = self._get_video_duration_seconds(video_path)
651
+ segments = []
652
+ start = 0.0
653
+ block_size = 10.0
654
+
655
+ while start < duration:
656
+ end = min(start + block_size, duration)
657
+ segments.append((start, end))
658
+ start = end
659
+
660
+ return segments
661
+
662
+ def _get_video_duration_seconds(self, video_path: str) -> float:
663
+ """Helper to get video duration"""
664
+ import json
665
+ import subprocess
666
+
667
+ cmd = [
668
+ "ffprobe",
669
+ "-v", "quiet",
670
+ "-print_format", "json",
671
+ "-show_format",
672
+ video_path
673
+ ]
674
+ result = subprocess.run(cmd, capture_output=True, text=True)
675
+ info = json.loads(result.stdout)
676
+ return float(info["format"]["duration"])
677
+
678
+ def _concatenate_scenes(
679
+ self,
680
+ video_path: str,
681
+ scene_times: List[Tuple[float, float]],
682
+ output_path: str
683
+ ):
684
+ """
685
+ Concatenate selected (start_sec, end_sec) scenes from 'video_path' into 'output_path'
686
+ using a complex ffmpeg filter instead of multiple intermediate files.
687
+ """
688
+
689
+ if not scene_times:
690
+ logger.warning("No scenes to concatenate, skipping.")
691
+ return
692
+
693
+ # Build the filter_complex string
694
+ # For each scene i, we create two filter chains: one for video [vN] and one for audio [aN].
695
+ # Then we feed them into the concat filter.
696
+ filter_complex_parts = []
697
+ concat_inputs = []
698
+ for i, (start_sec, end_sec) in enumerate(scene_times):
699
+ filter_complex_parts.append(
700
+ f"[0:v]trim=start={start_sec}:end={end_sec},"
701
+ f"setpts=PTS-STARTPTS[v{i}];"
702
+ )
703
+ filter_complex_parts.append(
704
+ f"[0:a]atrim=start={start_sec}:end={end_sec},"
705
+ f"asetpts=PTS-STARTPTS[a{i}];"
706
+ )
707
+ concat_inputs.append(f"[v{i}][a{i}]")
708
+
709
+ # Now build the actual concat invocation.
710
+ # n = number of segments to concat, v=1 video stream, a=1 audio stream
711
+ concat_filter = f"{''.join(concat_inputs)}concat=n={len(scene_times)}:v=1:a=1[outv][outa]"
712
+ filter_complex = "".join(filter_complex_parts) + concat_filter
713
+
714
+ # Build the ffmpeg command
715
+ cmd = [
716
+ "ffmpeg",
717
+ "-y", # overwrite
718
+ "-i", video_path,
719
+ "-filter_complex", filter_complex,
720
+ "-map", "[outv]",
721
+ "-map", "[outa]",
722
+ "-c:v", "libx264", # or any codec of your choice
723
+ "-c:a", "aac", # or any audio codec of your choice
724
+ output_path
725
+ ]
726
+
727
+ logger.info(f"Running ffmpeg command: {' '.join(cmd)}")
728
+ subprocess.run(cmd, check=True)
729
+ logger.info(f"Final video saved to: {output_path}")
730
+
731
+
732
+
733
+ def load_model(
734
+ checkpoint_path: Optional[str] = None,
735
+ base_model_id: str = "HuggingFaceTB/SmolVLM-2.2B-Instruct",
736
+ device: str = "cuda"
737
+ ):
738
+ """Load the model and processor."""
739
+ # For demonstration, we set the target size
740
+ video_target_size = 384
741
+
742
+ processor = AutoProcessor.from_pretrained(base_model_id)
743
+ # Configure the image processor
744
+ processor.image_processor.size = {"longest_edge": video_target_size}
745
+ processor.image_processor.do_resize = True
746
+ processor.image_processor.do_image_splitting = False
747
+
748
+ if checkpoint_path:
749
+ model = SmolVLMForConditionalGeneration.from_pretrained(
750
+ checkpoint_path,
751
+ torch_dtype=torch.bfloat16,
752
+ device_map=device
753
+ )
754
+ else:
755
+ model = SmolVLMForConditionalGeneration.from_pretrained(
756
+ base_model_id,
757
+ torch_dtype=torch.bfloat16,
758
+ device_map=device
759
+ )
760
+
761
+ return model, processor
762
+
763
+
764
+ def main():
765
+ checkpoint_path = "/fsx/miquel/smolvlmvideo/checkpoints/final-visionUnfrozen-balanced/checkpoint-6550"
766
+ base_model_id = "HuggingFaceTB/SmolVLM-2.2B-Instruct"
767
+ device = "cuda" if torch.cuda.is_available() else "cpu"
768
+
769
+ model, processor = load_model(checkpoint_path, base_model_id, device)
770
+ detector = BatchedVideoHighlightDetector(model, processor, device=device)
771
+
772
+ if len(sys.argv) < 3:
773
+ print("Usage: python video_highlight_detector.py <input_video> <output_video>")
774
+ sys.exit(1)
775
+
776
+ video_path = sys.argv[1]
777
+ output_path = sys.argv[2]
778
+
779
+ # Create highlight video
780
+ highlight_segments = detector.create_highlight_video(video_path, output_path)
781
+ print(f"Created highlight video with {len(highlight_segments)} segments")
782
+
783
+
784
+ if __name__ == "__main__":
785
+ main()
video_spec.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "examples": [
4
+ {
5
+ "id": "Example 1",
6
+ "title": "Football Match Highlights",
7
+ "description": "Champions League semifinal match",
8
+ "original": {
9
+ "url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/barcamadrid.mp4",
10
+ "duration_seconds": 6114
11
+ },
12
+ "highlights": {
13
+ "url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/barcamadridhighlights.mp4",
14
+ "duration_seconds": 130,
15
+ "thumbnail_url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/thumbnail_barcamadrid.png"
16
+
17
+ },
18
+ "analysis": {
19
+ "video_description": "This is a high-stakes football match between Barcelona and Madrid.",
20
+ "highlight_types": "- Goals scored\n- Player interactions\n- Vibe at the stadium"
21
+ }
22
+ },
23
+ {
24
+ "id": "Example 2",
25
+ "title": "Football Match Highlights",
26
+ "description": "Champions League semifinal match",
27
+ "original": {
28
+ "url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/barcamadrid.mp4",
29
+ "duration_seconds": 6114
30
+ },
31
+ "highlights": {
32
+ "url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/barcamadridhighlights.mp4",
33
+ "duration_seconds": 130,
34
+ "thumbnail_url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/thumbnail_barcamadrid.png"
35
+
36
+ },
37
+ "analysis": {
38
+ "video_description": "This is a high-stakes football match between Barcelona and Madrid.",
39
+ "highlight_types": "- Goals scored\n- Player interactions\n- Vibe at the stadium"
40
+ }
41
+ },
42
+ {
43
+ "id": "Example 3",
44
+ "title": "Football Match Highlights",
45
+ "description": "Champions League semifinal match",
46
+ "original": {
47
+ "url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/barcamadrid.mp4",
48
+ "duration_seconds": 6114
49
+ },
50
+ "highlights": {
51
+ "url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/barcamadridhighlights.mp4",
52
+ "duration_seconds": 130,
53
+ "thumbnail_url": "https://huggingface.co/datasets/mfarre/servedfiles/resolve/main/thumbnail_barcamadrid.png"
54
+
55
+ },
56
+ "analysis": {
57
+ "video_description": "This is a high-stakes football match between Barcelona and Madrid.",
58
+ "highlight_types": "- Goals scored\n- Player interactions\n- Vibe at the stadium"
59
+ }
60
+ }
61
+ ]
62
+ }