EXCAI commited on
Commit
5b2a969
·
1 Parent(s): 9bd5727
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. app.py +153 -70
  3. demo.py +88 -21
  4. models/pipelines.py +327 -122
.gitmodules CHANGED
@@ -1,3 +1,6 @@
1
  [submodule "submodules/MoGe"]
2
  path = submodules/MoGe
3
  url = https://github.com/microsoft/MoGe.git
 
 
 
 
1
  [submodule "submodules/MoGe"]
2
  path = submodules/MoGe
3
  url = https://github.com/microsoft/MoGe.git
4
+ [submodule "submodules/vggt"]
5
+ path = submodules/vggt
6
+ url = https://github.com/facebookresearch/vggt.git
app.py CHANGED
@@ -16,6 +16,7 @@ sys.path.append(project_root)
16
 
17
  try:
18
  sys.path.append(os.path.join(project_root, "submodules/MoGe"))
 
19
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
  except:
21
  print("Warning: MoGe not found, motion transfer will not be applied")
@@ -27,6 +28,8 @@ hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_f
27
 
28
  from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
29
  from submodules.MoGe.moge.model import MoGeModel
 
 
30
 
31
  # Parse command line arguments
32
  parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
@@ -47,6 +50,7 @@ os.makedirs("outputs", exist_ok=True)
47
  # Create project tmp directory instead of using system temp
48
  os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
49
  os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
 
50
  def load_media(media_path, max_frames=49, transform=None):
51
  """Load video or image frames and convert to tensor
52
 
@@ -69,22 +73,52 @@ def load_media(media_path, max_frames=49, transform=None):
69
  is_video = ext in ['.mp4', '.avi', '.mov']
70
 
71
  if is_video:
72
- frames = load_video(media_path)
73
- fps = len(frames) / VideoFileClip(media_path).duration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  else:
75
  # Handle image as single frame
76
  image = load_image(media_path)
77
  frames = [image]
78
  fps = 8 # Default fps for images
79
-
80
- # Ensure we have exactly max_frames
81
- if len(frames) > max_frames:
82
- frames = frames[:max_frames]
83
- elif len(frames) < max_frames:
84
- last_frame = frames[-1]
85
  while len(frames) < max_frames:
86
- frames.append(last_frame.copy())
87
-
88
  # Convert frames to tensor
89
  video_tensor = torch.stack([transform(frame) for frame in frames])
90
 
@@ -131,6 +165,7 @@ def save_uploaded_file(file):
131
 
132
  das_pipeline = None
133
  moge_model = None
 
134
 
135
  @spaces.GPU
136
  def get_das_pipeline():
@@ -147,6 +182,13 @@ def get_moge_model():
147
  moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
148
  return moge_model
149
 
 
 
 
 
 
 
 
150
 
151
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
152
  """Process video motion transfer task"""
@@ -154,19 +196,20 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
154
  # Save uploaded files
155
  input_video_path = save_uploaded_file(source)
156
  if input_video_path is None:
157
- return None
158
 
159
  print(f"DEBUG: Repaint option: {mt_repaint_option}")
160
  print(f"DEBUG: Repaint image: {mt_repaint_image}")
161
 
162
-
163
  das = get_das_pipeline()
164
  video_tensor, fps, is_video = load_media(input_video_path)
 
 
165
  if not is_video:
166
  tracking_method = "moge"
167
  print("Image input detected, using MoGe for tracking video generation.")
168
  else:
169
- tracking_method = "spatracker"
170
 
171
  repaint_img_tensor = None
172
  if mt_repaint_image is not None:
@@ -180,7 +223,9 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
180
  prompt=prompt,
181
  depth_path=None
182
  )
 
183
  tracking_tensor = None
 
184
  if tracking_method == "moge":
185
  moge = get_moge_model()
186
  infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
@@ -195,32 +240,31 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
195
 
196
  pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
197
 
198
- _, tracking_tensor = das.visualize_tracking_moge(
199
  pred_tracks.cpu().numpy(),
200
  infer_result["mask"].cpu().numpy()
201
  )
202
  print('Export tracking video via MoGe')
203
  else:
204
- pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
205
-
206
- _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
207
- print('Export tracking video via SpaTracker')
208
 
209
  output_path = das.apply_tracking(
210
  video_tensor=video_tensor,
211
- fps=8,
212
  tracking_tensor=tracking_tensor,
213
  img_cond_tensor=repaint_img_tensor,
214
  prompt=prompt,
215
  checkpoint_path=DEFAULT_MODEL_PATH
216
  )
217
 
218
- return output_path
219
  except Exception as e:
220
  import traceback
221
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
222
- return None
223
-
224
 
225
  def process_camera_control(source, prompt, camera_motion, tracking_method):
226
  """Process camera control task"""
@@ -228,17 +272,18 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
228
  # Save uploaded files
229
  input_media_path = save_uploaded_file(source)
230
  if input_media_path is None:
231
- return None
232
 
233
  print(f"DEBUG: Camera motion: '{camera_motion}'")
234
  print(f"DEBUG: Tracking method: '{tracking_method}'")
235
 
236
  das = get_das_pipeline()
237
-
238
  video_tensor, fps, is_video = load_media(input_media_path)
239
- if not is_video and tracking_method == "spatracker":
 
 
240
  tracking_method = "moge"
241
- print("Image input detected with spatracker selected, switching to MoGe")
242
 
243
  cam_motion = CameraMotionGenerator(camera_motion)
244
  repaint_img_tensor = None
@@ -267,32 +312,54 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
267
  )
268
  print('Export tracking video via MoGe')
269
  else:
270
-
271
- pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  if camera_motion:
273
  poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
274
- pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
 
275
  print("Camera motion applied")
276
-
277
- _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
278
- print('Export tracking video via SpaTracker')
279
-
280
 
281
  output_path = das.apply_tracking(
282
  video_tensor=video_tensor,
283
- fps=8,
284
  tracking_tensor=tracking_tensor,
285
  img_cond_tensor=repaint_img_tensor,
286
  prompt=prompt,
287
  checkpoint_path=DEFAULT_MODEL_PATH
288
  )
289
 
290
- return output_path
291
  except Exception as e:
292
  import traceback
293
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
294
- return None
295
-
296
 
297
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
298
  """Process object manipulation task"""
@@ -300,21 +367,21 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
300
  # Save uploaded files
301
  input_image_path = save_uploaded_file(source)
302
  if input_image_path is None:
303
- return None
304
 
305
  object_mask_path = save_uploaded_file(object_mask)
306
  if object_mask_path is None:
307
  print("Object mask not provided")
308
- return None
309
-
310
 
311
  das = get_das_pipeline()
312
  video_tensor, fps, is_video = load_media(input_image_path)
313
- if not is_video and tracking_method == "spatracker":
 
 
314
  tracking_method = "moge"
315
- print("Image input detected with spatracker selected, switching to MoGe")
316
 
317
-
318
  mask_image = Image.open(object_mask_path).convert('L')
319
  mask_image = transforms.Resize((480, 720))(mask_image)
320
  mask = torch.from_numpy(np.array(mask_image) > 127)
@@ -322,10 +389,10 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
322
  motion_generator = ObjectMotionGenerator(device=das.device)
323
  repaint_img_tensor = None
324
  tracking_tensor = None
 
325
  if tracking_method == "moge":
326
  moge = get_moge_model()
327
 
328
-
329
  infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
330
  H, W = infer_result["points"].shape[0:2]
331
  pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
@@ -342,7 +409,6 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
342
  poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
343
  pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
344
 
345
-
346
  cam_motion = CameraMotionGenerator(None)
347
  cam_motion.set_intr(infer_result["intrinsics"])
348
  pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
@@ -353,9 +419,27 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
353
  )
354
  print('Export tracking video via MoGe')
355
  else:
 
 
356
 
357
- pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
 
 
 
 
358
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  pred_tracks = motion_generator.apply_motion(
361
  pred_tracks=pred_tracks.squeeze(),
@@ -363,30 +447,27 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
363
  motion_type=object_motion,
364
  distance=50,
365
  num_frames=49,
366
- tracking_method="spatracker"
367
- ).unsqueeze(0)
368
  print(f"Object motion '{object_motion}' applied using provided mask")
369
 
370
-
371
- _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
372
- print('Export tracking video via SpaTracker')
373
-
374
 
375
  output_path = das.apply_tracking(
376
  video_tensor=video_tensor,
377
- fps=8,
378
  tracking_tensor=tracking_tensor,
379
  img_cond_tensor=repaint_img_tensor,
380
  prompt=prompt,
381
  checkpoint_path=DEFAULT_MODEL_PATH
382
  )
383
 
384
- return output_path
385
  except Exception as e:
386
  import traceback
387
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
388
- return None
389
-
390
 
391
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
392
  """Process mesh animation task"""
@@ -394,15 +475,16 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
394
  # Save uploaded files
395
  input_video_path = save_uploaded_file(source)
396
  if input_video_path is None:
397
- return None
398
 
399
  tracking_video_path = save_uploaded_file(tracking_video)
400
  if tracking_video_path is None:
401
- return None
402
-
403
 
404
  das = get_das_pipeline()
405
  video_tensor, fps, is_video = load_media(input_video_path)
 
 
406
  tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
407
  repaint_img_tensor = None
408
  if ma_repaint_image is not None:
@@ -420,18 +502,18 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
420
 
421
  output_path = das.apply_tracking(
422
  video_tensor=video_tensor,
423
- fps=8,
424
  tracking_tensor=tracking_tensor,
425
  img_cond_tensor=repaint_img_tensor,
426
  prompt=prompt,
427
  checkpoint_path=DEFAULT_MODEL_PATH
428
  )
429
 
430
- return output_path
431
  except Exception as e:
432
  import traceback
433
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
434
- return None
435
 
436
  # Create Gradio interface with updated layout
437
  with gr.Blocks(title="Diffusion as Shader") as demo:
@@ -444,6 +526,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
444
 
445
  with right_column:
446
  output_video = gr.Video(label="Generated Video")
 
447
 
448
  with left_column:
449
  source = gr.File(label="Source", file_types=["image", "video"])
@@ -479,7 +562,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
479
  source, common_prompt,
480
  mt_repaint_option, mt_repaint_image
481
  ],
482
- outputs=[output_video]
483
  )
484
 
485
  # Camera Control tab
@@ -597,8 +680,8 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
597
 
598
  cc_tracking_method = gr.Radio(
599
  label="Tracking Method",
600
- choices=["spatracker", "moge"],
601
- value="moge"
602
  )
603
 
604
  # Add run button for Camera Control tab
@@ -611,7 +694,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
611
  source, common_prompt,
612
  cc_camera_motion, cc_tracking_method
613
  ],
614
- outputs=[output_video]
615
  )
616
 
617
  # Object Manipulation tab
@@ -629,8 +712,8 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
629
  )
630
  om_tracking_method = gr.Radio(
631
  label="Tracking Method",
632
- choices=["spatracker", "moge"],
633
- value="moge"
634
  )
635
 
636
  # Add run button for Object Manipulation tab
@@ -643,7 +726,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
643
  source, common_prompt,
644
  om_object_motion, om_object_mask, om_tracking_method
645
  ],
646
- outputs=[output_video]
647
  )
648
 
649
  # Animating meshes to video tab
@@ -683,7 +766,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
683
  source, common_prompt,
684
  ma_tracking_video, ma_repaint_option, ma_repaint_image
685
  ],
686
- outputs=[output_video]
687
  )
688
 
689
  # Launch interface
 
16
 
17
  try:
18
  sys.path.append(os.path.join(project_root, "submodules/MoGe"))
19
+ sys.path.append(os.path.join(project_root, "submodules/vggt"))
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
  except:
22
  print("Warning: MoGe not found, motion transfer will not be applied")
 
28
 
29
  from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
30
  from submodules.MoGe.moge.model import MoGeModel
31
+ from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
32
+ from submodules.vggt.vggt.models.vggt import VGGT
33
 
34
  # Parse command line arguments
35
  parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
 
50
  # Create project tmp directory instead of using system temp
51
  os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
52
  os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
53
+
54
  def load_media(media_path, max_frames=49, transform=None):
55
  """Load video or image frames and convert to tensor
56
 
 
73
  is_video = ext in ['.mp4', '.avi', '.mov']
74
 
75
  if is_video:
76
+ # Load video file info
77
+ video_clip = VideoFileClip(media_path)
78
+ duration = video_clip.duration
79
+ original_fps = video_clip.fps
80
+
81
+ # Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
82
+ if duration > 6.0:
83
+ sampling_fps = 8 # 8 frames per second
84
+ frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
85
+ fps = sampling_fps
86
+ # Cases 2 and 3: Video shorter than 6 seconds
87
+ else:
88
+ # Load all frames
89
+ frames = load_video(media_path)
90
+
91
+ # Case 2: Total frames less than max_frames, need interpolation
92
+ if len(frames) < max_frames:
93
+ fps = len(frames) / duration # Keep original fps
94
+
95
+ # Evenly interpolate to max_frames
96
+ indices = np.linspace(0, len(frames) - 1, max_frames)
97
+ new_frames = []
98
+ for i in indices:
99
+ idx = int(i)
100
+ new_frames.append(frames[idx])
101
+ frames = new_frames
102
+ # Case 3: Total frames more than max_frames but video less than 6 seconds
103
+ else:
104
+ # Evenly sample to max_frames
105
+ indices = np.linspace(0, len(frames) - 1, max_frames)
106
+ new_frames = []
107
+ for i in indices:
108
+ idx = int(i)
109
+ new_frames.append(frames[idx])
110
+ frames = new_frames
111
+ fps = max_frames / duration # New fps to maintain duration
112
  else:
113
  # Handle image as single frame
114
  image = load_image(media_path)
115
  frames = [image]
116
  fps = 8 # Default fps for images
117
+
118
+ # Duplicate frame to max_frames
 
 
 
 
119
  while len(frames) < max_frames:
120
+ frames.append(frames[0].copy())
121
+
122
  # Convert frames to tensor
123
  video_tensor = torch.stack([transform(frame) for frame in frames])
124
 
 
165
 
166
  das_pipeline = None
167
  moge_model = None
168
+ vggt_model = None
169
 
170
  @spaces.GPU
171
  def get_das_pipeline():
 
182
  moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
183
  return moge_model
184
 
185
+ @spaces.GPU
186
+ def get_vggt_model():
187
+ global vggt_model
188
+ if vggt_model is None:
189
+ das = get_das_pipeline()
190
+ vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)
191
+ return vggt_model
192
 
193
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
194
  """Process video motion transfer task"""
 
196
  # Save uploaded files
197
  input_video_path = save_uploaded_file(source)
198
  if input_video_path is None:
199
+ return None, None
200
 
201
  print(f"DEBUG: Repaint option: {mt_repaint_option}")
202
  print(f"DEBUG: Repaint image: {mt_repaint_image}")
203
 
 
204
  das = get_das_pipeline()
205
  video_tensor, fps, is_video = load_media(input_video_path)
206
+ das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
207
+
208
  if not is_video:
209
  tracking_method = "moge"
210
  print("Image input detected, using MoGe for tracking video generation.")
211
  else:
212
+ tracking_method = "cotracker"
213
 
214
  repaint_img_tensor = None
215
  if mt_repaint_image is not None:
 
223
  prompt=prompt,
224
  depth_path=None
225
  )
226
+
227
  tracking_tensor = None
228
+ tracking_path = None
229
  if tracking_method == "moge":
230
  moge = get_moge_model()
231
  infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
 
240
 
241
  pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
242
 
243
+ tracking_path, tracking_tensor = das.visualize_tracking_moge(
244
  pred_tracks.cpu().numpy(),
245
  infer_result["mask"].cpu().numpy()
246
  )
247
  print('Export tracking video via MoGe')
248
  else:
249
+ # 使用 cotracker
250
+ pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
251
+ tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
252
+ print('Export tracking video via cotracker')
253
 
254
  output_path = das.apply_tracking(
255
  video_tensor=video_tensor,
256
+ fps=fps, # 使用 load_media 返回的 fps
257
  tracking_tensor=tracking_tensor,
258
  img_cond_tensor=repaint_img_tensor,
259
  prompt=prompt,
260
  checkpoint_path=DEFAULT_MODEL_PATH
261
  )
262
 
263
+ return tracking_path, output_path
264
  except Exception as e:
265
  import traceback
266
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
267
+ return None, None
 
268
 
269
  def process_camera_control(source, prompt, camera_motion, tracking_method):
270
  """Process camera control task"""
 
272
  # Save uploaded files
273
  input_media_path = save_uploaded_file(source)
274
  if input_media_path is None:
275
+ return None, None
276
 
277
  print(f"DEBUG: Camera motion: '{camera_motion}'")
278
  print(f"DEBUG: Tracking method: '{tracking_method}'")
279
 
280
  das = get_das_pipeline()
 
281
  video_tensor, fps, is_video = load_media(input_media_path)
282
+ das.fps = fps # 设置 das.fps load_media 返回的 fps
283
+
284
+ if not is_video:
285
  tracking_method = "moge"
286
+ print("Image input detected, switching to MoGe")
287
 
288
  cam_motion = CameraMotionGenerator(camera_motion)
289
  repaint_img_tensor = None
 
312
  )
313
  print('Export tracking video via MoGe')
314
  else:
315
+ # 使用 cotracker
316
+ pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
317
+
318
+ t, c, h, w = video_tensor.shape
319
+ new_width = 518
320
+ new_height = round(h * (new_width / w) / 14) * 14
321
+ resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
322
+ video_vggt = resize_transform(video_tensor) # [T, C, H, W]
323
+
324
+ if new_height > 518:
325
+ start_y = (new_height - 518) // 2
326
+ video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
327
+
328
+ vggt_model = get_vggt_model()
329
+
330
+ with torch.no_grad():
331
+ with torch.cuda.amp.autocast(dtype=das.dtype):
332
+ video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
333
+ aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
334
+
335
+ extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
336
+
337
+ cam_motion.set_intr(intr)
338
+ cam_motion.set_extr(extr)
339
+
340
  if camera_motion:
341
  poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
342
+ pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
343
+ pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
344
  print("Camera motion applied")
345
+
346
+ tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, None)
347
+ print('Export tracking video via cotracker')
 
348
 
349
  output_path = das.apply_tracking(
350
  video_tensor=video_tensor,
351
+ fps=fps, # 使用 load_media 返回的 fps
352
  tracking_tensor=tracking_tensor,
353
  img_cond_tensor=repaint_img_tensor,
354
  prompt=prompt,
355
  checkpoint_path=DEFAULT_MODEL_PATH
356
  )
357
 
358
+ return tracking_path, output_path
359
  except Exception as e:
360
  import traceback
361
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
362
+ return None, None
 
363
 
364
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
365
  """Process object manipulation task"""
 
367
  # Save uploaded files
368
  input_image_path = save_uploaded_file(source)
369
  if input_image_path is None:
370
+ return None, None
371
 
372
  object_mask_path = save_uploaded_file(object_mask)
373
  if object_mask_path is None:
374
  print("Object mask not provided")
375
+ return None, None
 
376
 
377
  das = get_das_pipeline()
378
  video_tensor, fps, is_video = load_media(input_image_path)
379
+ das.fps = fps # 设置 das.fps load_media 返回的 fps
380
+
381
+ if not is_video:
382
  tracking_method = "moge"
383
+ print("Image input detected, switching to MoGe")
384
 
 
385
  mask_image = Image.open(object_mask_path).convert('L')
386
  mask_image = transforms.Resize((480, 720))(mask_image)
387
  mask = torch.from_numpy(np.array(mask_image) > 127)
 
389
  motion_generator = ObjectMotionGenerator(device=das.device)
390
  repaint_img_tensor = None
391
  tracking_tensor = None
392
+
393
  if tracking_method == "moge":
394
  moge = get_moge_model()
395
 
 
396
  infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
397
  H, W = infer_result["points"].shape[0:2]
398
  pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
 
409
  poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
410
  pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
411
 
 
412
  cam_motion = CameraMotionGenerator(None)
413
  cam_motion.set_intr(infer_result["intrinsics"])
414
  pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
 
419
  )
420
  print('Export tracking video via MoGe')
421
  else:
422
+ # 使用 cotracker
423
+ pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
424
 
425
+ t, c, h, w = video_tensor.shape
426
+ new_width = 518
427
+ new_height = round(h * (new_width / w) / 14) * 14
428
+ resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
429
+ video_vggt = resize_transform(video_tensor) # [T, C, H, W]
430
 
431
+ if new_height > 518:
432
+ start_y = (new_height - 518) // 2
433
+ video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
434
+
435
+ vggt_model = get_vggt_model()
436
+
437
+ with torch.no_grad():
438
+ with torch.cuda.amp.autocast(dtype=das.dtype):
439
+ video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
440
+ aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
441
+
442
+ extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
443
 
444
  pred_tracks = motion_generator.apply_motion(
445
  pred_tracks=pred_tracks.squeeze(),
 
447
  motion_type=object_motion,
448
  distance=50,
449
  num_frames=49,
450
+ tracking_method="cotracker"
451
+ )
452
  print(f"Object motion '{object_motion}' applied using provided mask")
453
 
454
+ tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), None)
455
+ print('Export tracking video via cotracker')
 
 
456
 
457
  output_path = das.apply_tracking(
458
  video_tensor=video_tensor,
459
+ fps=fps, # 使用 load_media 返回的 fps
460
  tracking_tensor=tracking_tensor,
461
  img_cond_tensor=repaint_img_tensor,
462
  prompt=prompt,
463
  checkpoint_path=DEFAULT_MODEL_PATH
464
  )
465
 
466
+ return tracking_path, output_path
467
  except Exception as e:
468
  import traceback
469
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
470
+ return None, None
 
471
 
472
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
473
  """Process mesh animation task"""
 
475
  # Save uploaded files
476
  input_video_path = save_uploaded_file(source)
477
  if input_video_path is None:
478
+ return None, None
479
 
480
  tracking_video_path = save_uploaded_file(tracking_video)
481
  if tracking_video_path is None:
482
+ return None, None
 
483
 
484
  das = get_das_pipeline()
485
  video_tensor, fps, is_video = load_media(input_video_path)
486
+ das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
487
+
488
  tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
489
  repaint_img_tensor = None
490
  if ma_repaint_image is not None:
 
502
 
503
  output_path = das.apply_tracking(
504
  video_tensor=video_tensor,
505
+ fps=fps, # 使用 load_media 返回的 fps
506
  tracking_tensor=tracking_tensor,
507
  img_cond_tensor=repaint_img_tensor,
508
  prompt=prompt,
509
  checkpoint_path=DEFAULT_MODEL_PATH
510
  )
511
 
512
+ return tracking_video_path, output_path
513
  except Exception as e:
514
  import traceback
515
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
516
+ return None, None
517
 
518
  # Create Gradio interface with updated layout
519
  with gr.Blocks(title="Diffusion as Shader") as demo:
 
526
 
527
  with right_column:
528
  output_video = gr.Video(label="Generated Video")
529
+ tracking_video = gr.Video(label="Tracking Video")
530
 
531
  with left_column:
532
  source = gr.File(label="Source", file_types=["image", "video"])
 
562
  source, common_prompt,
563
  mt_repaint_option, mt_repaint_image
564
  ],
565
+ outputs=[tracking_video, output_video]
566
  )
567
 
568
  # Camera Control tab
 
680
 
681
  cc_tracking_method = gr.Radio(
682
  label="Tracking Method",
683
+ choices=["moge", "cotracker"],
684
+ value="cotracker"
685
  )
686
 
687
  # Add run button for Camera Control tab
 
694
  source, common_prompt,
695
  cc_camera_motion, cc_tracking_method
696
  ],
697
+ outputs=[tracking_video, output_video]
698
  )
699
 
700
  # Object Manipulation tab
 
712
  )
713
  om_tracking_method = gr.Radio(
714
  label="Tracking Method",
715
+ choices=["moge", "cotracker"],
716
+ value="cotracker"
717
  )
718
 
719
  # Add run button for Object Manipulation tab
 
726
  source, common_prompt,
727
  om_object_motion, om_object_mask, om_tracking_method
728
  ],
729
+ outputs=[tracking_video, output_video]
730
  )
731
 
732
  # Animating meshes to video tab
 
766
  source, common_prompt,
767
  ma_tracking_video, ma_repaint_option, ma_repaint_image
768
  ],
769
+ outputs=[tracking_video, output_video]
770
  )
771
 
772
  # Launch interface
demo.py CHANGED
@@ -5,6 +5,7 @@ from PIL import Image
5
  project_root = os.path.dirname(os.path.abspath(__file__))
6
  try:
7
  sys.path.append(os.path.join(project_root, "submodules/MoGe"))
 
8
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
  except:
10
  print("Warning: MoGe not found, motion transfer will not be applied")
@@ -18,6 +19,8 @@ from diffusers.utils import load_image, load_video
18
 
19
  from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
20
  from submodules.MoGe.moge.model import MoGeModel
 
 
21
 
22
  def load_media(media_path, max_frames=49, transform=None):
23
  """Load video or image frames and convert to tensor
@@ -28,7 +31,7 @@ def load_media(media_path, max_frames=49, transform=None):
28
  transform (callable): Transform to apply to frames
29
 
30
  Returns:
31
- Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and FPS
32
  """
33
  if transform is None:
34
  transform = transforms.Compose([
@@ -41,22 +44,52 @@ def load_media(media_path, max_frames=49, transform=None):
41
  is_video = ext in ['.mp4', '.avi', '.mov']
42
 
43
  if is_video:
44
- frames = load_video(media_path)
45
- fps = len(frames) / VideoFileClip(media_path).duration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  else:
47
  # Handle image as single frame
48
  image = load_image(media_path)
49
  frames = [image]
50
  fps = 8 # Default fps for images
51
-
52
- # Ensure we have exactly max_frames
53
- if len(frames) > max_frames:
54
- frames = frames[:max_frames]
55
- elif len(frames) < max_frames:
56
- last_frame = frames[-1]
57
  while len(frames) < max_frames:
58
- frames.append(last_frame.copy())
59
-
60
  # Convert frames to tensor
61
  video_tensor = torch.stack([transform(frame) for frame in frames])
62
 
@@ -77,8 +110,8 @@ if __name__ == "__main__":
77
  help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
78
  parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
79
  parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
80
- parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge'],
81
- help='Tracking method to use (spatracker or moge)')
82
  args = parser.parse_args()
83
 
84
  # Load input video/image
@@ -89,6 +122,7 @@ if __name__ == "__main__":
89
 
90
  # Initialize pipeline
91
  das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
 
92
  if args.tracking_method == "moge" and args.tracking_path is None:
93
  moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
94
 
@@ -153,7 +187,7 @@ if __name__ == "__main__":
153
  poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
154
  # change pred_tracks into screen coordinate
155
  pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
156
- pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
157
  _, tracking_tensor = das.visualize_tracking_moge(
158
  pred_tracks.cpu().numpy(),
159
  infer_result["mask"].cpu().numpy()
@@ -161,13 +195,44 @@ if __name__ == "__main__":
161
  print('export tracking video via MoGe.')
162
 
163
  else:
164
- # Generate tracking points
165
- pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # Apply camera motion if specified
168
  if args.camera_motion:
169
  poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
170
- pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
 
171
  print("Camera motion applied")
172
 
173
  # Apply object motion if specified
@@ -184,7 +249,7 @@ if __name__ == "__main__":
184
  motion_generator = ObjectMotionGenerator(device=das.device)
185
 
186
  pred_tracks = motion_generator.apply_motion(
187
- pred_tracks=pred_tracks.squeeze(),
188
  mask=mask,
189
  motion_type=args.object_motion,
190
  distance=50,
@@ -193,12 +258,14 @@ if __name__ == "__main__":
193
  ).unsqueeze(0)
194
  print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
195
 
196
- # Generate tracking tensor from modified tracks
197
- _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
 
 
198
 
199
  das.apply_tracking(
200
  video_tensor=video_tensor,
201
- fps=8,
202
  tracking_tensor=tracking_tensor,
203
  img_cond_tensor=repaint_img_tensor,
204
  prompt=args.prompt,
 
5
  project_root = os.path.dirname(os.path.abspath(__file__))
6
  try:
7
  sys.path.append(os.path.join(project_root, "submodules/MoGe"))
8
+ sys.path.append(os.path.join(project_root, "submodules/vggt"))
9
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
  except:
11
  print("Warning: MoGe not found, motion transfer will not be applied")
 
19
 
20
  from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
21
  from submodules.MoGe.moge.model import MoGeModel
22
+ from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
23
+ from submodules.vggt.vggt.models.vggt import VGGT
24
 
25
  def load_media(media_path, max_frames=49, transform=None):
26
  """Load video or image frames and convert to tensor
 
31
  transform (callable): Transform to apply to frames
32
 
33
  Returns:
34
+ Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag
35
  """
36
  if transform is None:
37
  transform = transforms.Compose([
 
44
  is_video = ext in ['.mp4', '.avi', '.mov']
45
 
46
  if is_video:
47
+ # Load video file info
48
+ video_clip = VideoFileClip(media_path)
49
+ duration = video_clip.duration
50
+ original_fps = video_clip.fps
51
+
52
+ # Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
53
+ if duration > 6.0:
54
+ sampling_fps = 8 # 8 frames per second
55
+ frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
56
+ fps = sampling_fps
57
+ # Cases 2 and 3: Video shorter than 6 seconds
58
+ else:
59
+ # Load all frames
60
+ frames = load_video(media_path)
61
+
62
+ # Case 2: Total frames less than max_frames, need interpolation
63
+ if len(frames) < max_frames:
64
+ fps = len(frames) / duration # Keep original fps
65
+
66
+ # Evenly interpolate to max_frames
67
+ indices = np.linspace(0, len(frames) - 1, max_frames)
68
+ new_frames = []
69
+ for i in indices:
70
+ idx = int(i)
71
+ new_frames.append(frames[idx])
72
+ frames = new_frames
73
+ # Case 3: Total frames more than max_frames but video less than 6 seconds
74
+ else:
75
+ # Evenly sample to max_frames
76
+ indices = np.linspace(0, len(frames) - 1, max_frames)
77
+ new_frames = []
78
+ for i in indices:
79
+ idx = int(i)
80
+ new_frames.append(frames[idx])
81
+ frames = new_frames
82
+ fps = max_frames / duration # New fps to maintain duration
83
  else:
84
  # Handle image as single frame
85
  image = load_image(media_path)
86
  frames = [image]
87
  fps = 8 # Default fps for images
88
+
89
+ # Duplicate frame to max_frames
 
 
 
 
90
  while len(frames) < max_frames:
91
+ frames.append(frames[0].copy())
92
+
93
  # Convert frames to tensor
94
  video_tensor = torch.stack([transform(frame) for frame in frames])
95
 
 
110
  help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
111
  parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
112
  parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
113
+ parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge', 'cotracker'],
114
+ help='Tracking method to use (spatracker, cotracker or moge)')
115
  args = parser.parse_args()
116
 
117
  # Load input video/image
 
122
 
123
  # Initialize pipeline
124
  das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
125
+ das.fps = fps
126
  if args.tracking_method == "moge" and args.tracking_path is None:
127
  moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
128
 
 
187
  poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
188
  # change pred_tracks into screen coordinate
189
  pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
190
+ pred_tracks = cam_motion.w2s_moge(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
191
  _, tracking_tensor = das.visualize_tracking_moge(
192
  pred_tracks.cpu().numpy(),
193
  infer_result["mask"].cpu().numpy()
 
195
  print('export tracking video via MoGe.')
196
 
197
  else:
198
+
199
+ if args.tracking_method == "cotracker":
200
+ pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor) # T N 3, T N
201
+ else:
202
+ pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) # T N 3, T N, B N
203
+
204
+ # Preprocess video tensor to match VGGT requirements
205
+ t, c, h, w = video_tensor.shape
206
+ new_width = 518
207
+ new_height = round(h * (new_width / w) / 14) * 14
208
+ resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
209
+ video_vggt = resize_transform(video_tensor) # [T, C, H, W]
210
+
211
+ if new_height > 518:
212
+ start_y = (new_height - 518) // 2
213
+ video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
214
+
215
+ # Get extrinsic and intrinsic matrices
216
+ vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)
217
+
218
+ with torch.no_grad():
219
+ with torch.cuda.amp.autocast(dtype=das.dtype):
220
+
221
+ video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
222
+ aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
223
+
224
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
225
+ extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
226
+ depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_vggt, ps_idx)
227
+
228
+ cam_motion.set_intr(intr)
229
+ cam_motion.set_extr(extr)
230
 
231
  # Apply camera motion if specified
232
  if args.camera_motion:
233
  poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
234
+ pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
235
+ pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
236
  print("Camera motion applied")
237
 
238
  # Apply object motion if specified
 
249
  motion_generator = ObjectMotionGenerator(device=das.device)
250
 
251
  pred_tracks = motion_generator.apply_motion(
252
+ pred_tracks=pred_tracks,
253
  mask=mask,
254
  motion_type=args.object_motion,
255
  distance=50,
 
258
  ).unsqueeze(0)
259
  print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
260
 
261
+ if args.tracking_method == "cotracker":
262
+ _, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
263
+ else:
264
+ _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
265
 
266
  das.apply_tracking(
267
  video_tensor=video_tensor,
268
+ fps=fps,
269
  tracking_tensor=tracking_tensor,
270
  img_cond_tensor=repaint_img_tensor,
271
  prompt=args.prompt,
models/pipelines.py CHANGED
@@ -22,9 +22,9 @@ from models.spatracker.utils.visualizer import Visualizer
22
  from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
23
 
24
  from submodules.MoGe.moge.model import MoGeModel
 
25
  from image_gen_aux import DepthPreprocessor
26
  from moviepy.editor import ImageSequenceClip
27
- import spaces
28
 
29
  class DiffusionAsShaderPipeline:
30
  def __init__(self, gpu_id=0, output_dir='outputs'):
@@ -45,6 +45,7 @@ class DiffusionAsShaderPipeline:
45
  # device
46
  self.device = f"cuda:{gpu_id}"
47
  torch.cuda.set_device(gpu_id)
 
48
 
49
  # files
50
  self.output_dir = output_dir
@@ -56,7 +57,6 @@ class DiffusionAsShaderPipeline:
56
  transforms.ToTensor()
57
  ])
58
 
59
- @spaces.GPU(duration=240)
60
  @torch.no_grad()
61
  def _infer(
62
  self,
@@ -65,7 +65,7 @@ class DiffusionAsShaderPipeline:
65
  tracking_tensor: torch.Tensor = None,
66
  image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
67
  output_path: str = "./output.mp4",
68
- num_inference_steps: int = 50,
69
  guidance_scale: float = 6.0,
70
  num_videos_per_prompt: int = 1,
71
  dtype: torch.dtype = torch.bfloat16,
@@ -114,6 +114,8 @@ class DiffusionAsShaderPipeline:
114
  pipe.text_encoder.eval()
115
  pipe.vae.eval()
116
 
 
 
117
  # Process tracking tensor
118
  tracking_maps = tracking_tensor.float() # [T, C, H, W]
119
  tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
@@ -167,60 +169,9 @@ class DiffusionAsShaderPipeline:
167
 
168
  def _set_camera_motion(self, camera_motion):
169
  self.camera_motion = camera_motion
170
-
171
- def _get_intr(self, fov, H=480, W=720):
172
- fov_rad = math.radians(fov)
173
- focal_length = (W / 2) / math.tan(fov_rad / 2)
174
-
175
- cx = W / 2
176
- cy = H / 2
177
-
178
- intr = torch.tensor([
179
- [focal_length, 0, cx],
180
- [0, focal_length, cy],
181
- [0, 0, 1]
182
- ], dtype=torch.float32)
183
-
184
- return intr
185
-
186
- @spaces.GPU
187
- def _apply_poses(self, pts, intr, poses):
188
- """
189
- Args:
190
- pts (torch.Tensor): pointclouds coordinates [T, N, 3]
191
- intr (torch.Tensor): camera intrinsics [T, 3, 3]
192
- poses (numpy.ndarray): camera poses [T, 4, 4]
193
- """
194
- poses = torch.from_numpy(poses).float().to(self.device)
195
-
196
- T, N, _ = pts.shape
197
- ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
198
- pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
199
- pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
200
- pts_cam[:,:, :3] /= pts[:, :, 2:3]
201
-
202
- # to homogeneous
203
- pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
204
-
205
- if poses.shape[0] == 1:
206
- poses = poses.repeat(T, 1, 1)
207
- elif poses.shape[0] != T:
208
- raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
209
-
210
- pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
211
-
212
- pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
213
- pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
214
-
215
- return pts_proj
216
-
217
- def apply_traj_on_tracking(self, pred_tracks, camera_motion=None, fov=55, frame_num=49):
218
- intr = self._get_intr(fov).unsqueeze(0).repeat(frame_num, 1, 1).to(self.device)
219
- tracking_pts = self._apply_poses(pred_tracks.squeeze(), intr, camera_motion).unsqueeze(0)
220
- return tracking_pts
221
 
222
  ##============= SpatialTracker =============##
223
- @spaces.GPU
224
  def generate_tracking_spatracker(self, video_tensor, density=70):
225
  """Generate tracking video
226
 
@@ -233,7 +184,7 @@ class DiffusionAsShaderPipeline:
233
  print("Loading tracking models...")
234
  # Load tracking model
235
  tracker = SpaTrackerPredictor(
236
- checkpoint=os.path.join(project_root, 'checkpoints/spatracker/spaT_final.pth'),
237
  interp_shape=(384, 576),
238
  seq_length=12
239
  ).to(self.device)
@@ -268,14 +219,13 @@ class DiffusionAsShaderPipeline:
268
  progressive_tracking=False
269
  )
270
 
271
- return pred_tracks, pred_visibility, T_Firsts
272
 
273
  finally:
274
  # Clean up GPU memory
275
  del tracker, self.depth_preprocessor
276
  torch.cuda.empty_cache()
277
 
278
- @spaces.GPU
279
  def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
280
  video = video.unsqueeze(0).to(self.device)
281
  vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
@@ -365,7 +315,6 @@ class DiffusionAsShaderPipeline:
365
  outline=tuple(color),
366
  )
367
 
368
- @spaces.GPU
369
  def visualize_tracking_moge(self, points, mask, save_tracking=True):
370
  """Visualize tracking results from MoGe model
371
 
@@ -399,8 +348,6 @@ class DiffusionAsShaderPipeline:
399
  normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
400
  colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
401
  colors = colors.astype(np.uint8)
402
- # colors = colors * mask[..., None]
403
- # points = points * mask[None, :, :, None]
404
 
405
  points = points.reshape(T, -1, 3)
406
  colors = colors.reshape(-1, 3)
@@ -408,7 +355,7 @@ class DiffusionAsShaderPipeline:
408
  # Initialize list to store frames
409
  frames = []
410
 
411
- for i, pts_i in enumerate(tqdm(points)):
412
  pixels, depths = pts_i[..., :2], pts_i[..., 2]
413
  pixels[..., 0] = pixels[..., 0] * W
414
  pixels[..., 1] = pixels[..., 1] * H
@@ -451,8 +398,178 @@ class DiffusionAsShaderPipeline:
451
  tracking_path = None
452
 
453
  return tracking_path, tracking_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
- @spaces.GPU(duration=240)
456
  def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
457
  """Generate final video with motion transfer
458
 
@@ -478,7 +595,7 @@ class DiffusionAsShaderPipeline:
478
  tracking_tensor=tracking_tensor,
479
  image_tensor=img_cond_tensor,
480
  output_path=final_output,
481
- num_inference_steps=50,
482
  guidance_scale=6.0,
483
  dtype=torch.bfloat16,
484
  fps=self.fps
@@ -493,7 +610,6 @@ class DiffusionAsShaderPipeline:
493
  """
494
  self.object_motion = motion_type
495
 
496
- @spaces.GPU(duration=120)
497
  class FirstFrameRepainter:
498
  def __init__(self, gpu_id=0, output_dir='outputs'):
499
  """Initialize FirstFrameRepainter
@@ -506,8 +622,7 @@ class FirstFrameRepainter:
506
  self.output_dir = output_dir
507
  self.max_depth = 65.0
508
  os.makedirs(output_dir, exist_ok=True)
509
-
510
- @spaces.GPU(duration=120)
511
  def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
512
  """Repaint first frame using Flux
513
 
@@ -599,48 +714,158 @@ class CameraMotionGenerator:
599
  fx = fy = (W / 2) / math.tan(fov_rad / 2)
600
 
601
  self.intr[0, 0] = fx
602
- self.intr[1, 1] = fy
 
 
603
 
604
- def _apply_poses(self, pts, poses):
605
  """
 
 
606
  Args:
607
- pts (torch.Tensor): pointclouds coordinates [T, N, 3]
608
- intr (torch.Tensor): camera intrinsics [T, 3, 3]
609
- poses (numpy.ndarray): camera poses [T, 4, 4]
 
 
 
610
  """
611
- if isinstance(poses, np.ndarray):
612
- poses = torch.from_numpy(poses)
613
-
614
- intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float)
615
- T, N, _ = pts.shape
616
- ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
617
- pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
618
- pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
619
- pts_cam[:,:, :3] *= pts[:, :, 2:3]
620
-
621
- # to homogeneous
622
- pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
 
 
 
 
 
623
 
624
- if poses.shape[0] == 1:
625
- poses = poses.repeat(T, 1, 1)
626
- elif poses.shape[0] != T:
627
- raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
628
 
629
- poses = poses.to(torch.float).to(self.device)
630
- pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
631
- pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
632
- pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
- return pts_proj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
- def w2s(self, pts, poses):
637
  if isinstance(poses, np.ndarray):
638
  poses = torch.from_numpy(poses)
639
  assert poses.shape[0] == self.frame_num
640
  poses = poses.to(torch.float32).to(self.device)
641
  T, N, _ = pts.shape # (T, N, 3)
642
  intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
643
- # Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
644
  ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
645
  points_world_h = torch.cat([pts, ones], dim=-1)
646
  points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
@@ -649,22 +874,21 @@ class CameraMotionGenerator:
649
  points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
650
 
651
  uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
652
-
653
- # Step 5: 提取深度 (Z) 并拼接
654
  depth = points_camera[:, :, 2:3] # (T, N, 1)
655
  uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
656
 
657
- return uvd # 屏幕坐标 + 深度 (T, N, 3)
658
-
659
- def apply_motion_on_pts(self, pts, camera_motion):
660
- tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
661
- return tracking_pts
662
 
663
  def set_intr(self, K):
664
  if isinstance(K, np.ndarray):
665
  K = torch.from_numpy(K)
666
  self.intr = K.to(self.device)
667
 
 
 
 
 
 
668
  def rot_poses(self, angle, axis='y'):
669
  """Generate a single rotation matrix
670
 
@@ -783,26 +1007,6 @@ class CameraMotionGenerator:
783
  camera_poses = np.concatenate(cam_poses, axis=0)
784
  return torch.from_numpy(camera_poses).to(self.device)
785
 
786
- def rot(self, pts, angle, axis):
787
- """
788
- pts: torch.Tensor, (T, N, 2)
789
- """
790
- rot_mats = self.rot_poses(angle, axis)
791
- pts = self.apply_motion_on_pts(pts, rot_mats)
792
- return pts
793
-
794
- def trans(self, pts, dx, dy, dz):
795
- if pts.shape[-1] != 3:
796
- raise ValueError("points should be in the 3d coordinate.")
797
- trans_mats = self.trans_poses(dx, dy, dz)
798
- pts = self.apply_motion_on_pts(pts, trans_mats)
799
- return pts
800
-
801
- def spiral(self, pts, radius):
802
- spiral_poses = self.spiral_poses(radius)
803
- pts = self.apply_motion_on_pts(pts, spiral_poses)
804
- return pts
805
-
806
  def get_default_motion(self):
807
  """Parse motion parameters and generate corresponding motion matrices
808
 
@@ -820,6 +1024,7 @@ class CameraMotionGenerator:
820
  - if not specified, defaults to 0-49
821
  - frames after end_frame will maintain the final transformation
822
  - for combined transformations, they are applied in sequence
 
823
 
824
  Returns:
825
  torch.Tensor: Motion matrices [num_frames, 4, 4]
 
22
  from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
23
 
24
  from submodules.MoGe.moge.model import MoGeModel
25
+
26
  from image_gen_aux import DepthPreprocessor
27
  from moviepy.editor import ImageSequenceClip
 
28
 
29
  class DiffusionAsShaderPipeline:
30
  def __init__(self, gpu_id=0, output_dir='outputs'):
 
45
  # device
46
  self.device = f"cuda:{gpu_id}"
47
  torch.cuda.set_device(gpu_id)
48
+ self.dtype = torch.bfloat16
49
 
50
  # files
51
  self.output_dir = output_dir
 
57
  transforms.ToTensor()
58
  ])
59
 
 
60
  @torch.no_grad()
61
  def _infer(
62
  self,
 
65
  tracking_tensor: torch.Tensor = None,
66
  image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
67
  output_path: str = "./output.mp4",
68
+ num_inference_steps: int = 25,
69
  guidance_scale: float = 6.0,
70
  num_videos_per_prompt: int = 1,
71
  dtype: torch.dtype = torch.bfloat16,
 
114
  pipe.text_encoder.eval()
115
  pipe.vae.eval()
116
 
117
+ self.dtype = dtype
118
+
119
  # Process tracking tensor
120
  tracking_maps = tracking_tensor.float() # [T, C, H, W]
121
  tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
 
169
 
170
  def _set_camera_motion(self, camera_motion):
171
  self.camera_motion = camera_motion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  ##============= SpatialTracker =============##
174
+
175
  def generate_tracking_spatracker(self, video_tensor, density=70):
176
  """Generate tracking video
177
 
 
184
  print("Loading tracking models...")
185
  # Load tracking model
186
  tracker = SpaTrackerPredictor(
187
+ checkpoint=os.path.join(project_root, 'checkpoints/spaT_final.pth'),
188
  interp_shape=(384, 576),
189
  seq_length=12
190
  ).to(self.device)
 
219
  progressive_tracking=False
220
  )
221
 
222
+ return pred_tracks.squeeze(0), pred_visibility.squeeze(0), T_Firsts
223
 
224
  finally:
225
  # Clean up GPU memory
226
  del tracker, self.depth_preprocessor
227
  torch.cuda.empty_cache()
228
 
 
229
  def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
230
  video = video.unsqueeze(0).to(self.device)
231
  vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
 
315
  outline=tuple(color),
316
  )
317
 
 
318
  def visualize_tracking_moge(self, points, mask, save_tracking=True):
319
  """Visualize tracking results from MoGe model
320
 
 
348
  normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
349
  colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
350
  colors = colors.astype(np.uint8)
 
 
351
 
352
  points = points.reshape(T, -1, 3)
353
  colors = colors.reshape(-1, 3)
 
355
  # Initialize list to store frames
356
  frames = []
357
 
358
+ for i, pts_i in enumerate(tqdm(points, desc="rendering frames")):
359
  pixels, depths = pts_i[..., :2], pts_i[..., 2]
360
  pixels[..., 0] = pixels[..., 0] * W
361
  pixels[..., 1] = pixels[..., 1] * H
 
398
  tracking_path = None
399
 
400
  return tracking_path, tracking_video
401
+
402
+
403
+ ##============= CoTracker =============##
404
+
405
+ def generate_tracking_cotracker(self, video_tensor, density=70):
406
+ """Generate tracking video
407
+
408
+ Args:
409
+ video_tensor (torch.Tensor): Input video tensor
410
+
411
+ Returns:
412
+ tuple: (pred_tracks, pred_visibility)
413
+ - pred_tracks (torch.Tensor): Tracking points with depth [T, N, 3]
414
+ - pred_visibility (torch.Tensor): Visibility mask [T, N, 1]
415
+ """
416
+ # Generate tracking points
417
+ cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(self.device)
418
+
419
+ # Load depth model
420
+ if not hasattr(self, 'depth_preprocessor') or self.depth_preprocessor is None:
421
+ self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
422
+ self.depth_preprocessor.to(self.device)
423
+
424
+ try:
425
+ video = video_tensor.unsqueeze(0).to(self.device)
426
+
427
+ # Process all frames to get depth maps
428
+ video_depths = []
429
+ for i in tqdm(range(video_tensor.shape[0]), desc="estimating depth"):
430
+ frame = (video_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
431
+ depth = self.depth_preprocessor(Image.fromarray(frame))[0]
432
+ depth_tensor = transforms.ToTensor()(depth) # [1, H, W]
433
+ video_depths.append(depth_tensor)
434
+
435
+ video_depth = torch.stack(video_depths, dim=0).to(self.device) # [T, 1, H, W]
436
+
437
+ # Get tracking points and visibility
438
+ print("tracking...")
439
+ pred_tracks, pred_visibility = cotracker(video, grid_size=density) # B T N 2, B T N 1
440
+
441
+ # Extract dimensions
442
+ B, T, N, _ = pred_tracks.shape
443
+ H, W = video_depth.shape[2], video_depth.shape[3]
444
+
445
+ # Create output tensor with depth
446
+ pred_tracks_with_depth = torch.zeros((B, T, N, 3), device=self.device)
447
+ pred_tracks_with_depth[:, :, :, :2] = pred_tracks # Copy x,y coordinates
448
+
449
+ # Vectorized approach to get depths for all points
450
+ # Reshape pred_tracks to process all batches and frames at once
451
+ flat_tracks = pred_tracks.reshape(B*T, N, 2)
452
+
453
+ # Clamp coordinates to valid image bounds
454
+ x_coords = flat_tracks[:, :, 0].clamp(0, W-1).long() # [B*T, N]
455
+ y_coords = flat_tracks[:, :, 1].clamp(0, H-1).long() # [B*T, N]
456
+
457
+ # Get depths for all points at once
458
+ # For each point in the flattened batch, get its depth from the corresponding frame
459
+ depths = torch.zeros((B*T, N), device=self.device)
460
+ for bt in range(B*T):
461
+ t = bt % T # Time index
462
+ depths[bt] = video_depth[t, 0, y_coords[bt], x_coords[bt]]
463
+
464
+ # Reshape depths back to [B, T, N] and assign to output tensor
465
+ pred_tracks_with_depth[:, :, :, 2] = depths.reshape(B, T, N)
466
+
467
+ return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
468
+
469
+ finally:
470
+ del cotracker
471
+ torch.cuda.empty_cache()
472
+
473
+ def visualize_tracking_cotracker(self, points, vis_mask=None, save_tracking=True, point_wise=4, video_size=(480, 720)):
474
+ """Visualize tracking results from CoTracker
475
+
476
+ Args:
477
+ points (torch.Tensor): Points array of shape [T, N, 3]
478
+ vis_mask (torch.Tensor): Visibility mask of shape [T, N, 1]
479
+ save_tracking (bool): Whether to save tracking video
480
+ point_wise (int): Size of points in visualization
481
+ video_size (tuple): Render size (height, width)
482
+
483
+ Returns:
484
+ tuple: (tracking_path, tracking_video)
485
+ """
486
+ # Move tensors to CPU and convert to numpy
487
+ if isinstance(points, torch.Tensor):
488
+ points = points.detach().cpu().numpy()
489
+
490
+ if vis_mask is not None and isinstance(vis_mask, torch.Tensor):
491
+ vis_mask = vis_mask.detach().cpu().numpy()
492
+ # Reshape if needed
493
+ if vis_mask.ndim == 3 and vis_mask.shape[2] == 1:
494
+ vis_mask = vis_mask.squeeze(-1)
495
+
496
+ T, N, _ = points.shape
497
+ H, W = video_size
498
+
499
+ if vis_mask is None:
500
+ vis_mask = np.ones((T, N), dtype=bool)
501
+
502
+ colors = np.zeros((N, 3), dtype=np.uint8)
503
+
504
+ first_frame_pts = points[0]
505
+
506
+ u_min, u_max = 0, W
507
+ u_normalized = np.clip((first_frame_pts[:, 0] - u_min) / (u_max - u_min), 0, 1)
508
+ colors[:, 0] = (u_normalized * 255).astype(np.uint8)
509
+
510
+ v_min, v_max = 0, H
511
+ v_normalized = np.clip((first_frame_pts[:, 1] - v_min) / (v_max - v_min), 0, 1)
512
+ colors[:, 1] = (v_normalized * 255).astype(np.uint8)
513
+
514
+ z_values = first_frame_pts[:, 2]
515
+ if np.all(z_values == 0):
516
+ colors[:, 2] = np.random.randint(0, 256, N, dtype=np.uint8)
517
+ else:
518
+ inv_z = 1 / (z_values + 1e-10)
519
+ p2 = np.percentile(inv_z, 2)
520
+ p98 = np.percentile(inv_z, 98)
521
+ normalized_z = np.clip((inv_z - p2) / (p98 - p2 + 1e-10), 0, 1)
522
+ colors[:, 2] = (normalized_z * 255).astype(np.uint8)
523
+
524
+ frames = []
525
+
526
+ for i in tqdm(range(T), desc="rendering frames"):
527
+ pts_i = points[i]
528
+
529
+ visibility = vis_mask[i]
530
+
531
+ pixels, depths = pts_i[visibility, :2], pts_i[visibility, 2]
532
+ pixels = pixels.astype(int)
533
+
534
+ in_frame = self.valid_mask(pixels, W, H)
535
+ pixels = pixels[in_frame]
536
+ depths = depths[in_frame]
537
+ frame_rgb = colors[visibility][in_frame]
538
+
539
+ img = Image.fromarray(np.zeros((H, W, 3), dtype=np.uint8), mode="RGB")
540
+
541
+ sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
542
+ sorted_rgb = frame_rgb[sort_index]
543
+
544
+ for j in range(sorted_pixels.shape[0]):
545
+ self.draw_rectangle(
546
+ img,
547
+ coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
548
+ side_length=point_wise,
549
+ color=sorted_rgb[j],
550
+ )
551
+
552
+ frames.append(np.array(img))
553
+
554
+ # Convert frames to video tensor in range [0,1]
555
+ tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
556
+
557
+ tracking_path = None
558
+ if save_tracking:
559
+ try:
560
+ tracking_path = os.path.join(self.output_dir, "tracking_video_cotracker.mp4")
561
+ # Convert back to uint8 for saving
562
+ uint8_frames = [frame.astype(np.uint8) for frame in frames]
563
+ clip = ImageSequenceClip(uint8_frames, fps=self.fps)
564
+ clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
565
+ print(f"Video saved to {tracking_path}")
566
+ except Exception as e:
567
+ print(f"Warning: Failed to save tracking video: {e}")
568
+ tracking_path = None
569
+
570
+ return tracking_path, tracking_video
571
+
572
 
 
573
  def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
574
  """Generate final video with motion transfer
575
 
 
595
  tracking_tensor=tracking_tensor,
596
  image_tensor=img_cond_tensor,
597
  output_path=final_output,
598
+ num_inference_steps=25,
599
  guidance_scale=6.0,
600
  dtype=torch.bfloat16,
601
  fps=self.fps
 
610
  """
611
  self.object_motion = motion_type
612
 
 
613
  class FirstFrameRepainter:
614
  def __init__(self, gpu_id=0, output_dir='outputs'):
615
  """Initialize FirstFrameRepainter
 
622
  self.output_dir = output_dir
623
  self.max_depth = 65.0
624
  os.makedirs(output_dir, exist_ok=True)
625
+
 
626
  def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
627
  """Repaint first frame using Flux
628
 
 
714
  fx = fy = (W / 2) / math.tan(fov_rad / 2)
715
 
716
  self.intr[0, 0] = fx
717
+ self.intr[1, 1] = fy
718
+
719
+ self.extr = torch.eye(4, device=device)
720
 
721
+ def s2w_vggt(self, points, extrinsics, intrinsics):
722
  """
723
+ Transform points from pixel coordinates to world coordinates
724
+
725
  Args:
726
+ points: Point cloud data of shape [T, N, 3] in uvz format
727
+ extrinsics: Camera extrinsic matrices [B, T, 3, 4] or [T, 3, 4]
728
+ intrinsics: Camera intrinsic matrices [B, T, 3, 3] or [T, 3, 3]
729
+
730
+ Returns:
731
+ world_points: Point cloud in world coordinates [T, N, 3]
732
  """
733
+ if isinstance(points, torch.Tensor):
734
+ points = points.detach().cpu().numpy()
735
+
736
+ if isinstance(extrinsics, torch.Tensor):
737
+ extrinsics = extrinsics.detach().cpu().numpy()
738
+ # Handle batch dimension
739
+ if extrinsics.ndim == 4: # [B, T, 3, 4]
740
+ extrinsics = extrinsics[0] # Take first batch
741
+
742
+ if isinstance(intrinsics, torch.Tensor):
743
+ intrinsics = intrinsics.detach().cpu().numpy()
744
+ # Handle batch dimension
745
+ if intrinsics.ndim == 4: # [B, T, 3, 3]
746
+ intrinsics = intrinsics[0] # Take first batch
747
+
748
+ T, N, _ = points.shape
749
+ world_points = np.zeros_like(points)
750
 
751
+ # Extract uvz coordinates
752
+ uvz = points
753
+ valid_mask = uvz[..., 2] > 0
 
754
 
755
+ # Create homogeneous coordinates [u, v, 1]
756
+ uv_homogeneous = np.concatenate([uvz[..., :2], np.ones((T, N, 1))], axis=-1)
757
+
758
+ # Transform from pixel to camera coordinates
759
+ for i in range(T):
760
+ K = intrinsics[i]
761
+ K_inv = np.linalg.inv(K)
762
+
763
+ R = extrinsics[i, :, :3]
764
+ t = extrinsics[i, :, 3]
765
+
766
+ R_inv = np.linalg.inv(R)
767
+
768
+ valid_indices = np.where(valid_mask[i])[0]
769
+
770
+ if len(valid_indices) > 0:
771
+ valid_uv = uv_homogeneous[i, valid_indices]
772
+ valid_z = uvz[i, valid_indices, 2]
773
+
774
+ valid_xyz_camera = valid_uv @ K_inv.T
775
+ valid_xyz_camera = valid_xyz_camera * valid_z[:, np.newaxis]
776
+
777
+ # Transform from camera to world coordinates: X_world = R^-1 * (X_camera - t)
778
+ valid_world_points = (valid_xyz_camera - t) @ R_inv.T
779
+
780
+ world_points[i, valid_indices] = valid_world_points
781
+
782
+ return world_points
783
 
784
+ def w2s_vggt(self, world_points, extrinsics, intrinsics, poses=None):
785
+ """
786
+ Project points from world coordinates to camera view
787
+
788
+ Args:
789
+ world_points: Point cloud in world coordinates [T, N, 3]
790
+ extrinsics: Original camera extrinsic matrices [B, T, 3, 4] or [T, 3, 4]
791
+ intrinsics: Camera intrinsic matrices [B, T, 3, 3] or [T, 3, 3]
792
+ poses: Camera pose matrices [T, 4, 4], if None use first frame extrinsics
793
+
794
+ Returns:
795
+ camera_points: Point cloud in camera coordinates [T, N, 3] in uvz format
796
+ """
797
+ if isinstance(world_points, torch.Tensor):
798
+ world_points = world_points.detach().cpu().numpy()
799
+
800
+ if isinstance(extrinsics, torch.Tensor):
801
+ extrinsics = extrinsics.detach().cpu().numpy()
802
+ if extrinsics.ndim == 4:
803
+ extrinsics = extrinsics[0]
804
+
805
+ if isinstance(intrinsics, torch.Tensor):
806
+ intrinsics = intrinsics.detach().cpu().numpy()
807
+ if intrinsics.ndim == 4:
808
+ intrinsics = intrinsics[0]
809
+
810
+ T, N, _ = world_points.shape
811
+
812
+ # If no poses provided, use first frame extrinsics
813
+ if poses is None:
814
+ pose1 = np.eye(4)
815
+ pose1[:3, :3] = extrinsics[0, :, :3]
816
+ pose1[:3, 3] = extrinsics[0, :, 3]
817
+
818
+ camera_poses = np.tile(pose1[np.newaxis, :, :], (T, 1, 1))
819
+ else:
820
+ if isinstance(poses, torch.Tensor):
821
+ camera_poses = poses.cpu().numpy()
822
+ else:
823
+ camera_poses = poses
824
+
825
+ # Scale translation by 1/5
826
+ scaled_poses = camera_poses.copy()
827
+ scaled_poses[:, :3, 3] = camera_poses[:, :3, 3] / 5.0
828
+ camera_poses = scaled_poses
829
+
830
+ # Add homogeneous coordinates
831
+ ones = np.ones([T, N, 1])
832
+ world_points_hom = np.concatenate([world_points, ones], axis=-1)
833
+
834
+ # Transform points using batch matrix multiplication
835
+ pts_cam_hom = np.matmul(world_points_hom, np.transpose(camera_poses, (0, 2, 1)))
836
+ pts_cam = pts_cam_hom[..., :3]
837
+
838
+ # Extract depth information
839
+ depths = pts_cam[..., 2:3]
840
+ valid_mask = depths[..., 0] > 0
841
+
842
+ # Normalize coordinates
843
+ normalized_pts = pts_cam / (depths + 1e-10)
844
+
845
+ # Apply intrinsic matrix for projection
846
+ pts_pixel = np.matmul(normalized_pts, np.transpose(intrinsics, (0, 2, 1)))
847
+
848
+ # Extract pixel coordinates
849
+ u = pts_pixel[..., 0:1]
850
+ v = pts_pixel[..., 1:2]
851
+
852
+ # Set invalid points to zero
853
+ u[~valid_mask] = 0
854
+ v[~valid_mask] = 0
855
+ depths[~valid_mask] = 0
856
+
857
+ # Return points in uvz format
858
+ result = np.concatenate([u, v, depths], axis=-1)
859
+
860
+ return torch.from_numpy(result)
861
 
862
+ def w2s_moge(self, pts, poses):
863
  if isinstance(poses, np.ndarray):
864
  poses = torch.from_numpy(poses)
865
  assert poses.shape[0] == self.frame_num
866
  poses = poses.to(torch.float32).to(self.device)
867
  T, N, _ = pts.shape # (T, N, 3)
868
  intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
 
869
  ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
870
  points_world_h = torch.cat([pts, ones], dim=-1)
871
  points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
 
874
  points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
875
 
876
  uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
 
 
877
  depth = points_camera[:, :, 2:3] # (T, N, 1)
878
  uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
879
 
880
+ return uvd
 
 
 
 
881
 
882
  def set_intr(self, K):
883
  if isinstance(K, np.ndarray):
884
  K = torch.from_numpy(K)
885
  self.intr = K.to(self.device)
886
 
887
+ def set_extr(self, extr):
888
+ if isinstance(extr, np.ndarray):
889
+ extr = torch.from_numpy(extr)
890
+ self.extr = extr.to(self.device)
891
+
892
  def rot_poses(self, angle, axis='y'):
893
  """Generate a single rotation matrix
894
 
 
1007
  camera_poses = np.concatenate(cam_poses, axis=0)
1008
  return torch.from_numpy(camera_poses).to(self.device)
1009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1010
  def get_default_motion(self):
1011
  """Parse motion parameters and generate corresponding motion matrices
1012
 
 
1024
  - if not specified, defaults to 0-49
1025
  - frames after end_frame will maintain the final transformation
1026
  - for combined transformations, they are applied in sequence
1027
+ - moving left, up and zoom out is positive in video
1028
 
1029
  Returns:
1030
  torch.Tensor: Motion matrices [num_frames, 4, 4]