er1t0 commited on
Commit
9b87d5a
1 Parent(s): 7976ee8

shift to ffmpeg

Browse files
Files changed (4) hide show
  1. app.py +108 -98
  2. myapp2.py +204 -0
  3. packages.txt +1 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import torch
4
  import numpy as np
@@ -10,6 +9,7 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
  import cv2
11
  import traceback
12
  import matplotlib.pyplot as plt
 
13
  from utils import load_model_without_flash_attn
14
 
15
 
@@ -62,7 +62,7 @@ def apply_color_mask(frame, mask, obj_id):
62
  return frame * (1 - mask) + colored_mask * 255
63
 
64
  def run_florence(image, text_input):
65
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
66
  task_prompt = '<OPEN_VOCABULARY_DETECTION>'
67
  prompt = task_prompt + text_input
68
  inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16)
@@ -89,125 +89,135 @@ def remove_directory_contents(directory):
89
  for name in dirs:
90
  os.rmdir(os.path.join(root, name))
91
 
92
- def process_video(video_path, prompt, chunk_size=30):
 
93
  try:
94
- video = cv2.VideoCapture(video_path)
95
- if not video.isOpened():
96
- raise ValueError("Unable to open video file")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- fps = video.get(cv2.CAP_PROP_FPS)
99
- frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
 
100
 
101
- # Process video in chunks
102
- all_segmented_frames = []
103
- for chunk_start in range(0, frame_count, chunk_size):
104
- chunk_end = min(chunk_start + chunk_size, frame_count)
105
-
106
- frames = []
107
- video.set(cv2.CAP_PROP_POS_FRAMES, chunk_start)
108
- for _ in range(chunk_end - chunk_start):
109
- ret, frame = video.read()
110
- if not ret:
111
- break
112
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
113
-
114
- if not frames:
115
- print(f"No frames extracted for chunk starting at {chunk_start}")
116
- continue
117
-
118
- # Florence detection on first frame of the chunk
119
- first_frame = Image.fromarray(frames[0])
120
- mask_box = run_florence(first_frame, prompt)
121
- print("Original mask box:", mask_box)
122
-
123
- # Convert mask_box to numpy array and ensure it's in the correct format
124
- mask_box = np.array(mask_box)
125
- print("Reshaped mask box:", mask_box)
126
-
127
- # SAM2 segmentation on first frame
128
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
129
- image_predictor.set_image(first_frame)
130
- masks, _, _ = image_predictor.predict(
131
- point_coords=None,
132
- point_labels=None,
133
- box=mask_box[None, :],
134
- multimask_output=False,
135
- )
136
- print("masks.shape",masks.shape)
137
-
138
- mask = masks.squeeze().astype(bool)
139
- print("Mask shape:", mask.shape)
140
- print("Frame shape:", frames[0].shape)
141
-
142
- # SAM2 video propagation
143
- temp_dir = f"temp_frames_{chunk_start}"
144
- os.makedirs(temp_dir, exist_ok=True)
145
- for i, frame in enumerate(frames):
146
- cv2.imwrite(os.path.join(temp_dir, f"{i:04d}.jpg"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
147
-
148
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
149
- inference_state = video_predictor.init_state(video_path=temp_dir)
150
- _, _, _ = video_predictor.add_new_mask(
151
- inference_state=inference_state,
152
- frame_idx=0,
153
- obj_id=1,
154
- mask=mask
155
- )
156
-
157
- video_segments = {}
158
- for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
159
- video_segments[out_frame_idx] = {
160
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
161
- for i, out_obj_id in enumerate(out_obj_ids)
162
- }
163
-
164
- print('segmenting for main vid done')
165
-
166
- # Apply segmentation masks to frames
167
- for i, frame in enumerate(frames):
168
- if i in video_segments:
169
- for out_obj_id, mask in video_segments[i].items():
170
- frame = apply_color_mask(frame, mask, out_obj_id)
171
- all_segmented_frames.append(frame.astype(np.uint8))
172
- else:
173
- all_segmented_frames.append(frame)
174
 
175
- # Clean up temporary files
176
- remove_directory_contents(temp_dir)
177
- os.rmdir(temp_dir)
 
 
 
 
 
 
178
 
179
- video.release()
 
 
 
 
 
 
 
 
180
 
181
- if not all_segmented_frames:
182
- raise ValueError("No frames were processed successfully")
183
-
184
- # Create video from segmented frames
 
 
 
185
  output_path = "segmented_video.mp4"
186
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps,
187
- (all_segmented_frames[0].shape[1], all_segmented_frames[0].shape[0]))
 
 
 
 
 
 
188
  for frame in all_segmented_frames:
189
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
190
- out.release()
 
 
191
 
 
 
 
 
192
  return output_path
193
 
194
  except Exception as e:
195
  print(f"Error in process_video: {str(e)}")
196
  print(traceback.format_exc()) # This will print the full stack trace
197
  return None
198
-
199
- def segment_video(video_file, prompt, chunk_size):
200
  if video_file is None:
201
  return None
202
- output_video = process_video(video_file, prompt, int(chunk_size))
203
  return output_video
204
 
205
  demo = gr.Interface(
206
  fn=segment_video,
207
  inputs=[
208
  gr.Video(label="Upload Video"),
209
- gr.Textbox(label="Enter prompt (e.g., 'a gymnast')"),
210
- gr.Slider(minimum=10, maximum=100, step=10, value=30, label="Chunk Size (frames)")
211
  ],
212
  outputs=gr.Video(label="Segmented Video"),
213
  title="Video Object Segmentation with Florence and SAM2",
 
 
1
  import os
2
  import torch
3
  import numpy as np
 
9
  import cv2
10
  import traceback
11
  import matplotlib.pyplot as plt
12
+ import ffmpeg
13
  from utils import load_model_without_flash_attn
14
 
15
 
 
62
  return frame * (1 - mask) + colored_mask * 255
63
 
64
  def run_florence(image, text_input):
65
+ with torch.amp.autocast(dtype=torch.bfloat16):
66
  task_prompt = '<OPEN_VOCABULARY_DETECTION>'
67
  prompt = task_prompt + text_input
68
  inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16)
 
89
  for name in dirs:
90
  os.rmdir(os.path.join(root, name))
91
 
92
+
93
+ def process_video(video_path, prompt):
94
  try:
95
+ # Get video info
96
+ probe = ffmpeg.probe(video_path)
97
+ video_info = next(s for s in probe['streams'] if s['codec_type'] == 'video')
98
+ width = int(video_info['width'])
99
+ height = int(video_info['height'])
100
+ num_frames = int(video_info['nb_frames'])
101
+ fps = eval(video_info['r_frame_rate'])
102
+
103
+ print(f"Video info: {width}x{height}, {num_frames} frames, {fps} fps")
104
+
105
+ # Read frames
106
+ out, _ = (
107
+ ffmpeg
108
+ .input(video_path)
109
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24')
110
+ .run(capture_stdout=True)
111
+ )
112
+ frames = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
113
+
114
+ print(f"Read {len(frames)} frames")
115
+
116
+ # Florence detection on first frame
117
+ first_frame = Image.fromarray(frames[0])
118
+ mask_box = run_florence(first_frame, prompt)
119
+ print("Original mask box:", mask_box)
120
 
121
+ # Convert mask_box to numpy array
122
+ mask_box = np.array(mask_box)
123
+ print("Reshaped mask box:", mask_box)
124
 
125
+ # SAM2 segmentation on first frame
126
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
127
+ image_predictor.set_image(first_frame)
128
+ masks, _, _ = image_predictor.predict(
129
+ point_coords=None,
130
+ point_labels=None,
131
+ box=mask_box[None, :],
132
+ multimask_output=False,
133
+ )
134
+ print("masks.shape", masks.shape)
135
+
136
+ mask = masks.squeeze().astype(bool)
137
+ print("Mask shape:", mask.shape)
138
+ print("Frame shape:", frames[0].shape)
139
+
140
+ # SAM2 video propagation
141
+ temp_dir = "temp_frames"
142
+ os.makedirs(temp_dir, exist_ok=True)
143
+ for i, frame in enumerate(frames):
144
+ Image.fromarray(frame).save(os.path.join(temp_dir, f"{i:04d}.jpg"))
145
+
146
+ print(f"Saved {len(frames)} temporary frames")
147
+
148
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
149
+ inference_state = video_predictor.init_state(video_path=temp_dir)
150
+ _, _, _ = video_predictor.add_new_mask(
151
+ inference_state=inference_state,
152
+ frame_idx=0,
153
+ obj_id=1,
154
+ mask=mask
155
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ video_segments = {}
158
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
159
+ video_segments[out_frame_idx] = {
160
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
161
+ for i, out_obj_id in enumerate(out_obj_ids)
162
+ }
163
+
164
+ print('Segmenting for main vid done')
165
+ print(f"Number of segmented frames: {len(video_segments)}")
166
 
167
+ # Apply segmentation masks to frames
168
+ all_segmented_frames = []
169
+ for i, frame in enumerate(frames):
170
+ if i in video_segments:
171
+ for out_obj_id, mask in video_segments[i].items():
172
+ frame = apply_color_mask(frame, mask, out_obj_id)
173
+ all_segmented_frames.append(frame.astype(np.uint8))
174
+ else:
175
+ all_segmented_frames.append(frame)
176
 
177
+ print(f"Applied masks to {len(all_segmented_frames)} frames")
178
+
179
+ # Clean up temporary files
180
+ remove_directory_contents(temp_dir)
181
+ os.rmdir(temp_dir)
182
+
183
+ # Write output video using ffmpeg
184
  output_path = "segmented_video.mp4"
185
+ process = (
186
+ ffmpeg
187
+ .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}', r=fps)
188
+ .output(output_path, pix_fmt='yuv420p')
189
+ .overwrite_output()
190
+ .run_async(pipe_stdin=True)
191
+ )
192
+
193
  for frame in all_segmented_frames:
194
+ process.stdin.write(frame.tobytes())
195
+
196
+ process.stdin.close()
197
+ process.wait()
198
 
199
+ if not os.path.exists(output_path):
200
+ raise ValueError(f"Output video file was not created: {output_path}")
201
+
202
+ print(f"Successfully created output video: {output_path}")
203
  return output_path
204
 
205
  except Exception as e:
206
  print(f"Error in process_video: {str(e)}")
207
  print(traceback.format_exc()) # This will print the full stack trace
208
  return None
209
+
210
+ def segment_video(video_file, prompt):
211
  if video_file is None:
212
  return None
213
+ output_video = process_video(video_file, prompt)
214
  return output_video
215
 
216
  demo = gr.Interface(
217
  fn=segment_video,
218
  inputs=[
219
  gr.Video(label="Upload Video"),
220
+ gr.Textbox(label="Enter prompt (e.g., 'a gymnast')")
 
221
  ],
222
  outputs=gr.Video(label="Segmented Video"),
223
  title="Video Object Segmentation with Florence and SAM2",
myapp2.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+ import cv2
10
+ import traceback
11
+ import matplotlib.pyplot as plt
12
+
13
+ # CUDA optimizations
14
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
15
+ if torch.cuda.get_device_properties(0).major >= 8:
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
+
19
+ # Initialize models
20
+ sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
21
+ model_cfg = "sam2_hiera_l.yaml"
22
+
23
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
24
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
25
+ image_predictor = SAM2ImagePredictor(sam2_model)
26
+
27
+ model_id = 'microsoft/Florence-2-large'
28
+ florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16).eval().cuda()
29
+ florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
30
+
31
+ def apply_color_mask(frame, mask, obj_id):
32
+ cmap = plt.get_cmap("tab10")
33
+ color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
34
+
35
+ # Ensure mask has the correct shape
36
+ if mask.ndim == 4:
37
+ mask = mask.squeeze() # Remove singleton dimensions
38
+ if mask.ndim == 3 and mask.shape[0] == 1:
39
+ mask = mask[0] # Take the first channel if it's a single-channel 3D array
40
+
41
+ # Reshape mask to match frame dimensions
42
+ mask = cv2.resize(mask.astype(np.float32), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR)
43
+
44
+ # Expand dimensions of mask and color for broadcasting
45
+ mask = np.expand_dims(mask, axis=2)
46
+ color = color.reshape(1, 1, 3)
47
+
48
+ colored_mask = mask * color
49
+ return frame * (1 - mask) + colored_mask * 255
50
+
51
+ def run_florence(image, text_input):
52
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
53
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
54
+ prompt = task_prompt + text_input
55
+ inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16)
56
+ generated_ids = florence_model.generate(
57
+ input_ids=inputs["input_ids"].cuda(),
58
+ pixel_values=inputs["pixel_values"].cuda(),
59
+ max_new_tokens=1024,
60
+ early_stopping=False,
61
+ do_sample=False,
62
+ num_beams=3,
63
+ )
64
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
65
+ parsed_answer = florence_processor.post_process_generation(
66
+ generated_text,
67
+ task=task_prompt,
68
+ image_size=(image.width, image.height)
69
+ )
70
+ return parsed_answer[task_prompt]['bboxes'][0]
71
+
72
+ def remove_directory_contents(directory):
73
+ for root, dirs, files in os.walk(directory, topdown=False):
74
+ for name in files:
75
+ os.remove(os.path.join(root, name))
76
+ for name in dirs:
77
+ os.rmdir(os.path.join(root, name))
78
+
79
+ def process_video(video_path, prompt, chunk_size=30):
80
+ try:
81
+ video = cv2.VideoCapture(video_path)
82
+ if not video.isOpened():
83
+ raise ValueError("Unable to open video file")
84
+
85
+ fps = video.get(cv2.CAP_PROP_FPS)
86
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
87
+
88
+ # Process video in chunks
89
+ all_segmented_frames = []
90
+ for chunk_start in range(0, frame_count, chunk_size):
91
+ chunk_end = min(chunk_start + chunk_size, frame_count)
92
+
93
+ frames = []
94
+ video.set(cv2.CAP_PROP_POS_FRAMES, chunk_start)
95
+ for _ in range(chunk_end - chunk_start):
96
+ ret, frame = video.read()
97
+ if not ret:
98
+ break
99
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
100
+
101
+ if not frames:
102
+ print(f"No frames extracted for chunk starting at {chunk_start}")
103
+ continue
104
+
105
+ # Florence detection on first frame of the chunk
106
+ first_frame = Image.fromarray(frames[0])
107
+ mask_box = run_florence(first_frame, prompt)
108
+ print("Original mask box:", mask_box)
109
+
110
+ # Convert mask_box to numpy array and ensure it's in the correct format
111
+ mask_box = np.array(mask_box)
112
+ print("Reshaped mask box:", mask_box)
113
+
114
+ # SAM2 segmentation on first frame
115
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
116
+ image_predictor.set_image(first_frame)
117
+ masks, _, _ = image_predictor.predict(
118
+ point_coords=None,
119
+ point_labels=None,
120
+ box=mask_box[None, :],
121
+ multimask_output=False,
122
+ )
123
+ print("masks.shape",masks.shape)
124
+
125
+ mask = masks.squeeze().astype(bool)
126
+ print("Mask shape:", mask.shape)
127
+ print("Frame shape:", frames[0].shape)
128
+
129
+ # SAM2 video propagation
130
+ temp_dir = f"temp_frames_{chunk_start}"
131
+ os.makedirs(temp_dir, exist_ok=True)
132
+ for i, frame in enumerate(frames):
133
+ cv2.imwrite(os.path.join(temp_dir, f"{i:04d}.jpg"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
134
+
135
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
136
+ inference_state = video_predictor.init_state(video_path=temp_dir)
137
+ _, _, _ = video_predictor.add_new_mask(
138
+ inference_state=inference_state,
139
+ frame_idx=0,
140
+ obj_id=1,
141
+ mask=mask
142
+ )
143
+
144
+ video_segments = {}
145
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
146
+ video_segments[out_frame_idx] = {
147
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
148
+ for i, out_obj_id in enumerate(out_obj_ids)
149
+ }
150
+
151
+ print('segmenting for main vid done')
152
+
153
+ # Apply segmentation masks to frames
154
+ for i, frame in enumerate(frames):
155
+ if i in video_segments:
156
+ for out_obj_id, mask in video_segments[i].items():
157
+ frame = apply_color_mask(frame, mask, out_obj_id)
158
+ all_segmented_frames.append(frame.astype(np.uint8))
159
+ else:
160
+ all_segmented_frames.append(frame)
161
+
162
+ # Clean up temporary files
163
+ remove_directory_contents(temp_dir)
164
+ os.rmdir(temp_dir)
165
+
166
+ video.release()
167
+
168
+ if not all_segmented_frames:
169
+ raise ValueError("No frames were processed successfully")
170
+
171
+ # Create video from segmented frames
172
+ output_path = "segmented_video.mp4"
173
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps,
174
+ (all_segmented_frames[0].shape[1], all_segmented_frames[0].shape[0]))
175
+ for frame in all_segmented_frames:
176
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
177
+ out.release()
178
+
179
+ return output_path
180
+
181
+ except Exception as e:
182
+ print(f"Error in process_video: {str(e)}")
183
+ print(traceback.format_exc()) # This will print the full stack trace
184
+ return None
185
+
186
+ def segment_video(video_file, prompt, chunk_size):
187
+ if video_file is None:
188
+ return None
189
+ output_video = process_video(video_file, prompt, int(chunk_size))
190
+ return output_video
191
+
192
+ demo = gr.Interface(
193
+ fn=segment_video,
194
+ inputs=[
195
+ gr.Video(label="Upload Video"),
196
+ gr.Textbox(label="Enter prompt (e.g., 'a gymnast')"),
197
+ gr.Slider(minimum=10, maximum=100, step=10, value=30, label="Chunk Size (frames)")
198
+ ],
199
+ outputs=gr.Video(label="Segmented Video"),
200
+ title="Video Object Segmentation with Florence and SAM2",
201
+ description="Upload a video and provide a text prompt to segment a specific object throughout the video."
202
+ )
203
+
204
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt CHANGED
@@ -8,4 +8,5 @@ opencv-python
8
  matplotlib
9
  einops
10
  timm
11
- pytest
 
 
8
  matplotlib
9
  einops
10
  timm
11
+ pytest
12
+ ffmpeg-python