EXCAI commited on
Commit
182f943
·
1 Parent(s): 5b2a969

deploy cotracker on cpu

Browse files
Files changed (3) hide show
  1. app.py +60 -6
  2. models/pipelines.py +1 -1
  3. submodules/vggt +1 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
  import gradio as gr
4
  import torch
5
  import argparse
@@ -8,7 +9,8 @@ import numpy as np
8
  import torchvision.transforms as transforms
9
  from moviepy.editor import VideoFileClip
10
  from diffusers.utils import load_image, load_video
11
- import spaces
 
12
 
13
  project_root = os.path.dirname(os.path.abspath(__file__))
14
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
@@ -247,7 +249,7 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
247
  print('Export tracking video via MoGe')
248
  else:
249
  # 使用 cotracker
250
- pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
251
  tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
252
  print('Export tracking video via cotracker')
253
 
@@ -312,8 +314,8 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
312
  )
313
  print('Export tracking video via MoGe')
314
  else:
315
- # 使用 cotracker
316
- pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
317
 
318
  t, c, h, w = video_tensor.shape
319
  new_width = 518
@@ -419,8 +421,8 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
419
  )
420
  print('Export tracking video via MoGe')
421
  else:
422
- # 使用 cotracker
423
- pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
424
 
425
  t, c, h, w = video_tensor.shape
426
  new_width = 518
@@ -515,6 +517,58 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
515
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
516
  return None, None
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  # Create Gradio interface with updated layout
519
  with gr.Blocks(title="Diffusion as Shader") as demo:
520
  gr.Markdown("# Diffusion as Shader Web UI")
 
1
  import os
2
  import sys
3
+ import spaces
4
  import gradio as gr
5
  import torch
6
  import argparse
 
9
  import torchvision.transforms as transforms
10
  from moviepy.editor import VideoFileClip
11
  from diffusers.utils import load_image, load_video
12
+ from tqdm import tqdm
13
+ from image_gen_aux import DepthPreprocessor
14
 
15
  project_root = os.path.dirname(os.path.abspath(__file__))
16
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
 
249
  print('Export tracking video via MoGe')
250
  else:
251
  # 使用 cotracker
252
+ pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
253
  tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
254
  print('Export tracking video via cotracker')
255
 
 
314
  )
315
  print('Export tracking video via MoGe')
316
  else:
317
+ # 使用在CPU上运行的cotracker
318
+ pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
319
 
320
  t, c, h, w = video_tensor.shape
321
  new_width = 518
 
421
  )
422
  print('Export tracking video via MoGe')
423
  else:
424
+ # 使用在CPU上运行的cotracker
425
+ pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
426
 
427
  t, c, h, w = video_tensor.shape
428
  new_width = 518
 
517
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
518
  return None, None
519
 
520
+ def generate_tracking_cotracker(video_tensor, density=30):
521
+ """在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
522
+
523
+ 参数:
524
+ video_tensor (torch.Tensor): 输入视频张量
525
+ density (int): 跟踪点的密度
526
+
527
+ 返回:
528
+ tuple: (pred_tracks, pred_visibility)
529
+ """
530
+ cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to("cpu")
531
+ depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti").to("cpu")
532
+
533
+ video = video_tensor.unsqueeze(0).to("cpu")
534
+
535
+ # 只处理第一帧以获取深度图
536
+ print("estimating depth for first frame...")
537
+ frame = (video_tensor[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
538
+ depth = depth_preprocessor(Image.fromarray(frame))[0]
539
+ depth_tensor = transforms.ToTensor()(depth) # [1, H, W]
540
+
541
+ # 获取跟踪点和可见性
542
+ print("tracking on CPU...")
543
+ pred_tracks, pred_visibility = cotracker(video, grid_size=density) # B T N 2, B T N 1
544
+
545
+ # 提取维度
546
+ B, T, N, _ = pred_tracks.shape
547
+ H, W = depth_tensor.shape[1], depth_tensor.shape[2]
548
+
549
+ # 创建带深度的输出张量
550
+ pred_tracks_with_depth = torch.zeros((B, T, N, 3), device="cpu")
551
+ pred_tracks_with_depth[:, :, :, :2] = pred_tracks # 复制x,y坐标
552
+
553
+ # 使用矩阵运算一次性处理所有帧和点
554
+ # 重塑pred_tracks为[B*T*N, 2]以便于处理
555
+ flat_tracks = pred_tracks.reshape(-1, 2)
556
+
557
+ # 将坐标限制在有效图像边界内
558
+ x_coords = flat_tracks[:, 0].clamp(0, W-1).long()
559
+ y_coords = flat_tracks[:, 1].clamp(0, H-1).long()
560
+
561
+ # 从第一帧的深度图获取所有点的深度值
562
+ depths = depth_tensor[0, y_coords, x_coords]
563
+
564
+ # 重塑回原始形状并分配给输出张量
565
+ pred_tracks_with_depth[:, :, :, 2] = depths.reshape(B, T, N)
566
+
567
+ del cotracker,depth_preprocessor
568
+
569
+ # 将结果返回
570
+ return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
571
+
572
  # Create Gradio interface with updated layout
573
  with gr.Blocks(title="Diffusion as Shader") as demo:
574
  gr.Markdown("# Diffusion as Shader Web UI")
models/pipelines.py CHANGED
@@ -470,7 +470,7 @@ class DiffusionAsShaderPipeline:
470
  del cotracker
471
  torch.cuda.empty_cache()
472
 
473
- def visualize_tracking_cotracker(self, points, vis_mask=None, save_tracking=True, point_wise=4, video_size=(480, 720)):
474
  """Visualize tracking results from CoTracker
475
 
476
  Args:
 
470
  del cotracker
471
  torch.cuda.empty_cache()
472
 
473
+ def visualize_tracking_cotracker(self, points, vis_mask=None, save_tracking=True, point_wise=10, video_size=(480, 720)):
474
  """Visualize tracking results from CoTracker
475
 
476
  Args:
submodules/vggt ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b02cc03ceee70821ed1231a530c1992507ef9862