Spaces:
Running
Running
yansong1616
commited on
Rename app_gys.py to app.py
Browse files- app_gys.py → app.py +735 -735
app_gys.py → app.py
RENAMED
@@ -1,735 +1,735 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
|
3 |
-
import argparse
|
4 |
-
import gradio
|
5 |
-
import os
|
6 |
-
import torch
|
7 |
-
import numpy as np
|
8 |
-
import tempfile
|
9 |
-
import functools
|
10 |
-
import trimesh
|
11 |
-
import copy
|
12 |
-
from scipy.spatial.transform import Rotation
|
13 |
-
|
14 |
-
from dust3r.inference import inference, load_model
|
15 |
-
from dust3r.image_pairs import make_pairs
|
16 |
-
from dust3r.utils.image import load_images, rgb, resize_images
|
17 |
-
from dust3r.utils.device import to_numpy
|
18 |
-
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
19 |
-
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
20 |
-
from sam2.build_sam import build_sam2_video_predictor
|
21 |
-
import matplotlib.pyplot as plt
|
22 |
-
|
23 |
-
import shutil
|
24 |
-
import json
|
25 |
-
from PIL import Image
|
26 |
-
import math
|
27 |
-
import cv2
|
28 |
-
|
29 |
-
plt.ion()
|
30 |
-
|
31 |
-
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
32 |
-
batch_size = 1
|
33 |
-
|
34 |
-
########################## 引入grounding_dino #############################
|
35 |
-
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
36 |
-
def get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text):
|
37 |
-
# init grounding dino model from huggingface
|
38 |
-
model_id = "IDEA-Research/grounding-dino-tiny"
|
39 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
-
processor = AutoProcessor.from_pretrained(model_id)
|
41 |
-
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
42 |
-
|
43 |
-
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
44 |
-
# VERY important: text queries need to be lowercased + end with a dot
|
45 |
-
|
46 |
-
|
47 |
-
"""
|
48 |
-
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
|
49 |
-
"""
|
50 |
-
# prompt grounding dino to get the box coordinates on specific frame
|
51 |
-
frame_names = [
|
52 |
-
p for p in os.listdir(video_dir)
|
53 |
-
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
54 |
-
]
|
55 |
-
# frame_names.sort(key=lambda p: os.path.splitext(p)[0])
|
56 |
-
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
|
57 |
-
image = Image.open(img_path)
|
58 |
-
|
59 |
-
|
60 |
-
# run Grounding DINO on the image
|
61 |
-
inputs = processor(images=image, text=input_text, return_tensors="pt").to(device)
|
62 |
-
with torch.no_grad():
|
63 |
-
outputs = grounding_model(**inputs)
|
64 |
-
|
65 |
-
results = processor.post_process_grounded_object_detection(
|
66 |
-
outputs,
|
67 |
-
inputs.input_ids,
|
68 |
-
box_threshold=0.25,
|
69 |
-
text_threshold=0.3,
|
70 |
-
target_sizes=[image.size[::-1]]
|
71 |
-
)
|
72 |
-
return results[0]["boxes"], results[0]["labels"]
|
73 |
-
|
74 |
-
def get_masks_from_grounded_sam2(h, w, predictor, video_dir, input_text):
|
75 |
-
|
76 |
-
inference_state = predictor.init_state(video_path=video_dir)
|
77 |
-
predictor.reset_state(inference_state)
|
78 |
-
|
79 |
-
ann_frame_idx = 0 # the frame index we interact with
|
80 |
-
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
81 |
-
print("Running Groundding DINO......")
|
82 |
-
input_boxes, OBJECTS = get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text)
|
83 |
-
print("Groundding DINO run over!")
|
84 |
-
if(len(OBJECTS) < 1):
|
85 |
-
raise gradio.Error("The images you input do not contain the target in '{}'".format(input_text))
|
86 |
-
# 给第一个帧输入由grounding_dino输出的boxes作为prompts
|
87 |
-
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes)):
|
88 |
-
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
89 |
-
inference_state=inference_state,
|
90 |
-
frame_idx=ann_frame_idx,
|
91 |
-
obj_id=ann_obj_id,
|
92 |
-
box=box,
|
93 |
-
)
|
94 |
-
break #只加入第一个box
|
95 |
-
|
96 |
-
|
97 |
-
# sam2获取所有帧的分割结果
|
98 |
-
video_segments = {} # video_segments contains the per-frame segmentation results
|
99 |
-
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
100 |
-
video_segments[out_frame_idx] = {
|
101 |
-
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
102 |
-
for i, out_obj_id in enumerate(out_obj_ids)
|
103 |
-
}
|
104 |
-
|
105 |
-
resize_mask = resize_mask_to_img(video_segments, w, h)
|
106 |
-
return resize_mask
|
107 |
-
|
108 |
-
|
109 |
-
def handle_uploaded_files(uploaded_files, target_folder):
|
110 |
-
# 创建目标文件夹
|
111 |
-
if not os.path.exists(target_folder):
|
112 |
-
os.makedirs(target_folder)
|
113 |
-
|
114 |
-
# 遍历上传的文件,移动到目标文件夹
|
115 |
-
for file in uploaded_files:
|
116 |
-
file_path = file.name # 文件的临时路径
|
117 |
-
file_name = os.path.basename(file_path) # 文件名
|
118 |
-
target_path = os.path.join(target_folder, file_name)
|
119 |
-
shutil.copy2(file_path, target_path)
|
120 |
-
print("copy images from {} to {}".format(file_path, target_path))
|
121 |
-
|
122 |
-
return target_folder
|
123 |
-
def show_mask(mask, ax, random_color=False):
|
124 |
-
if random_color:
|
125 |
-
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
126 |
-
else:
|
127 |
-
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
128 |
-
h, w = mask.shape[-2:]
|
129 |
-
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
130 |
-
ax.imshow(mask_image)
|
131 |
-
|
132 |
-
def show_mask_sam2(mask, ax, obj_id=None, random_color=False):
|
133 |
-
if random_color:
|
134 |
-
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
135 |
-
else:
|
136 |
-
cmap = plt.get_cmap("tab10")
|
137 |
-
cmap_idx = 0 if obj_id is None else obj_id
|
138 |
-
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
139 |
-
h, w = mask.shape[-2:]
|
140 |
-
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
141 |
-
ax.imshow(mask_image)
|
142 |
-
def show_points(coords, labels, ax, marker_size=375):
|
143 |
-
pos_points = coords[labels == 1]
|
144 |
-
neg_points = coords[labels == 0]
|
145 |
-
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
|
146 |
-
linewidth=1.25)
|
147 |
-
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
|
148 |
-
linewidth=1.25)
|
149 |
-
|
150 |
-
def show_box(box, ax):
|
151 |
-
x0, y0 = box[0], box[1]
|
152 |
-
w, h = box[2] - box[0], box[3] - box[1]
|
153 |
-
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
def get_args_parser():
|
158 |
-
parser = argparse.ArgumentParser()
|
159 |
-
parser_url = parser.add_mutually_exclusive_group()
|
160 |
-
parser_url.add_argument("--local_network", action='store_true', default=False,
|
161 |
-
help="make app accessible on local network: address will be set to 0.0.0.0")
|
162 |
-
parser_url.add_argument("--server_name", type=str, default=
|
163 |
-
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
|
164 |
-
parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
|
165 |
-
"If None, will search for an available port starting at 7860."),
|
166 |
-
default=None)
|
167 |
-
parser.add_argument("--weights", type=str, required=True, help="path to the model weights")
|
168 |
-
parser.add_argument("--device", type=str, default='
|
169 |
-
parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
|
170 |
-
return parser
|
171 |
-
|
172 |
-
|
173 |
-
# 将渲染的3D保存到outfile路径
|
174 |
-
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
175 |
-
cam_color=None, as_pointcloud=False, transparent_cams=False):
|
176 |
-
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
177 |
-
pts3d = to_numpy(pts3d)
|
178 |
-
imgs = to_numpy(imgs)
|
179 |
-
focals = to_numpy(focals)
|
180 |
-
cams2world = to_numpy(cams2world)
|
181 |
-
|
182 |
-
scene = trimesh.Scene()
|
183 |
-
|
184 |
-
# full pointcloud
|
185 |
-
if as_pointcloud:
|
186 |
-
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
187 |
-
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
188 |
-
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
189 |
-
scene.add_geometry(pct)
|
190 |
-
else:
|
191 |
-
meshes = []
|
192 |
-
for i in range(len(imgs)):
|
193 |
-
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
|
194 |
-
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
195 |
-
scene.add_geometry(mesh)
|
196 |
-
|
197 |
-
# add each camera
|
198 |
-
for i, pose_c2w in enumerate(cams2world):
|
199 |
-
if isinstance(cam_color, list):
|
200 |
-
camera_edge_color = cam_color[i]
|
201 |
-
else:
|
202 |
-
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
203 |
-
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
204 |
-
None if transparent_cams else imgs[i], focals[i],
|
205 |
-
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
206 |
-
|
207 |
-
rot = np.eye(4)
|
208 |
-
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
209 |
-
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
210 |
-
outfile = os.path.join(outdir, 'scene.glb')
|
211 |
-
print('(exporting 3D scene to', outfile, ')')
|
212 |
-
scene.export(file_obj=outfile)
|
213 |
-
return outfile
|
214 |
-
|
215 |
-
|
216 |
-
def get_3D_model_from_scene(outdir, scene, sam2_masks, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
217 |
-
clean_depth=False, transparent_cams=False, cam_size=0.05):
|
218 |
-
"""
|
219 |
-
extract 3D_model (glb file) from a reconstructed scene
|
220 |
-
"""
|
221 |
-
if scene is None:
|
222 |
-
return None
|
223 |
-
# post processes
|
224 |
-
if clean_depth:
|
225 |
-
scene = scene.clean_pointcloud()
|
226 |
-
if mask_sky:
|
227 |
-
scene = scene.mask_sky()
|
228 |
-
|
229 |
-
# get optimized values from scene
|
230 |
-
rgbimg = scene.imgs
|
231 |
-
|
232 |
-
focals = scene.get_focals().cpu()
|
233 |
-
cams2world = scene.get_im_poses().cpu()
|
234 |
-
# 3D pointcloud from depthmap, poses and intrinsics
|
235 |
-
pts3d = to_numpy(scene.get_pts3d())
|
236 |
-
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
237 |
-
msk = to_numpy(scene.get_masks())
|
238 |
-
|
239 |
-
assert len(msk) == len(sam2_masks)
|
240 |
-
# 将sam2输出的mask 和 dust3r输出的置信度阈值筛选后的msk取交集
|
241 |
-
for i in range(len(sam2_masks)):
|
242 |
-
msk[i] = np.logical_and(msk[i], sam2_masks[i])
|
243 |
-
|
244 |
-
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
245 |
-
transparent_cams=transparent_cams, cam_size=cam_size) # 置信度和SAM2 mask的交集
|
246 |
-
|
247 |
-
# 将视频分割成固定帧数
|
248 |
-
def video_to_frames_fix(video_path, output_folder, frame_interval=10, target_fps=6):
|
249 |
-
"""
|
250 |
-
将视频转换为图像帧,并保存为 JPEG 文件。
|
251 |
-
frame_interval
|
252 |
-
target_fps: 目标帧率(每秒保存的帧数)
|
253 |
-
"""
|
254 |
-
|
255 |
-
# 确保输出文件夹存在
|
256 |
-
if not os.path.exists(output_folder):
|
257 |
-
os.makedirs(output_folder)
|
258 |
-
# 打开视频文件
|
259 |
-
cap = cv2.VideoCapture(video_path)
|
260 |
-
# 获取视频总帧数
|
261 |
-
frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
262 |
-
|
263 |
-
# 计算动态帧间隔
|
264 |
-
frame_interval = math.ceil(frames_num / target_fps)
|
265 |
-
print(f"总帧数: {frames_num} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
|
266 |
-
frame_count = 0
|
267 |
-
saved_frame_count = 0
|
268 |
-
success, frame = cap.read()
|
269 |
-
|
270 |
-
file_list = []
|
271 |
-
# 逐帧读取视频
|
272 |
-
while success:
|
273 |
-
if frame_count % frame_interval == 0:
|
274 |
-
# 每隔 frame_interval 帧保存一次
|
275 |
-
frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
|
276 |
-
cv2.imwrite(frame_filename, frame)
|
277 |
-
file_list.append(frame_filename)
|
278 |
-
saved_frame_count += 1
|
279 |
-
frame_count += 1
|
280 |
-
success, frame = cap.read()
|
281 |
-
|
282 |
-
# 释放视频捕获对象
|
283 |
-
cap.release()
|
284 |
-
print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
|
285 |
-
return file_list
|
286 |
-
|
287 |
-
|
288 |
-
def video_to_frames(video_path, output_folder, frame_interval=10, target_fps = 2):
|
289 |
-
"""
|
290 |
-
将视频转换为图像帧,并保存为 JPEG 文件。
|
291 |
-
frame_interval:保存帧的步长
|
292 |
-
target_fps: 目标帧率(每秒保存的帧数)
|
293 |
-
"""
|
294 |
-
|
295 |
-
# 确保输出文件夹存在
|
296 |
-
if not os.path.exists(output_folder):
|
297 |
-
os.makedirs(output_folder)
|
298 |
-
# 打开视频文件
|
299 |
-
cap = cv2.VideoCapture(video_path)
|
300 |
-
# 获取视频的实际帧率
|
301 |
-
actual_fps = cap.get(cv2.CAP_PROP_FPS)
|
302 |
-
|
303 |
-
# 获取视频总帧数
|
304 |
-
frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
305 |
-
|
306 |
-
# 计算动态帧间隔
|
307 |
-
# frame_interval = math.ceil(actual_fps / target_fps)
|
308 |
-
print(f"实际帧率: {actual_fps} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
|
309 |
-
frame_count = 0
|
310 |
-
saved_frame_count = 0
|
311 |
-
success, frame = cap.read()
|
312 |
-
|
313 |
-
file_list = []
|
314 |
-
# 逐帧读取视频
|
315 |
-
while success:
|
316 |
-
if frame_count % frame_interval == 0:
|
317 |
-
# 每隔 frame_interval 帧保存一次
|
318 |
-
frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
|
319 |
-
cv2.imwrite(frame_filename, frame)
|
320 |
-
file_list.append(frame_filename)
|
321 |
-
saved_frame_count += 1
|
322 |
-
frame_count += 1
|
323 |
-
success, frame = cap.read()
|
324 |
-
|
325 |
-
# 释放视频捕获对象
|
326 |
-
cap.release()
|
327 |
-
print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
|
328 |
-
return file_list
|
329 |
-
|
330 |
-
def overlay_mask_on_image(image, mask, color=[0, 1, 0], alpha=0.5):
|
331 |
-
"""
|
332 |
-
将mask融合在image上显示。
|
333 |
-
返回融合后的图片 (H, W, 3)
|
334 |
-
"""
|
335 |
-
|
336 |
-
# 创建一个与image相同尺寸的全黑图像
|
337 |
-
mask_colored = np.zeros_like(image)
|
338 |
-
|
339 |
-
# 将mask为True的位置赋值为指定颜色
|
340 |
-
mask_colored[mask] = color
|
341 |
-
|
342 |
-
# 将彩色掩码与原图像叠加
|
343 |
-
overlay = cv2.addWeighted(image, 1 - alpha, mask_colored, alpha, 0)
|
344 |
-
|
345 |
-
return overlay
|
346 |
-
def get_reconstructed_video(sam2, outdir, model, device, image_size, image_mask, video_dir, schedule, niter, min_conf_thr,
|
347 |
-
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
348 |
-
scenegraph_type, winsize, refid, input_text):
|
349 |
-
target_dir = os.path.join(outdir, 'frames_video')
|
350 |
-
file_list = video_to_frames_fix(video_dir, target_dir)
|
351 |
-
scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, file_list, schedule, niter, min_conf_thr,
|
352 |
-
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
353 |
-
scenegraph_type, winsize, refid, target_dir, input_text)
|
354 |
-
return scene, outfile, imgs
|
355 |
-
|
356 |
-
def get_reconstructed_image(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
|
357 |
-
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
358 |
-
scenegraph_type, winsize, refid, input_text):
|
359 |
-
target_folder = handle_uploaded_files(filelist, os.path.join(outdir, 'uploaded_images'))
|
360 |
-
scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
|
361 |
-
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
362 |
-
scenegraph_type, winsize, refid, target_folder, input_text)
|
363 |
-
return scene, outfile, imgs
|
364 |
-
def get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
|
365 |
-
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
366 |
-
scenegraph_type, winsize, refid, images_folder, input_text=None):
|
367 |
-
"""
|
368 |
-
from a list of images, run dust3rWithSam2 inference, global aligner.
|
369 |
-
then run get_3D_model_from_scene
|
370 |
-
"""
|
371 |
-
imgs = load_images(filelist, size=image_size)
|
372 |
-
img_size = imgs[0]["true_shape"]
|
373 |
-
for img in imgs[1:]:
|
374 |
-
if not np.equal(img["true_shape"], img_size).all():
|
375 |
-
raise gradio.Error("Please ensure that the images you enter are of the same size")
|
376 |
-
|
377 |
-
if len(imgs) == 1:
|
378 |
-
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
379 |
-
imgs[1]['idx'] = 1
|
380 |
-
if scenegraph_type == "swin":
|
381 |
-
scenegraph_type = scenegraph_type + "-" + str(winsize)
|
382 |
-
elif scenegraph_type == "oneref":
|
383 |
-
scenegraph_type = scenegraph_type + "-" + str(refid)
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
|
390 |
-
output = inference(pairs, model, device, batch_size=batch_size)
|
391 |
-
|
392 |
-
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
|
393 |
-
scene = global_aligner(output, device=device, mode=mode)
|
394 |
-
lr = 0.01
|
395 |
-
|
396 |
-
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
397 |
-
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
# also return rgb, depth and confidence imgs
|
402 |
-
# depth is normalized with the max value for all images
|
403 |
-
# we apply the jet colormap on the confidence maps
|
404 |
-
rgbimg = scene.imgs
|
405 |
-
depths = to_numpy(scene.get_depthmaps())
|
406 |
-
confs = to_numpy([c for c in scene.im_conf])
|
407 |
-
cmap = plt.get_cmap('jet')
|
408 |
-
depths_max = max([d.max() for d in depths])
|
409 |
-
depths = [d / depths_max for d in depths]
|
410 |
-
confs_max = max([d.max() for d in confs])
|
411 |
-
confs = [cmap(d / confs_max) for d in confs]
|
412 |
-
|
413 |
-
# TODO 调用SAM2获取masks
|
414 |
-
h, w = rgbimg[0].shape[:-1]
|
415 |
-
masks = None
|
416 |
-
if not input_text or input_text.isspace(): # input_text 为空串
|
417 |
-
masks = get_masks_from_sam2(h, w, sam2, images_folder)
|
418 |
-
else:
|
419 |
-
masks = get_masks_from_grounded_sam2(h, w, sam2, images_folder, input_text) # gd-sam2
|
420 |
-
|
421 |
-
|
422 |
-
imgs = []
|
423 |
-
for i in range(len(rgbimg)):
|
424 |
-
imgs.append(rgbimg[i])
|
425 |
-
imgs.append(rgb(depths[i]))
|
426 |
-
imgs.append(rgb(confs[i]))
|
427 |
-
imgs.append(overlay_mask_on_image(rgbimg[i], masks[i])) # mask融合原图,展示SAM2的分割效果
|
428 |
-
|
429 |
-
# TODO 基于SAM2的mask过滤DUST3R的3D重建模型
|
430 |
-
outfile = get_3D_model_from_scene(outdir, scene, masks, min_conf_thr, as_pointcloud, mask_sky,
|
431 |
-
clean_depth, transparent_cams, cam_size)
|
432 |
-
return scene, outfile, imgs
|
433 |
-
|
434 |
-
|
435 |
-
def resize_mask_to_img(masks, target_width, target_height):
|
436 |
-
frame_mask = []
|
437 |
-
origin_size = masks[0][1].shape # 1表示object id
|
438 |
-
for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
|
439 |
-
# 每个frame可能包含多个object对应的mask
|
440 |
-
masks = list(objects_mask.values())
|
441 |
-
if not masks: # masks为空,即当前frame不包含object
|
442 |
-
frame_mask.append(np.ones(origin_size, dtype=bool))
|
443 |
-
else: # 将当前frame包含的所有object的mask取并集
|
444 |
-
union_mask = masks[0]
|
445 |
-
for mask in masks[1:]:
|
446 |
-
union_mask = np.logical_or(union_mask, mask)
|
447 |
-
frame_mask.append(union_mask)
|
448 |
-
resized_mask = []
|
449 |
-
for mask in frame_mask:
|
450 |
-
mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
|
451 |
-
resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
|
452 |
-
resized_mask.append(np.array(resized_mask_image) > 0)
|
453 |
-
|
454 |
-
return resized_mask
|
455 |
-
|
456 |
-
def get_masks_from_sam2(h, w, predictor, video_dir):
|
457 |
-
|
458 |
-
inference_state = predictor.init_state(video_path=video_dir)
|
459 |
-
predictor.reset_state(inference_state)
|
460 |
-
|
461 |
-
|
462 |
-
# 给一个帧添加points
|
463 |
-
points = np.array([[360, 550], [340, 400]], dtype=np.float32)
|
464 |
-
labels = np.array([1, 1], dtype=np.int32)
|
465 |
-
|
466 |
-
|
467 |
-
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
|
468 |
-
inference_state=inference_state,
|
469 |
-
frame_idx=0,
|
470 |
-
obj_id=1,
|
471 |
-
points=points,
|
472 |
-
labels=labels,
|
473 |
-
)
|
474 |
-
|
475 |
-
# sam2获取所有帧的分割结果
|
476 |
-
video_segments = {} # video_segments contains the per-frame segmentation results
|
477 |
-
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
478 |
-
video_segments[out_frame_idx] = {
|
479 |
-
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
480 |
-
for i, out_obj_id in enumerate(out_obj_ids)
|
481 |
-
}
|
482 |
-
|
483 |
-
resize_mask = resize_mask_to_img(video_segments, w, h)
|
484 |
-
return resize_mask
|
485 |
-
|
486 |
-
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
|
487 |
-
num_files = len(inputfiles) if inputfiles is not None else 1
|
488 |
-
max_winsize = max(1, (num_files - 1) // 2)
|
489 |
-
if scenegraph_type == "swin":
|
490 |
-
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
491 |
-
minimum=1, maximum=max_winsize, step=1, visible=True)
|
492 |
-
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
493 |
-
maximum=num_files - 1, step=1, visible=False)
|
494 |
-
elif scenegraph_type == "oneref":
|
495 |
-
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
496 |
-
minimum=1, maximum=max_winsize, step=1, visible=False)
|
497 |
-
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
498 |
-
maximum=num_files - 1, step=1, visible=True)
|
499 |
-
else:
|
500 |
-
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
501 |
-
minimum=1, maximum=max_winsize, step=1, visible=False)
|
502 |
-
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
503 |
-
maximum=num_files - 1, step=1, visible=False)
|
504 |
-
return winsize, refid
|
505 |
-
|
506 |
-
def process_images(imagesList):
|
507 |
-
return None
|
508 |
-
|
509 |
-
def process_videos(video):
|
510 |
-
return None
|
511 |
-
|
512 |
-
def upload_images_listener(image_size, file_list):
|
513 |
-
if len(file_list) == 1:
|
514 |
-
raise gradio.Error("Please enter images from at least two different views.")
|
515 |
-
print("Uploading image[0] to ImageMask:")
|
516 |
-
img_0 = load_images([file_list[0]], image_size)
|
517 |
-
i1 = img_0[0]['img'].squeeze(0)
|
518 |
-
rgb_img = rgb(i1)
|
519 |
-
return rgb_img
|
520 |
-
def upload_video_listener(image_size, video_dir):
|
521 |
-
cap = cv2.VideoCapture(video_dir)
|
522 |
-
success, frame = cap.read() # 第一帧
|
523 |
-
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
524 |
-
Image_frame = Image.fromarray(rgb_frame)
|
525 |
-
resized_frame = resize_images([Image_frame], image_size)
|
526 |
-
i1 = resized_frame[0]['img'].squeeze(0)
|
527 |
-
rgb_img = rgb(i1)
|
528 |
-
return rgb_img
|
529 |
-
|
530 |
-
def main_demo(sam2, tmpdirname, model, device, image_size, server_name, server_port):
|
531 |
-
|
532 |
-
# functools.partial解析:https://blog.csdn.net/wuShiJingZuo/article/details/135018810
|
533 |
-
recon_fun_image_demo = functools.partial(get_reconstructed_image,sam2, tmpdirname, model, device,
|
534 |
-
image_size)
|
535 |
-
|
536 |
-
recon_fun_video_demo = functools.partial(get_reconstructed_video, sam2, tmpdirname, model, device,
|
537 |
-
image_size)
|
538 |
-
|
539 |
-
upload_files_fun = functools.partial(upload_images_listener,image_size)
|
540 |
-
upload_video_fun = functools.partial(upload_video_listener, image_size)
|
541 |
-
with gradio.Blocks() as demo1:
|
542 |
-
scene = gradio.State(None)
|
543 |
-
gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
|
544 |
-
gradio.HTML("""<h2 style="text-align: center;">
|
545 |
-
<a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
|
546 |
-
<a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
|
547 |
-
</h2>""")
|
548 |
-
gradio.HTML("""
|
549 |
-
<div style="text-align: center;">
|
550 |
-
<h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
|
551 |
-
</div>
|
552 |
-
""")
|
553 |
-
gradio.set_static_paths(paths=["static/images/"])
|
554 |
-
project_path = "static/images/project.gif"
|
555 |
-
gradio.HTML(f"""
|
556 |
-
<div align='center' >
|
557 |
-
<img src="/file={project_path}" width='720px'>
|
558 |
-
</div>
|
559 |
-
""")
|
560 |
-
gradio.HTML("<p> \
|
561 |
-
<strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
|
562 |
-
Choose an example below 🔥 🔥 🔥 <br>\
|
563 |
-
Or, upload by yourself: <br>\
|
564 |
-
1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
|
565 |
-
<br> \
|
566 |
-
💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
|
567 |
-
💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
|
568 |
-
💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
|
569 |
-
Note: we only take the first 16 frames for the demo. \
|
570 |
-
</p>")
|
571 |
-
|
572 |
-
with gradio.Column():
|
573 |
-
with gradio.Row():
|
574 |
-
inputfiles = gradio.File(file_count="multiple")
|
575 |
-
with gradio.Column():
|
576 |
-
image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
|
577 |
-
label="prompt (提示图)", transforms=(), width=600, height=450)
|
578 |
-
input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
|
579 |
-
with gradio.Row():
|
580 |
-
schedule = gradio.Dropdown(["linear", "cosine"],
|
581 |
-
value='linear', label="schedule", info="For global alignment!")
|
582 |
-
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
|
583 |
-
label="num_iterations", info="For global alignment!")
|
584 |
-
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
|
585 |
-
value='complete', label="Scenegraph",
|
586 |
-
info="Define how to make pairs",
|
587 |
-
interactive=True)
|
588 |
-
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
589 |
-
minimum=1, maximum=1, step=1, visible=False)
|
590 |
-
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
591 |
-
|
592 |
-
run_btn = gradio.Button("Run")
|
593 |
-
|
594 |
-
with gradio.Row():
|
595 |
-
# adjust the confidence threshold
|
596 |
-
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
|
597 |
-
# adjust the camera size in the output pointcloud
|
598 |
-
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
|
599 |
-
with gradio.Row():
|
600 |
-
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
|
601 |
-
# two post process implemented
|
602 |
-
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
603 |
-
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
604 |
-
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
605 |
-
|
606 |
-
outmodel = gradio.Model3D()
|
607 |
-
outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
|
608 |
-
|
609 |
-
inputfiles.upload(upload_files_fun, inputs=inputfiles, outputs=image_mask)
|
610 |
-
|
611 |
-
run_btn.click(fn=recon_fun_image_demo, # 调用get_reconstructed_image即DUST3R模型
|
612 |
-
inputs=[image_mask, inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
|
613 |
-
mask_sky, clean_depth, transparent_cams, cam_size,
|
614 |
-
scenegraph_type, winsize, refid, input_text],
|
615 |
-
outputs=[scene, outmodel, outgallery])
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
# ## **************************** video *******************************************************
|
621 |
-
with gradio.Blocks() as demo2:
|
622 |
-
gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
|
623 |
-
gradio.HTML("""<h2 style="text-align: center;">
|
624 |
-
<a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
|
625 |
-
<a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
|
626 |
-
</h2>""")
|
627 |
-
gradio.HTML("""
|
628 |
-
<div style="text-align: center;">
|
629 |
-
<h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
|
630 |
-
</div>
|
631 |
-
""")
|
632 |
-
gradio.set_static_paths(paths=["static/images/"])
|
633 |
-
project_path = "static/images/project.gif"
|
634 |
-
gradio.HTML(f"""
|
635 |
-
<div align='center' >
|
636 |
-
<img src="/file={project_path}" width='720px'>
|
637 |
-
</div>
|
638 |
-
""")
|
639 |
-
gradio.HTML("<p> \
|
640 |
-
<strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
|
641 |
-
Choose an example below 🔥 🔥 🔥 <br>\
|
642 |
-
Or, upload by yourself: <br>\
|
643 |
-
1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
|
644 |
-
<br> \
|
645 |
-
💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
|
646 |
-
💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
|
647 |
-
💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
|
648 |
-
Note: we only take the first 16 frames for the demo. \
|
649 |
-
</p>")
|
650 |
-
|
651 |
-
with gradio.Column():
|
652 |
-
with gradio.Row():
|
653 |
-
input_video = gradio.Video(width=600, height=600)
|
654 |
-
with gradio.Column():
|
655 |
-
image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
|
656 |
-
label="prompt (提示图)", transforms=(), width=600, height=450)
|
657 |
-
input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
|
658 |
-
|
659 |
-
|
660 |
-
with gradio.Row():
|
661 |
-
schedule = gradio.Dropdown(["linear", "cosine"],
|
662 |
-
value='linear', label="schedule", info="For global alignment!")
|
663 |
-
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
|
664 |
-
label="num_iterations", info="For global alignment!")
|
665 |
-
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
|
666 |
-
value='complete', label="Scenegraph",
|
667 |
-
info="Define how to make pairs",
|
668 |
-
interactive=True)
|
669 |
-
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
670 |
-
minimum=1, maximum=1, step=1, visible=False)
|
671 |
-
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
672 |
-
|
673 |
-
run_btn = gradio.Button("Run")
|
674 |
-
|
675 |
-
with gradio.Row():
|
676 |
-
# adjust the confidence threshold
|
677 |
-
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
|
678 |
-
# adjust the camera size in the output pointcloud
|
679 |
-
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
|
680 |
-
with gradio.Row():
|
681 |
-
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
|
682 |
-
# two post process implemented
|
683 |
-
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
684 |
-
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
685 |
-
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
686 |
-
|
687 |
-
outmodel = gradio.Model3D()
|
688 |
-
outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
|
689 |
-
|
690 |
-
input_video.upload(upload_video_fun, inputs=input_video, outputs=image_mask)
|
691 |
-
|
692 |
-
run_btn.click(fn=recon_fun_video_demo, # 调用get_reconstructed_scene即DUST3R模型
|
693 |
-
inputs=[image_mask, input_video, schedule, niter, min_conf_thr, as_pointcloud,
|
694 |
-
mask_sky, clean_depth, transparent_cams, cam_size,
|
695 |
-
scenegraph_type, winsize, refid, input_text],
|
696 |
-
outputs=[scene, outmodel, outgallery])
|
697 |
-
|
698 |
-
app = gradio.TabbedInterface([demo1, demo2], ["3d rebuilding by images", "3d rebuilding by video"])
|
699 |
-
app.launch(share=False, server_name=server_name, server_port=server_port)
|
700 |
-
|
701 |
-
|
702 |
-
# TODO 修改bug:
|
703 |
-
#在项目的一次启动中,上传的多组图片在点击run后,会保存在同一个临时文件夹中,
|
704 |
-
# 这样后面再上传其他场景的图片时,不同场景下的图片会存在于一个文件夹中,
|
705 |
-
# 不同场景的图片导致分割与重建错误
|
706 |
-
|
707 |
-
## 目前构思的解决:在文件夹下再基于创建一个文件夹存放不同场景的图片,可以基于时间命名该文件夹
|
708 |
-
|
709 |
-
|
710 |
-
if __name__ == '__main__':
|
711 |
-
parser = get_args_parser()
|
712 |
-
args = parser.parse_args()
|
713 |
-
|
714 |
-
if args.tmp_dir is not None:
|
715 |
-
tmp_path = args.tmp_dir
|
716 |
-
os.makedirs(tmp_path, exist_ok=True)
|
717 |
-
tempfile.tempdir = tmp_path
|
718 |
-
|
719 |
-
if args.server_name is not None:
|
720 |
-
server_name = args.server_name
|
721 |
-
else:
|
722 |
-
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
723 |
-
|
724 |
-
# DUST3R
|
725 |
-
model = load_model(args.weights, args.device)
|
726 |
-
# SAM2
|
727 |
-
# 加载模型
|
728 |
-
sam2_checkpoint = "
|
729 |
-
model_cfg = "sam2_hiera_l.yaml"
|
730 |
-
sam2 = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
731 |
-
# dust3rWithSam2 will write the 3D model inside tmpdirname
|
732 |
-
with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: # DUST3R生成的3D .glb 文件所在的文件夹名称
|
733 |
-
print('Outputing stuff in', tmpdirname)
|
734 |
-
main_demo(sam2, tmpdirname, model, args.device, args.image_size, server_name, args.server_port)
|
735 |
-
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import gradio
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import tempfile
|
9 |
+
import functools
|
10 |
+
import trimesh
|
11 |
+
import copy
|
12 |
+
from scipy.spatial.transform import Rotation
|
13 |
+
|
14 |
+
from dust3r.inference import inference, load_model
|
15 |
+
from dust3r.image_pairs import make_pairs
|
16 |
+
from dust3r.utils.image import load_images, rgb, resize_images
|
17 |
+
from dust3r.utils.device import to_numpy
|
18 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
19 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
20 |
+
from sam2.build_sam import build_sam2_video_predictor
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
|
23 |
+
import shutil
|
24 |
+
import json
|
25 |
+
from PIL import Image
|
26 |
+
import math
|
27 |
+
import cv2
|
28 |
+
|
29 |
+
plt.ion()
|
30 |
+
|
31 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
32 |
+
batch_size = 1
|
33 |
+
|
34 |
+
########################## 引入grounding_dino #############################
|
35 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
36 |
+
def get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text):
|
37 |
+
# init grounding dino model from huggingface
|
38 |
+
model_id = "IDEA-Research/grounding-dino-tiny"
|
39 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
41 |
+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
42 |
+
|
43 |
+
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
44 |
+
# VERY important: text queries need to be lowercased + end with a dot
|
45 |
+
|
46 |
+
|
47 |
+
"""
|
48 |
+
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
|
49 |
+
"""
|
50 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
51 |
+
frame_names = [
|
52 |
+
p for p in os.listdir(video_dir)
|
53 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
54 |
+
]
|
55 |
+
# frame_names.sort(key=lambda p: os.path.splitext(p)[0])
|
56 |
+
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
|
57 |
+
image = Image.open(img_path)
|
58 |
+
|
59 |
+
|
60 |
+
# run Grounding DINO on the image
|
61 |
+
inputs = processor(images=image, text=input_text, return_tensors="pt").to(device)
|
62 |
+
with torch.no_grad():
|
63 |
+
outputs = grounding_model(**inputs)
|
64 |
+
|
65 |
+
results = processor.post_process_grounded_object_detection(
|
66 |
+
outputs,
|
67 |
+
inputs.input_ids,
|
68 |
+
box_threshold=0.25,
|
69 |
+
text_threshold=0.3,
|
70 |
+
target_sizes=[image.size[::-1]]
|
71 |
+
)
|
72 |
+
return results[0]["boxes"], results[0]["labels"]
|
73 |
+
|
74 |
+
def get_masks_from_grounded_sam2(h, w, predictor, video_dir, input_text):
|
75 |
+
|
76 |
+
inference_state = predictor.init_state(video_path=video_dir)
|
77 |
+
predictor.reset_state(inference_state)
|
78 |
+
|
79 |
+
ann_frame_idx = 0 # the frame index we interact with
|
80 |
+
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
81 |
+
print("Running Groundding DINO......")
|
82 |
+
input_boxes, OBJECTS = get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text)
|
83 |
+
print("Groundding DINO run over!")
|
84 |
+
if(len(OBJECTS) < 1):
|
85 |
+
raise gradio.Error("The images you input do not contain the target in '{}'".format(input_text))
|
86 |
+
# 给第一个帧输入由grounding_dino输出的boxes作为prompts
|
87 |
+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes)):
|
88 |
+
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
89 |
+
inference_state=inference_state,
|
90 |
+
frame_idx=ann_frame_idx,
|
91 |
+
obj_id=ann_obj_id,
|
92 |
+
box=box,
|
93 |
+
)
|
94 |
+
break #只加入第一个box
|
95 |
+
|
96 |
+
|
97 |
+
# sam2获取所有帧的分割结果
|
98 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
99 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
100 |
+
video_segments[out_frame_idx] = {
|
101 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
102 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
103 |
+
}
|
104 |
+
|
105 |
+
resize_mask = resize_mask_to_img(video_segments, w, h)
|
106 |
+
return resize_mask
|
107 |
+
|
108 |
+
|
109 |
+
def handle_uploaded_files(uploaded_files, target_folder):
|
110 |
+
# 创建目标文件夹
|
111 |
+
if not os.path.exists(target_folder):
|
112 |
+
os.makedirs(target_folder)
|
113 |
+
|
114 |
+
# 遍历上传的文件,移动到目标文件夹
|
115 |
+
for file in uploaded_files:
|
116 |
+
file_path = file.name # 文件的临时路径
|
117 |
+
file_name = os.path.basename(file_path) # 文件名
|
118 |
+
target_path = os.path.join(target_folder, file_name)
|
119 |
+
shutil.copy2(file_path, target_path)
|
120 |
+
print("copy images from {} to {}".format(file_path, target_path))
|
121 |
+
|
122 |
+
return target_folder
|
123 |
+
def show_mask(mask, ax, random_color=False):
|
124 |
+
if random_color:
|
125 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
126 |
+
else:
|
127 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
128 |
+
h, w = mask.shape[-2:]
|
129 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
130 |
+
ax.imshow(mask_image)
|
131 |
+
|
132 |
+
def show_mask_sam2(mask, ax, obj_id=None, random_color=False):
|
133 |
+
if random_color:
|
134 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
135 |
+
else:
|
136 |
+
cmap = plt.get_cmap("tab10")
|
137 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
138 |
+
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
139 |
+
h, w = mask.shape[-2:]
|
140 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
141 |
+
ax.imshow(mask_image)
|
142 |
+
def show_points(coords, labels, ax, marker_size=375):
|
143 |
+
pos_points = coords[labels == 1]
|
144 |
+
neg_points = coords[labels == 0]
|
145 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
|
146 |
+
linewidth=1.25)
|
147 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
|
148 |
+
linewidth=1.25)
|
149 |
+
|
150 |
+
def show_box(box, ax):
|
151 |
+
x0, y0 = box[0], box[1]
|
152 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
153 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
def get_args_parser():
|
158 |
+
parser = argparse.ArgumentParser()
|
159 |
+
parser_url = parser.add_mutually_exclusive_group()
|
160 |
+
parser_url.add_argument("--local_network", action='store_true', default=False,
|
161 |
+
help="make app accessible on local network: address will be set to 0.0.0.0")
|
162 |
+
parser_url.add_argument("--server_name", type=str, default="0.0.0.0", help="server url, default is 127.0.0.1")
|
163 |
+
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
|
164 |
+
parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
|
165 |
+
"If None, will search for an available port starting at 7860."),
|
166 |
+
default=None)
|
167 |
+
parser.add_argument("--weights", type=str, required=True, help="path to the model weights")
|
168 |
+
parser.add_argument("--device", type=str, default='cpu', help="pytorch device")
|
169 |
+
parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
|
170 |
+
return parser
|
171 |
+
|
172 |
+
|
173 |
+
# 将渲染的3D保存到outfile路径
|
174 |
+
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
175 |
+
cam_color=None, as_pointcloud=False, transparent_cams=False):
|
176 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
177 |
+
pts3d = to_numpy(pts3d)
|
178 |
+
imgs = to_numpy(imgs)
|
179 |
+
focals = to_numpy(focals)
|
180 |
+
cams2world = to_numpy(cams2world)
|
181 |
+
|
182 |
+
scene = trimesh.Scene()
|
183 |
+
|
184 |
+
# full pointcloud
|
185 |
+
if as_pointcloud:
|
186 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
187 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
188 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
189 |
+
scene.add_geometry(pct)
|
190 |
+
else:
|
191 |
+
meshes = []
|
192 |
+
for i in range(len(imgs)):
|
193 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
|
194 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
195 |
+
scene.add_geometry(mesh)
|
196 |
+
|
197 |
+
# add each camera
|
198 |
+
for i, pose_c2w in enumerate(cams2world):
|
199 |
+
if isinstance(cam_color, list):
|
200 |
+
camera_edge_color = cam_color[i]
|
201 |
+
else:
|
202 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
203 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
204 |
+
None if transparent_cams else imgs[i], focals[i],
|
205 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
206 |
+
|
207 |
+
rot = np.eye(4)
|
208 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
209 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
210 |
+
outfile = os.path.join(outdir, 'scene.glb')
|
211 |
+
print('(exporting 3D scene to', outfile, ')')
|
212 |
+
scene.export(file_obj=outfile)
|
213 |
+
return outfile
|
214 |
+
|
215 |
+
|
216 |
+
def get_3D_model_from_scene(outdir, scene, sam2_masks, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
217 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05):
|
218 |
+
"""
|
219 |
+
extract 3D_model (glb file) from a reconstructed scene
|
220 |
+
"""
|
221 |
+
if scene is None:
|
222 |
+
return None
|
223 |
+
# post processes
|
224 |
+
if clean_depth:
|
225 |
+
scene = scene.clean_pointcloud()
|
226 |
+
if mask_sky:
|
227 |
+
scene = scene.mask_sky()
|
228 |
+
|
229 |
+
# get optimized values from scene
|
230 |
+
rgbimg = scene.imgs
|
231 |
+
|
232 |
+
focals = scene.get_focals().cpu()
|
233 |
+
cams2world = scene.get_im_poses().cpu()
|
234 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
235 |
+
pts3d = to_numpy(scene.get_pts3d())
|
236 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
237 |
+
msk = to_numpy(scene.get_masks())
|
238 |
+
|
239 |
+
assert len(msk) == len(sam2_masks)
|
240 |
+
# 将sam2输出的mask 和 dust3r输出的置信度阈值筛选后的msk取交集
|
241 |
+
for i in range(len(sam2_masks)):
|
242 |
+
msk[i] = np.logical_and(msk[i], sam2_masks[i])
|
243 |
+
|
244 |
+
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
245 |
+
transparent_cams=transparent_cams, cam_size=cam_size) # 置信度和SAM2 mask的交集
|
246 |
+
|
247 |
+
# 将视频分割成固定帧数
|
248 |
+
def video_to_frames_fix(video_path, output_folder, frame_interval=10, target_fps=6):
|
249 |
+
"""
|
250 |
+
将视频转换为图像帧,并保存为 JPEG 文件。
|
251 |
+
frame_interval:保存��的步长
|
252 |
+
target_fps: 目标帧率(每秒保存的帧数)
|
253 |
+
"""
|
254 |
+
|
255 |
+
# 确保输出文件夹存在
|
256 |
+
if not os.path.exists(output_folder):
|
257 |
+
os.makedirs(output_folder)
|
258 |
+
# 打开视频文件
|
259 |
+
cap = cv2.VideoCapture(video_path)
|
260 |
+
# 获取视频总帧数
|
261 |
+
frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
262 |
+
|
263 |
+
# 计算动态帧间隔
|
264 |
+
frame_interval = math.ceil(frames_num / target_fps)
|
265 |
+
print(f"总帧数: {frames_num} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
|
266 |
+
frame_count = 0
|
267 |
+
saved_frame_count = 0
|
268 |
+
success, frame = cap.read()
|
269 |
+
|
270 |
+
file_list = []
|
271 |
+
# 逐帧读取视频
|
272 |
+
while success:
|
273 |
+
if frame_count % frame_interval == 0:
|
274 |
+
# 每隔 frame_interval 帧保存一次
|
275 |
+
frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
|
276 |
+
cv2.imwrite(frame_filename, frame)
|
277 |
+
file_list.append(frame_filename)
|
278 |
+
saved_frame_count += 1
|
279 |
+
frame_count += 1
|
280 |
+
success, frame = cap.read()
|
281 |
+
|
282 |
+
# 释放视频捕获对象
|
283 |
+
cap.release()
|
284 |
+
print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
|
285 |
+
return file_list
|
286 |
+
|
287 |
+
|
288 |
+
def video_to_frames(video_path, output_folder, frame_interval=10, target_fps = 2):
|
289 |
+
"""
|
290 |
+
将视频转换为图像帧,并保存为 JPEG 文件。
|
291 |
+
frame_interval:保存帧的步长
|
292 |
+
target_fps: 目标帧率(每秒保存的帧数)
|
293 |
+
"""
|
294 |
+
|
295 |
+
# 确保输出文件夹存在
|
296 |
+
if not os.path.exists(output_folder):
|
297 |
+
os.makedirs(output_folder)
|
298 |
+
# 打开视频文件
|
299 |
+
cap = cv2.VideoCapture(video_path)
|
300 |
+
# 获取视频的实际帧率
|
301 |
+
actual_fps = cap.get(cv2.CAP_PROP_FPS)
|
302 |
+
|
303 |
+
# 获取视频总帧数
|
304 |
+
frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
305 |
+
|
306 |
+
# 计算动态帧间隔
|
307 |
+
# frame_interval = math.ceil(actual_fps / target_fps)
|
308 |
+
print(f"实际帧率: {actual_fps} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
|
309 |
+
frame_count = 0
|
310 |
+
saved_frame_count = 0
|
311 |
+
success, frame = cap.read()
|
312 |
+
|
313 |
+
file_list = []
|
314 |
+
# 逐帧读取视频
|
315 |
+
while success:
|
316 |
+
if frame_count % frame_interval == 0:
|
317 |
+
# 每隔 frame_interval 帧保存一次
|
318 |
+
frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
|
319 |
+
cv2.imwrite(frame_filename, frame)
|
320 |
+
file_list.append(frame_filename)
|
321 |
+
saved_frame_count += 1
|
322 |
+
frame_count += 1
|
323 |
+
success, frame = cap.read()
|
324 |
+
|
325 |
+
# 释放视频捕获对象
|
326 |
+
cap.release()
|
327 |
+
print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
|
328 |
+
return file_list
|
329 |
+
|
330 |
+
def overlay_mask_on_image(image, mask, color=[0, 1, 0], alpha=0.5):
|
331 |
+
"""
|
332 |
+
将mask融合在image上显示。
|
333 |
+
返回融合后的图片 (H, W, 3)
|
334 |
+
"""
|
335 |
+
|
336 |
+
# 创建一个与image相同尺寸的全黑图像
|
337 |
+
mask_colored = np.zeros_like(image)
|
338 |
+
|
339 |
+
# 将mask为True的位置赋值为指定颜色
|
340 |
+
mask_colored[mask] = color
|
341 |
+
|
342 |
+
# 将彩色掩码与原图像叠加
|
343 |
+
overlay = cv2.addWeighted(image, 1 - alpha, mask_colored, alpha, 0)
|
344 |
+
|
345 |
+
return overlay
|
346 |
+
def get_reconstructed_video(sam2, outdir, model, device, image_size, image_mask, video_dir, schedule, niter, min_conf_thr,
|
347 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
348 |
+
scenegraph_type, winsize, refid, input_text):
|
349 |
+
target_dir = os.path.join(outdir, 'frames_video')
|
350 |
+
file_list = video_to_frames_fix(video_dir, target_dir)
|
351 |
+
scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, file_list, schedule, niter, min_conf_thr,
|
352 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
353 |
+
scenegraph_type, winsize, refid, target_dir, input_text)
|
354 |
+
return scene, outfile, imgs
|
355 |
+
|
356 |
+
def get_reconstructed_image(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
|
357 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
358 |
+
scenegraph_type, winsize, refid, input_text):
|
359 |
+
target_folder = handle_uploaded_files(filelist, os.path.join(outdir, 'uploaded_images'))
|
360 |
+
scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
|
361 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
362 |
+
scenegraph_type, winsize, refid, target_folder, input_text)
|
363 |
+
return scene, outfile, imgs
|
364 |
+
def get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
|
365 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
366 |
+
scenegraph_type, winsize, refid, images_folder, input_text=None):
|
367 |
+
"""
|
368 |
+
from a list of images, run dust3rWithSam2 inference, global aligner.
|
369 |
+
then run get_3D_model_from_scene
|
370 |
+
"""
|
371 |
+
imgs = load_images(filelist, size=image_size)
|
372 |
+
img_size = imgs[0]["true_shape"]
|
373 |
+
for img in imgs[1:]:
|
374 |
+
if not np.equal(img["true_shape"], img_size).all():
|
375 |
+
raise gradio.Error("Please ensure that the images you enter are of the same size")
|
376 |
+
|
377 |
+
if len(imgs) == 1:
|
378 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
379 |
+
imgs[1]['idx'] = 1
|
380 |
+
if scenegraph_type == "swin":
|
381 |
+
scenegraph_type = scenegraph_type + "-" + str(winsize)
|
382 |
+
elif scenegraph_type == "oneref":
|
383 |
+
scenegraph_type = scenegraph_type + "-" + str(refid)
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
|
390 |
+
output = inference(pairs, model, device, batch_size=batch_size)
|
391 |
+
|
392 |
+
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
|
393 |
+
scene = global_aligner(output, device=device, mode=mode)
|
394 |
+
lr = 0.01
|
395 |
+
|
396 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
397 |
+
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
# also return rgb, depth and confidence imgs
|
402 |
+
# depth is normalized with the max value for all images
|
403 |
+
# we apply the jet colormap on the confidence maps
|
404 |
+
rgbimg = scene.imgs
|
405 |
+
depths = to_numpy(scene.get_depthmaps())
|
406 |
+
confs = to_numpy([c for c in scene.im_conf])
|
407 |
+
cmap = plt.get_cmap('jet')
|
408 |
+
depths_max = max([d.max() for d in depths])
|
409 |
+
depths = [d / depths_max for d in depths]
|
410 |
+
confs_max = max([d.max() for d in confs])
|
411 |
+
confs = [cmap(d / confs_max) for d in confs]
|
412 |
+
|
413 |
+
# TODO 调用SAM2获取masks
|
414 |
+
h, w = rgbimg[0].shape[:-1]
|
415 |
+
masks = None
|
416 |
+
if not input_text or input_text.isspace(): # input_text 为空串
|
417 |
+
masks = get_masks_from_sam2(h, w, sam2, images_folder)
|
418 |
+
else:
|
419 |
+
masks = get_masks_from_grounded_sam2(h, w, sam2, images_folder, input_text) # gd-sam2
|
420 |
+
|
421 |
+
|
422 |
+
imgs = []
|
423 |
+
for i in range(len(rgbimg)):
|
424 |
+
imgs.append(rgbimg[i])
|
425 |
+
imgs.append(rgb(depths[i]))
|
426 |
+
imgs.append(rgb(confs[i]))
|
427 |
+
imgs.append(overlay_mask_on_image(rgbimg[i], masks[i])) # mask融合原图,展示SAM2的分割效果
|
428 |
+
|
429 |
+
# TODO 基于SAM2的mask过滤DUST3R的3D重建模型
|
430 |
+
outfile = get_3D_model_from_scene(outdir, scene, masks, min_conf_thr, as_pointcloud, mask_sky,
|
431 |
+
clean_depth, transparent_cams, cam_size)
|
432 |
+
return scene, outfile, imgs
|
433 |
+
|
434 |
+
|
435 |
+
def resize_mask_to_img(masks, target_width, target_height):
|
436 |
+
frame_mask = []
|
437 |
+
origin_size = masks[0][1].shape # 1表示object id
|
438 |
+
for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
|
439 |
+
# 每个frame可能包含多个object对应的mask
|
440 |
+
masks = list(objects_mask.values())
|
441 |
+
if not masks: # masks为空,即当前frame不包含object
|
442 |
+
frame_mask.append(np.ones(origin_size, dtype=bool))
|
443 |
+
else: # 将当前frame包含的所有object的mask取并集
|
444 |
+
union_mask = masks[0]
|
445 |
+
for mask in masks[1:]:
|
446 |
+
union_mask = np.logical_or(union_mask, mask)
|
447 |
+
frame_mask.append(union_mask)
|
448 |
+
resized_mask = []
|
449 |
+
for mask in frame_mask:
|
450 |
+
mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
|
451 |
+
resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
|
452 |
+
resized_mask.append(np.array(resized_mask_image) > 0)
|
453 |
+
|
454 |
+
return resized_mask
|
455 |
+
|
456 |
+
def get_masks_from_sam2(h, w, predictor, video_dir):
|
457 |
+
|
458 |
+
inference_state = predictor.init_state(video_path=video_dir)
|
459 |
+
predictor.reset_state(inference_state)
|
460 |
+
|
461 |
+
|
462 |
+
# 给一个帧添加points
|
463 |
+
points = np.array([[360, 550], [340, 400]], dtype=np.float32)
|
464 |
+
labels = np.array([1, 1], dtype=np.int32)
|
465 |
+
|
466 |
+
|
467 |
+
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
|
468 |
+
inference_state=inference_state,
|
469 |
+
frame_idx=0,
|
470 |
+
obj_id=1,
|
471 |
+
points=points,
|
472 |
+
labels=labels,
|
473 |
+
)
|
474 |
+
|
475 |
+
# sam2获取所有帧的分割结果
|
476 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
477 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
478 |
+
video_segments[out_frame_idx] = {
|
479 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
480 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
481 |
+
}
|
482 |
+
|
483 |
+
resize_mask = resize_mask_to_img(video_segments, w, h)
|
484 |
+
return resize_mask
|
485 |
+
|
486 |
+
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
|
487 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
488 |
+
max_winsize = max(1, (num_files - 1) // 2)
|
489 |
+
if scenegraph_type == "swin":
|
490 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
491 |
+
minimum=1, maximum=max_winsize, step=1, visible=True)
|
492 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
493 |
+
maximum=num_files - 1, step=1, visible=False)
|
494 |
+
elif scenegraph_type == "oneref":
|
495 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
496 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
497 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
498 |
+
maximum=num_files - 1, step=1, visible=True)
|
499 |
+
else:
|
500 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
501 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
502 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
503 |
+
maximum=num_files - 1, step=1, visible=False)
|
504 |
+
return winsize, refid
|
505 |
+
|
506 |
+
def process_images(imagesList):
|
507 |
+
return None
|
508 |
+
|
509 |
+
def process_videos(video):
|
510 |
+
return None
|
511 |
+
|
512 |
+
def upload_images_listener(image_size, file_list):
|
513 |
+
if len(file_list) == 1:
|
514 |
+
raise gradio.Error("Please enter images from at least two different views.")
|
515 |
+
print("Uploading image[0] to ImageMask:")
|
516 |
+
img_0 = load_images([file_list[0]], image_size)
|
517 |
+
i1 = img_0[0]['img'].squeeze(0)
|
518 |
+
rgb_img = rgb(i1)
|
519 |
+
return rgb_img
|
520 |
+
def upload_video_listener(image_size, video_dir):
|
521 |
+
cap = cv2.VideoCapture(video_dir)
|
522 |
+
success, frame = cap.read() # 第一帧
|
523 |
+
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
524 |
+
Image_frame = Image.fromarray(rgb_frame)
|
525 |
+
resized_frame = resize_images([Image_frame], image_size)
|
526 |
+
i1 = resized_frame[0]['img'].squeeze(0)
|
527 |
+
rgb_img = rgb(i1)
|
528 |
+
return rgb_img
|
529 |
+
|
530 |
+
def main_demo(sam2, tmpdirname, model, device, image_size, server_name, server_port):
|
531 |
+
|
532 |
+
# functools.partial解析:https://blog.csdn.net/wuShiJingZuo/article/details/135018810
|
533 |
+
recon_fun_image_demo = functools.partial(get_reconstructed_image,sam2, tmpdirname, model, device,
|
534 |
+
image_size)
|
535 |
+
|
536 |
+
recon_fun_video_demo = functools.partial(get_reconstructed_video, sam2, tmpdirname, model, device,
|
537 |
+
image_size)
|
538 |
+
|
539 |
+
upload_files_fun = functools.partial(upload_images_listener,image_size)
|
540 |
+
upload_video_fun = functools.partial(upload_video_listener, image_size)
|
541 |
+
with gradio.Blocks() as demo1:
|
542 |
+
scene = gradio.State(None)
|
543 |
+
gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
|
544 |
+
gradio.HTML("""<h2 style="text-align: center;">
|
545 |
+
<a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
|
546 |
+
<a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
|
547 |
+
</h2>""")
|
548 |
+
gradio.HTML("""
|
549 |
+
<div style="text-align: center;">
|
550 |
+
<h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
|
551 |
+
</div>
|
552 |
+
""")
|
553 |
+
gradio.set_static_paths(paths=["static/images/"])
|
554 |
+
project_path = "static/images/project.gif"
|
555 |
+
gradio.HTML(f"""
|
556 |
+
<div align='center' >
|
557 |
+
<img src="/file={project_path}" width='720px'>
|
558 |
+
</div>
|
559 |
+
""")
|
560 |
+
gradio.HTML("<p> \
|
561 |
+
<strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
|
562 |
+
Choose an example below 🔥 🔥 🔥 <br>\
|
563 |
+
Or, upload by yourself: <br>\
|
564 |
+
1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
|
565 |
+
<br> \
|
566 |
+
💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
|
567 |
+
💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
|
568 |
+
💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
|
569 |
+
Note: we only take the first 16 frames for the demo. \
|
570 |
+
</p>")
|
571 |
+
|
572 |
+
with gradio.Column():
|
573 |
+
with gradio.Row():
|
574 |
+
inputfiles = gradio.File(file_count="multiple")
|
575 |
+
with gradio.Column():
|
576 |
+
image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
|
577 |
+
label="prompt (提示图)", transforms=(), width=600, height=450)
|
578 |
+
input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
|
579 |
+
with gradio.Row():
|
580 |
+
schedule = gradio.Dropdown(["linear", "cosine"],
|
581 |
+
value='linear', label="schedule", info="For global alignment!")
|
582 |
+
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
|
583 |
+
label="num_iterations", info="For global alignment!")
|
584 |
+
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
|
585 |
+
value='complete', label="Scenegraph",
|
586 |
+
info="Define how to make pairs",
|
587 |
+
interactive=True)
|
588 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
589 |
+
minimum=1, maximum=1, step=1, visible=False)
|
590 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
591 |
+
|
592 |
+
run_btn = gradio.Button("Run")
|
593 |
+
|
594 |
+
with gradio.Row():
|
595 |
+
# adjust the confidence threshold
|
596 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
|
597 |
+
# adjust the camera size in the output pointcloud
|
598 |
+
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
|
599 |
+
with gradio.Row():
|
600 |
+
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
|
601 |
+
# two post process implemented
|
602 |
+
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
603 |
+
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
604 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
605 |
+
|
606 |
+
outmodel = gradio.Model3D()
|
607 |
+
outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
|
608 |
+
|
609 |
+
inputfiles.upload(upload_files_fun, inputs=inputfiles, outputs=image_mask)
|
610 |
+
|
611 |
+
run_btn.click(fn=recon_fun_image_demo, # 调用get_reconstructed_image即DUST3R模型
|
612 |
+
inputs=[image_mask, inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
|
613 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
614 |
+
scenegraph_type, winsize, refid, input_text],
|
615 |
+
outputs=[scene, outmodel, outgallery])
|
616 |
+
|
617 |
+
|
618 |
+
|
619 |
+
|
620 |
+
# ## **************************** video *******************************************************
|
621 |
+
with gradio.Blocks() as demo2:
|
622 |
+
gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
|
623 |
+
gradio.HTML("""<h2 style="text-align: center;">
|
624 |
+
<a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
|
625 |
+
<a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
|
626 |
+
</h2>""")
|
627 |
+
gradio.HTML("""
|
628 |
+
<div style="text-align: center;">
|
629 |
+
<h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
|
630 |
+
</div>
|
631 |
+
""")
|
632 |
+
gradio.set_static_paths(paths=["static/images/"])
|
633 |
+
project_path = "static/images/project.gif"
|
634 |
+
gradio.HTML(f"""
|
635 |
+
<div align='center' >
|
636 |
+
<img src="/file={project_path}" width='720px'>
|
637 |
+
</div>
|
638 |
+
""")
|
639 |
+
gradio.HTML("<p> \
|
640 |
+
<strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
|
641 |
+
Choose an example below 🔥 🔥 🔥 <br>\
|
642 |
+
Or, upload by yourself: <br>\
|
643 |
+
1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
|
644 |
+
<br> \
|
645 |
+
💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
|
646 |
+
💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
|
647 |
+
💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
|
648 |
+
Note: we only take the first 16 frames for the demo. \
|
649 |
+
</p>")
|
650 |
+
|
651 |
+
with gradio.Column():
|
652 |
+
with gradio.Row():
|
653 |
+
input_video = gradio.Video(width=600, height=600)
|
654 |
+
with gradio.Column():
|
655 |
+
image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
|
656 |
+
label="prompt (提示图)", transforms=(), width=600, height=450)
|
657 |
+
input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
|
658 |
+
|
659 |
+
|
660 |
+
with gradio.Row():
|
661 |
+
schedule = gradio.Dropdown(["linear", "cosine"],
|
662 |
+
value='linear', label="schedule", info="For global alignment!")
|
663 |
+
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
|
664 |
+
label="num_iterations", info="For global alignment!")
|
665 |
+
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
|
666 |
+
value='complete', label="Scenegraph",
|
667 |
+
info="Define how to make pairs",
|
668 |
+
interactive=True)
|
669 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
670 |
+
minimum=1, maximum=1, step=1, visible=False)
|
671 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
672 |
+
|
673 |
+
run_btn = gradio.Button("Run")
|
674 |
+
|
675 |
+
with gradio.Row():
|
676 |
+
# adjust the confidence threshold
|
677 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
|
678 |
+
# adjust the camera size in the output pointcloud
|
679 |
+
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
|
680 |
+
with gradio.Row():
|
681 |
+
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
|
682 |
+
# two post process implemented
|
683 |
+
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
684 |
+
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
685 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
686 |
+
|
687 |
+
outmodel = gradio.Model3D()
|
688 |
+
outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
|
689 |
+
|
690 |
+
input_video.upload(upload_video_fun, inputs=input_video, outputs=image_mask)
|
691 |
+
|
692 |
+
run_btn.click(fn=recon_fun_video_demo, # 调用get_reconstructed_scene即DUST3R模型
|
693 |
+
inputs=[image_mask, input_video, schedule, niter, min_conf_thr, as_pointcloud,
|
694 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
695 |
+
scenegraph_type, winsize, refid, input_text],
|
696 |
+
outputs=[scene, outmodel, outgallery])
|
697 |
+
|
698 |
+
app = gradio.TabbedInterface([demo1, demo2], ["3d rebuilding by images", "3d rebuilding by video"])
|
699 |
+
app.launch(share=False, server_name=server_name, server_port=server_port)
|
700 |
+
|
701 |
+
|
702 |
+
# TODO 修改bug:
|
703 |
+
#在项目的一次启动中,上传的多组图片在点击run后,会保存在同一个临时文件夹中,
|
704 |
+
# 这样后面再上传其他场景的图片时,不同场景下的图片会存在于一个文件夹中,
|
705 |
+
# 不同场景的图片导致分割与重建错误
|
706 |
+
|
707 |
+
## 目前构思的解决:在文件夹下再基于创建一个文件夹存放不同场景的图片,可以基于时间命名该文件夹
|
708 |
+
|
709 |
+
|
710 |
+
if __name__ == '__main__':
|
711 |
+
parser = get_args_parser()
|
712 |
+
args = parser.parse_args()
|
713 |
+
|
714 |
+
if args.tmp_dir is not None:
|
715 |
+
tmp_path = args.tmp_dir
|
716 |
+
os.makedirs(tmp_path, exist_ok=True)
|
717 |
+
tempfile.tempdir = tmp_path
|
718 |
+
|
719 |
+
if args.server_name is not None:
|
720 |
+
server_name = args.server_name
|
721 |
+
else:
|
722 |
+
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
723 |
+
|
724 |
+
# DUST3R
|
725 |
+
model = load_model(args.weights, args.device)
|
726 |
+
# SAM2
|
727 |
+
# 加载模型
|
728 |
+
sam2_checkpoint = "./SAM2/checkpoints/sam2_hiera_large.pt"
|
729 |
+
model_cfg = "sam2_hiera_l.yaml"
|
730 |
+
sam2 = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
731 |
+
# dust3rWithSam2 will write the 3D model inside tmpdirname
|
732 |
+
with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: # DUST3R生成的3D .glb 文件所在的文件夹名称
|
733 |
+
print('Outputing stuff in', tmpdirname)
|
734 |
+
main_demo(sam2, tmpdirname, model, args.device, args.image_size, server_name, args.server_port)
|
735 |
+
|