yansong1616 commited on
Commit
777e649
·
verified ·
1 Parent(s): 3cee1b0

Rename app_gys.py to app.py

Browse files
Files changed (1) hide show
  1. app_gys.py → app.py +735 -735
app_gys.py → app.py RENAMED
@@ -1,735 +1,735 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import argparse
4
- import gradio
5
- import os
6
- import torch
7
- import numpy as np
8
- import tempfile
9
- import functools
10
- import trimesh
11
- import copy
12
- from scipy.spatial.transform import Rotation
13
-
14
- from dust3r.inference import inference, load_model
15
- from dust3r.image_pairs import make_pairs
16
- from dust3r.utils.image import load_images, rgb, resize_images
17
- from dust3r.utils.device import to_numpy
18
- from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
19
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
20
- from sam2.build_sam import build_sam2_video_predictor
21
- import matplotlib.pyplot as plt
22
-
23
- import shutil
24
- import json
25
- from PIL import Image
26
- import math
27
- import cv2
28
-
29
- plt.ion()
30
-
31
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
32
- batch_size = 1
33
-
34
- ########################## 引入grounding_dino #############################
35
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
36
- def get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text):
37
- # init grounding dino model from huggingface
38
- model_id = "IDEA-Research/grounding-dino-tiny"
39
- device = "cuda" if torch.cuda.is_available() else "cpu"
40
- processor = AutoProcessor.from_pretrained(model_id)
41
- grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
42
-
43
- # setup the input image and text prompt for SAM 2 and Grounding DINO
44
- # VERY important: text queries need to be lowercased + end with a dot
45
-
46
-
47
- """
48
- Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
49
- """
50
- # prompt grounding dino to get the box coordinates on specific frame
51
- frame_names = [
52
- p for p in os.listdir(video_dir)
53
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
54
- ]
55
- # frame_names.sort(key=lambda p: os.path.splitext(p)[0])
56
- img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
57
- image = Image.open(img_path)
58
-
59
-
60
- # run Grounding DINO on the image
61
- inputs = processor(images=image, text=input_text, return_tensors="pt").to(device)
62
- with torch.no_grad():
63
- outputs = grounding_model(**inputs)
64
-
65
- results = processor.post_process_grounded_object_detection(
66
- outputs,
67
- inputs.input_ids,
68
- box_threshold=0.25,
69
- text_threshold=0.3,
70
- target_sizes=[image.size[::-1]]
71
- )
72
- return results[0]["boxes"], results[0]["labels"]
73
-
74
- def get_masks_from_grounded_sam2(h, w, predictor, video_dir, input_text):
75
-
76
- inference_state = predictor.init_state(video_path=video_dir)
77
- predictor.reset_state(inference_state)
78
-
79
- ann_frame_idx = 0 # the frame index we interact with
80
- ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
81
- print("Running Groundding DINO......")
82
- input_boxes, OBJECTS = get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text)
83
- print("Groundding DINO run over!")
84
- if(len(OBJECTS) < 1):
85
- raise gradio.Error("The images you input do not contain the target in '{}'".format(input_text))
86
- # 给第一个帧输入由grounding_dino输出的boxes作为prompts
87
- for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes)):
88
- _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
89
- inference_state=inference_state,
90
- frame_idx=ann_frame_idx,
91
- obj_id=ann_obj_id,
92
- box=box,
93
- )
94
- break #只加入第一个box
95
-
96
-
97
- # sam2获取所有帧的分割结果
98
- video_segments = {} # video_segments contains the per-frame segmentation results
99
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
100
- video_segments[out_frame_idx] = {
101
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
102
- for i, out_obj_id in enumerate(out_obj_ids)
103
- }
104
-
105
- resize_mask = resize_mask_to_img(video_segments, w, h)
106
- return resize_mask
107
-
108
-
109
- def handle_uploaded_files(uploaded_files, target_folder):
110
- # 创建目标文件夹
111
- if not os.path.exists(target_folder):
112
- os.makedirs(target_folder)
113
-
114
- # 遍历上传的文件,移动到目标文件夹
115
- for file in uploaded_files:
116
- file_path = file.name # 文件的临时路径
117
- file_name = os.path.basename(file_path) # 文件名
118
- target_path = os.path.join(target_folder, file_name)
119
- shutil.copy2(file_path, target_path)
120
- print("copy images from {} to {}".format(file_path, target_path))
121
-
122
- return target_folder
123
- def show_mask(mask, ax, random_color=False):
124
- if random_color:
125
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
126
- else:
127
- color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
128
- h, w = mask.shape[-2:]
129
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
130
- ax.imshow(mask_image)
131
-
132
- def show_mask_sam2(mask, ax, obj_id=None, random_color=False):
133
- if random_color:
134
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
135
- else:
136
- cmap = plt.get_cmap("tab10")
137
- cmap_idx = 0 if obj_id is None else obj_id
138
- color = np.array([*cmap(cmap_idx)[:3], 0.6])
139
- h, w = mask.shape[-2:]
140
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
141
- ax.imshow(mask_image)
142
- def show_points(coords, labels, ax, marker_size=375):
143
- pos_points = coords[labels == 1]
144
- neg_points = coords[labels == 0]
145
- ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
146
- linewidth=1.25)
147
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
148
- linewidth=1.25)
149
-
150
- def show_box(box, ax):
151
- x0, y0 = box[0], box[1]
152
- w, h = box[2] - box[0], box[3] - box[1]
153
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
154
-
155
-
156
-
157
- def get_args_parser():
158
- parser = argparse.ArgumentParser()
159
- parser_url = parser.add_mutually_exclusive_group()
160
- parser_url.add_argument("--local_network", action='store_true', default=False,
161
- help="make app accessible on local network: address will be set to 0.0.0.0")
162
- parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
163
- parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
164
- parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
165
- "If None, will search for an available port starting at 7860."),
166
- default=None)
167
- parser.add_argument("--weights", type=str, required=True, help="path to the model weights")
168
- parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
169
- parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
170
- return parser
171
-
172
-
173
- # 将渲染的3D保存到outfile路径
174
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
175
- cam_color=None, as_pointcloud=False, transparent_cams=False):
176
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
177
- pts3d = to_numpy(pts3d)
178
- imgs = to_numpy(imgs)
179
- focals = to_numpy(focals)
180
- cams2world = to_numpy(cams2world)
181
-
182
- scene = trimesh.Scene()
183
-
184
- # full pointcloud
185
- if as_pointcloud:
186
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
187
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
188
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
189
- scene.add_geometry(pct)
190
- else:
191
- meshes = []
192
- for i in range(len(imgs)):
193
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
194
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
195
- scene.add_geometry(mesh)
196
-
197
- # add each camera
198
- for i, pose_c2w in enumerate(cams2world):
199
- if isinstance(cam_color, list):
200
- camera_edge_color = cam_color[i]
201
- else:
202
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
203
- add_scene_cam(scene, pose_c2w, camera_edge_color,
204
- None if transparent_cams else imgs[i], focals[i],
205
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
206
-
207
- rot = np.eye(4)
208
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
209
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
210
- outfile = os.path.join(outdir, 'scene.glb')
211
- print('(exporting 3D scene to', outfile, ')')
212
- scene.export(file_obj=outfile)
213
- return outfile
214
-
215
-
216
- def get_3D_model_from_scene(outdir, scene, sam2_masks, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
217
- clean_depth=False, transparent_cams=False, cam_size=0.05):
218
- """
219
- extract 3D_model (glb file) from a reconstructed scene
220
- """
221
- if scene is None:
222
- return None
223
- # post processes
224
- if clean_depth:
225
- scene = scene.clean_pointcloud()
226
- if mask_sky:
227
- scene = scene.mask_sky()
228
-
229
- # get optimized values from scene
230
- rgbimg = scene.imgs
231
-
232
- focals = scene.get_focals().cpu()
233
- cams2world = scene.get_im_poses().cpu()
234
- # 3D pointcloud from depthmap, poses and intrinsics
235
- pts3d = to_numpy(scene.get_pts3d())
236
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
237
- msk = to_numpy(scene.get_masks())
238
-
239
- assert len(msk) == len(sam2_masks)
240
- # 将sam2输出的mask 和 dust3r输出的置信度阈值筛选后的msk取交集
241
- for i in range(len(sam2_masks)):
242
- msk[i] = np.logical_and(msk[i], sam2_masks[i])
243
-
244
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
245
- transparent_cams=transparent_cams, cam_size=cam_size) # 置信度和SAM2 mask的交集
246
-
247
- # 将视频分割成固定帧数
248
- def video_to_frames_fix(video_path, output_folder, frame_interval=10, target_fps=6):
249
- """
250
- 将视频转换为图像帧,并保存为 JPEG 文件。
251
- frame_interval:保存帧的步长
252
- target_fps: 目标帧率(每秒保存的帧数)
253
- """
254
-
255
- # 确保输出文件夹存在
256
- if not os.path.exists(output_folder):
257
- os.makedirs(output_folder)
258
- # 打开视频文件
259
- cap = cv2.VideoCapture(video_path)
260
- # 获取视频总帧数
261
- frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
262
-
263
- # 计算动态帧间隔
264
- frame_interval = math.ceil(frames_num / target_fps)
265
- print(f"总帧数: {frames_num} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
266
- frame_count = 0
267
- saved_frame_count = 0
268
- success, frame = cap.read()
269
-
270
- file_list = []
271
- # 逐帧读取视频
272
- while success:
273
- if frame_count % frame_interval == 0:
274
- # 每隔 frame_interval 帧保存一次
275
- frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
276
- cv2.imwrite(frame_filename, frame)
277
- file_list.append(frame_filename)
278
- saved_frame_count += 1
279
- frame_count += 1
280
- success, frame = cap.read()
281
-
282
- # 释放视频捕获对象
283
- cap.release()
284
- print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
285
- return file_list
286
-
287
-
288
- def video_to_frames(video_path, output_folder, frame_interval=10, target_fps = 2):
289
- """
290
- 将视频转换为图像帧,并保存为 JPEG 文件。
291
- frame_interval:保存帧的步长
292
- target_fps: 目标帧率(每秒保存的帧数)
293
- """
294
-
295
- # 确保输出文件夹存在
296
- if not os.path.exists(output_folder):
297
- os.makedirs(output_folder)
298
- # 打开视频文件
299
- cap = cv2.VideoCapture(video_path)
300
- # 获取视频的实际帧率
301
- actual_fps = cap.get(cv2.CAP_PROP_FPS)
302
-
303
- # 获取视频总帧数
304
- frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
305
-
306
- # 计算动态帧间隔
307
- # frame_interval = math.ceil(actual_fps / target_fps)
308
- print(f"实际帧率: {actual_fps} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
309
- frame_count = 0
310
- saved_frame_count = 0
311
- success, frame = cap.read()
312
-
313
- file_list = []
314
- # 逐帧读取视频
315
- while success:
316
- if frame_count % frame_interval == 0:
317
- # 每隔 frame_interval 帧保存一次
318
- frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
319
- cv2.imwrite(frame_filename, frame)
320
- file_list.append(frame_filename)
321
- saved_frame_count += 1
322
- frame_count += 1
323
- success, frame = cap.read()
324
-
325
- # 释放视频捕获对象
326
- cap.release()
327
- print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
328
- return file_list
329
-
330
- def overlay_mask_on_image(image, mask, color=[0, 1, 0], alpha=0.5):
331
- """
332
- 将mask融合在image上显示。
333
- 返回融合后的图片 (H, W, 3)
334
- """
335
-
336
- # 创建一个与image相同尺寸的全黑图像
337
- mask_colored = np.zeros_like(image)
338
-
339
- # 将mask为True的位置赋值为指定颜色
340
- mask_colored[mask] = color
341
-
342
- # 将彩色掩码与原图像叠加
343
- overlay = cv2.addWeighted(image, 1 - alpha, mask_colored, alpha, 0)
344
-
345
- return overlay
346
- def get_reconstructed_video(sam2, outdir, model, device, image_size, image_mask, video_dir, schedule, niter, min_conf_thr,
347
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
348
- scenegraph_type, winsize, refid, input_text):
349
- target_dir = os.path.join(outdir, 'frames_video')
350
- file_list = video_to_frames_fix(video_dir, target_dir)
351
- scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, file_list, schedule, niter, min_conf_thr,
352
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
353
- scenegraph_type, winsize, refid, target_dir, input_text)
354
- return scene, outfile, imgs
355
-
356
- def get_reconstructed_image(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
357
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
358
- scenegraph_type, winsize, refid, input_text):
359
- target_folder = handle_uploaded_files(filelist, os.path.join(outdir, 'uploaded_images'))
360
- scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
361
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
362
- scenegraph_type, winsize, refid, target_folder, input_text)
363
- return scene, outfile, imgs
364
- def get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
365
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
366
- scenegraph_type, winsize, refid, images_folder, input_text=None):
367
- """
368
- from a list of images, run dust3rWithSam2 inference, global aligner.
369
- then run get_3D_model_from_scene
370
- """
371
- imgs = load_images(filelist, size=image_size)
372
- img_size = imgs[0]["true_shape"]
373
- for img in imgs[1:]:
374
- if not np.equal(img["true_shape"], img_size).all():
375
- raise gradio.Error("Please ensure that the images you enter are of the same size")
376
-
377
- if len(imgs) == 1:
378
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
379
- imgs[1]['idx'] = 1
380
- if scenegraph_type == "swin":
381
- scenegraph_type = scenegraph_type + "-" + str(winsize)
382
- elif scenegraph_type == "oneref":
383
- scenegraph_type = scenegraph_type + "-" + str(refid)
384
-
385
-
386
-
387
-
388
-
389
- pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
390
- output = inference(pairs, model, device, batch_size=batch_size)
391
-
392
- mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
393
- scene = global_aligner(output, device=device, mode=mode)
394
- lr = 0.01
395
-
396
- if mode == GlobalAlignerMode.PointCloudOptimizer:
397
- loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
398
-
399
-
400
-
401
- # also return rgb, depth and confidence imgs
402
- # depth is normalized with the max value for all images
403
- # we apply the jet colormap on the confidence maps
404
- rgbimg = scene.imgs
405
- depths = to_numpy(scene.get_depthmaps())
406
- confs = to_numpy([c for c in scene.im_conf])
407
- cmap = plt.get_cmap('jet')
408
- depths_max = max([d.max() for d in depths])
409
- depths = [d / depths_max for d in depths]
410
- confs_max = max([d.max() for d in confs])
411
- confs = [cmap(d / confs_max) for d in confs]
412
-
413
- # TODO 调用SAM2获取masks
414
- h, w = rgbimg[0].shape[:-1]
415
- masks = None
416
- if not input_text or input_text.isspace(): # input_text 为空串
417
- masks = get_masks_from_sam2(h, w, sam2, images_folder)
418
- else:
419
- masks = get_masks_from_grounded_sam2(h, w, sam2, images_folder, input_text) # gd-sam2
420
-
421
-
422
- imgs = []
423
- for i in range(len(rgbimg)):
424
- imgs.append(rgbimg[i])
425
- imgs.append(rgb(depths[i]))
426
- imgs.append(rgb(confs[i]))
427
- imgs.append(overlay_mask_on_image(rgbimg[i], masks[i])) # mask融合原图,展示SAM2的分割效果
428
-
429
- # TODO 基于SAM2的mask过滤DUST3R的3D重建模型
430
- outfile = get_3D_model_from_scene(outdir, scene, masks, min_conf_thr, as_pointcloud, mask_sky,
431
- clean_depth, transparent_cams, cam_size)
432
- return scene, outfile, imgs
433
-
434
-
435
- def resize_mask_to_img(masks, target_width, target_height):
436
- frame_mask = []
437
- origin_size = masks[0][1].shape # 1表示object id
438
- for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
439
- # 每个frame可能包含多个object对应的mask
440
- masks = list(objects_mask.values())
441
- if not masks: # masks为空,即当前frame不包含object
442
- frame_mask.append(np.ones(origin_size, dtype=bool))
443
- else: # 将当前frame包含的所有object的mask取并集
444
- union_mask = masks[0]
445
- for mask in masks[1:]:
446
- union_mask = np.logical_or(union_mask, mask)
447
- frame_mask.append(union_mask)
448
- resized_mask = []
449
- for mask in frame_mask:
450
- mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
451
- resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
452
- resized_mask.append(np.array(resized_mask_image) > 0)
453
-
454
- return resized_mask
455
-
456
- def get_masks_from_sam2(h, w, predictor, video_dir):
457
-
458
- inference_state = predictor.init_state(video_path=video_dir)
459
- predictor.reset_state(inference_state)
460
-
461
-
462
- # 给一个帧添加points
463
- points = np.array([[360, 550], [340, 400]], dtype=np.float32)
464
- labels = np.array([1, 1], dtype=np.int32)
465
-
466
-
467
- _, out_obj_ids, out_mask_logits = predictor.add_new_points(
468
- inference_state=inference_state,
469
- frame_idx=0,
470
- obj_id=1,
471
- points=points,
472
- labels=labels,
473
- )
474
-
475
- # sam2获取所有帧的分割结果
476
- video_segments = {} # video_segments contains the per-frame segmentation results
477
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
478
- video_segments[out_frame_idx] = {
479
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
480
- for i, out_obj_id in enumerate(out_obj_ids)
481
- }
482
-
483
- resize_mask = resize_mask_to_img(video_segments, w, h)
484
- return resize_mask
485
-
486
- def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
487
- num_files = len(inputfiles) if inputfiles is not None else 1
488
- max_winsize = max(1, (num_files - 1) // 2)
489
- if scenegraph_type == "swin":
490
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
491
- minimum=1, maximum=max_winsize, step=1, visible=True)
492
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
493
- maximum=num_files - 1, step=1, visible=False)
494
- elif scenegraph_type == "oneref":
495
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
496
- minimum=1, maximum=max_winsize, step=1, visible=False)
497
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
498
- maximum=num_files - 1, step=1, visible=True)
499
- else:
500
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
501
- minimum=1, maximum=max_winsize, step=1, visible=False)
502
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
503
- maximum=num_files - 1, step=1, visible=False)
504
- return winsize, refid
505
-
506
- def process_images(imagesList):
507
- return None
508
-
509
- def process_videos(video):
510
- return None
511
-
512
- def upload_images_listener(image_size, file_list):
513
- if len(file_list) == 1:
514
- raise gradio.Error("Please enter images from at least two different views.")
515
- print("Uploading image[0] to ImageMask:")
516
- img_0 = load_images([file_list[0]], image_size)
517
- i1 = img_0[0]['img'].squeeze(0)
518
- rgb_img = rgb(i1)
519
- return rgb_img
520
- def upload_video_listener(image_size, video_dir):
521
- cap = cv2.VideoCapture(video_dir)
522
- success, frame = cap.read() # 第一帧
523
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
524
- Image_frame = Image.fromarray(rgb_frame)
525
- resized_frame = resize_images([Image_frame], image_size)
526
- i1 = resized_frame[0]['img'].squeeze(0)
527
- rgb_img = rgb(i1)
528
- return rgb_img
529
-
530
- def main_demo(sam2, tmpdirname, model, device, image_size, server_name, server_port):
531
-
532
- # functools.partial解析:https://blog.csdn.net/wuShiJingZuo/article/details/135018810
533
- recon_fun_image_demo = functools.partial(get_reconstructed_image,sam2, tmpdirname, model, device,
534
- image_size)
535
-
536
- recon_fun_video_demo = functools.partial(get_reconstructed_video, sam2, tmpdirname, model, device,
537
- image_size)
538
-
539
- upload_files_fun = functools.partial(upload_images_listener,image_size)
540
- upload_video_fun = functools.partial(upload_video_listener, image_size)
541
- with gradio.Blocks() as demo1:
542
- scene = gradio.State(None)
543
- gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
544
- gradio.HTML("""<h2 style="text-align: center;">
545
- <a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
546
- <a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
547
- </h2>""")
548
- gradio.HTML("""
549
- <div style="text-align: center;">
550
- <h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
551
- </div>
552
- """)
553
- gradio.set_static_paths(paths=["static/images/"])
554
- project_path = "static/images/project.gif"
555
- gradio.HTML(f"""
556
- <div align='center' >
557
- <img src="/file={project_path}" width='720px'>
558
- </div>
559
- """)
560
- gradio.HTML("<p> \
561
- <strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
562
- Choose an example below &#128293; &#128293; &#128293; <br>\
563
- Or, upload by yourself: <br>\
564
- 1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
565
- <br> \
566
- 💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
567
- 💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
568
- 💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
569
- Note: we only take the first 16 frames for the demo. \
570
- </p>")
571
-
572
- with gradio.Column():
573
- with gradio.Row():
574
- inputfiles = gradio.File(file_count="multiple")
575
- with gradio.Column():
576
- image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
577
- label="prompt (提示图)", transforms=(), width=600, height=450)
578
- input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
579
- with gradio.Row():
580
- schedule = gradio.Dropdown(["linear", "cosine"],
581
- value='linear', label="schedule", info="For global alignment!")
582
- niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
583
- label="num_iterations", info="For global alignment!")
584
- scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
585
- value='complete', label="Scenegraph",
586
- info="Define how to make pairs",
587
- interactive=True)
588
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
589
- minimum=1, maximum=1, step=1, visible=False)
590
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
591
-
592
- run_btn = gradio.Button("Run")
593
-
594
- with gradio.Row():
595
- # adjust the confidence threshold
596
- min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
597
- # adjust the camera size in the output pointcloud
598
- cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
599
- with gradio.Row():
600
- as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
601
- # two post process implemented
602
- mask_sky = gradio.Checkbox(value=False, label="Mask sky")
603
- clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
604
- transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
605
-
606
- outmodel = gradio.Model3D()
607
- outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
608
-
609
- inputfiles.upload(upload_files_fun, inputs=inputfiles, outputs=image_mask)
610
-
611
- run_btn.click(fn=recon_fun_image_demo, # 调用get_reconstructed_image即DUST3R模型
612
- inputs=[image_mask, inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
613
- mask_sky, clean_depth, transparent_cams, cam_size,
614
- scenegraph_type, winsize, refid, input_text],
615
- outputs=[scene, outmodel, outgallery])
616
-
617
-
618
-
619
-
620
- # ## **************************** video *******************************************************
621
- with gradio.Blocks() as demo2:
622
- gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
623
- gradio.HTML("""<h2 style="text-align: center;">
624
- <a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
625
- <a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
626
- </h2>""")
627
- gradio.HTML("""
628
- <div style="text-align: center;">
629
- <h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
630
- </div>
631
- """)
632
- gradio.set_static_paths(paths=["static/images/"])
633
- project_path = "static/images/project.gif"
634
- gradio.HTML(f"""
635
- <div align='center' >
636
- <img src="/file={project_path}" width='720px'>
637
- </div>
638
- """)
639
- gradio.HTML("<p> \
640
- <strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
641
- Choose an example below &#128293; &#128293; &#128293; <br>\
642
- Or, upload by yourself: <br>\
643
- 1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
644
- <br> \
645
- 💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
646
- 💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
647
- 💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
648
- Note: we only take the first 16 frames for the demo. \
649
- </p>")
650
-
651
- with gradio.Column():
652
- with gradio.Row():
653
- input_video = gradio.Video(width=600, height=600)
654
- with gradio.Column():
655
- image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
656
- label="prompt (提示图)", transforms=(), width=600, height=450)
657
- input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
658
-
659
-
660
- with gradio.Row():
661
- schedule = gradio.Dropdown(["linear", "cosine"],
662
- value='linear', label="schedule", info="For global alignment!")
663
- niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
664
- label="num_iterations", info="For global alignment!")
665
- scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
666
- value='complete', label="Scenegraph",
667
- info="Define how to make pairs",
668
- interactive=True)
669
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
670
- minimum=1, maximum=1, step=1, visible=False)
671
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
672
-
673
- run_btn = gradio.Button("Run")
674
-
675
- with gradio.Row():
676
- # adjust the confidence threshold
677
- min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
678
- # adjust the camera size in the output pointcloud
679
- cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
680
- with gradio.Row():
681
- as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
682
- # two post process implemented
683
- mask_sky = gradio.Checkbox(value=False, label="Mask sky")
684
- clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
685
- transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
686
-
687
- outmodel = gradio.Model3D()
688
- outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
689
-
690
- input_video.upload(upload_video_fun, inputs=input_video, outputs=image_mask)
691
-
692
- run_btn.click(fn=recon_fun_video_demo, # 调用get_reconstructed_scene即DUST3R模型
693
- inputs=[image_mask, input_video, schedule, niter, min_conf_thr, as_pointcloud,
694
- mask_sky, clean_depth, transparent_cams, cam_size,
695
- scenegraph_type, winsize, refid, input_text],
696
- outputs=[scene, outmodel, outgallery])
697
-
698
- app = gradio.TabbedInterface([demo1, demo2], ["3d rebuilding by images", "3d rebuilding by video"])
699
- app.launch(share=False, server_name=server_name, server_port=server_port)
700
-
701
-
702
- # TODO 修改bug:
703
- #在项目的一次启动中,上传的多组图片在点击run后,会保存在同一个临时文件夹中,
704
- # 这样后面再上传其他场景的图片时,不同场景下的图片会存在于一个文件夹中,
705
- # 不同场景的图片导致分割与重建错误
706
-
707
- ## 目前构思的解决:在文件夹下再基于创建一个文件夹存放不同场景的图片,可以基于时间命名该文件夹
708
-
709
-
710
- if __name__ == '__main__':
711
- parser = get_args_parser()
712
- args = parser.parse_args()
713
-
714
- if args.tmp_dir is not None:
715
- tmp_path = args.tmp_dir
716
- os.makedirs(tmp_path, exist_ok=True)
717
- tempfile.tempdir = tmp_path
718
-
719
- if args.server_name is not None:
720
- server_name = args.server_name
721
- else:
722
- server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
723
-
724
- # DUST3R
725
- model = load_model(args.weights, args.device)
726
- # SAM2
727
- # 加载模型
728
- sam2_checkpoint = "D:\XMU\mac\hujie\\3D\DUST3RwithSAM2\dust3rWithSam2\SAM2\checkpoints\sam2_hiera_large.pt"
729
- model_cfg = "sam2_hiera_l.yaml"
730
- sam2 = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
731
- # dust3rWithSam2 will write the 3D model inside tmpdirname
732
- with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: # DUST3R生成的3D .glb 文件所在的文件夹名称
733
- print('Outputing stuff in', tmpdirname)
734
- main_demo(sam2, tmpdirname, model, args.device, args.image_size, server_name, args.server_port)
735
-
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import argparse
4
+ import gradio
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+ import tempfile
9
+ import functools
10
+ import trimesh
11
+ import copy
12
+ from scipy.spatial.transform import Rotation
13
+
14
+ from dust3r.inference import inference, load_model
15
+ from dust3r.image_pairs import make_pairs
16
+ from dust3r.utils.image import load_images, rgb, resize_images
17
+ from dust3r.utils.device import to_numpy
18
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
19
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
20
+ from sam2.build_sam import build_sam2_video_predictor
21
+ import matplotlib.pyplot as plt
22
+
23
+ import shutil
24
+ import json
25
+ from PIL import Image
26
+ import math
27
+ import cv2
28
+
29
+ plt.ion()
30
+
31
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
32
+ batch_size = 1
33
+
34
+ ########################## 引入grounding_dino #############################
35
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
36
+ def get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text):
37
+ # init grounding dino model from huggingface
38
+ model_id = "IDEA-Research/grounding-dino-tiny"
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ processor = AutoProcessor.from_pretrained(model_id)
41
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
42
+
43
+ # setup the input image and text prompt for SAM 2 and Grounding DINO
44
+ # VERY important: text queries need to be lowercased + end with a dot
45
+
46
+
47
+ """
48
+ Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
49
+ """
50
+ # prompt grounding dino to get the box coordinates on specific frame
51
+ frame_names = [
52
+ p for p in os.listdir(video_dir)
53
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
54
+ ]
55
+ # frame_names.sort(key=lambda p: os.path.splitext(p)[0])
56
+ img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
57
+ image = Image.open(img_path)
58
+
59
+
60
+ # run Grounding DINO on the image
61
+ inputs = processor(images=image, text=input_text, return_tensors="pt").to(device)
62
+ with torch.no_grad():
63
+ outputs = grounding_model(**inputs)
64
+
65
+ results = processor.post_process_grounded_object_detection(
66
+ outputs,
67
+ inputs.input_ids,
68
+ box_threshold=0.25,
69
+ text_threshold=0.3,
70
+ target_sizes=[image.size[::-1]]
71
+ )
72
+ return results[0]["boxes"], results[0]["labels"]
73
+
74
+ def get_masks_from_grounded_sam2(h, w, predictor, video_dir, input_text):
75
+
76
+ inference_state = predictor.init_state(video_path=video_dir)
77
+ predictor.reset_state(inference_state)
78
+
79
+ ann_frame_idx = 0 # the frame index we interact with
80
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
81
+ print("Running Groundding DINO......")
82
+ input_boxes, OBJECTS = get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text)
83
+ print("Groundding DINO run over!")
84
+ if(len(OBJECTS) < 1):
85
+ raise gradio.Error("The images you input do not contain the target in '{}'".format(input_text))
86
+ # 给第一个帧输入由grounding_dino输出的boxes作为prompts
87
+ for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes)):
88
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
89
+ inference_state=inference_state,
90
+ frame_idx=ann_frame_idx,
91
+ obj_id=ann_obj_id,
92
+ box=box,
93
+ )
94
+ break #只加入第一个box
95
+
96
+
97
+ # sam2获取所有帧的分割结果
98
+ video_segments = {} # video_segments contains the per-frame segmentation results
99
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
100
+ video_segments[out_frame_idx] = {
101
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
102
+ for i, out_obj_id in enumerate(out_obj_ids)
103
+ }
104
+
105
+ resize_mask = resize_mask_to_img(video_segments, w, h)
106
+ return resize_mask
107
+
108
+
109
+ def handle_uploaded_files(uploaded_files, target_folder):
110
+ # 创建目标文件夹
111
+ if not os.path.exists(target_folder):
112
+ os.makedirs(target_folder)
113
+
114
+ # 遍历上传的文件,移动到目标文件夹
115
+ for file in uploaded_files:
116
+ file_path = file.name # 文件的临时路径
117
+ file_name = os.path.basename(file_path) # 文件名
118
+ target_path = os.path.join(target_folder, file_name)
119
+ shutil.copy2(file_path, target_path)
120
+ print("copy images from {} to {}".format(file_path, target_path))
121
+
122
+ return target_folder
123
+ def show_mask(mask, ax, random_color=False):
124
+ if random_color:
125
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
126
+ else:
127
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
128
+ h, w = mask.shape[-2:]
129
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
130
+ ax.imshow(mask_image)
131
+
132
+ def show_mask_sam2(mask, ax, obj_id=None, random_color=False):
133
+ if random_color:
134
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
135
+ else:
136
+ cmap = plt.get_cmap("tab10")
137
+ cmap_idx = 0 if obj_id is None else obj_id
138
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
139
+ h, w = mask.shape[-2:]
140
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
141
+ ax.imshow(mask_image)
142
+ def show_points(coords, labels, ax, marker_size=375):
143
+ pos_points = coords[labels == 1]
144
+ neg_points = coords[labels == 0]
145
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
146
+ linewidth=1.25)
147
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
148
+ linewidth=1.25)
149
+
150
+ def show_box(box, ax):
151
+ x0, y0 = box[0], box[1]
152
+ w, h = box[2] - box[0], box[3] - box[1]
153
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
154
+
155
+
156
+
157
+ def get_args_parser():
158
+ parser = argparse.ArgumentParser()
159
+ parser_url = parser.add_mutually_exclusive_group()
160
+ parser_url.add_argument("--local_network", action='store_true', default=False,
161
+ help="make app accessible on local network: address will be set to 0.0.0.0")
162
+ parser_url.add_argument("--server_name", type=str, default="0.0.0.0", help="server url, default is 127.0.0.1")
163
+ parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
164
+ parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
165
+ "If None, will search for an available port starting at 7860."),
166
+ default=None)
167
+ parser.add_argument("--weights", type=str, required=True, help="path to the model weights")
168
+ parser.add_argument("--device", type=str, default='cpu', help="pytorch device")
169
+ parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
170
+ return parser
171
+
172
+
173
+ # 将渲染的3D保存到outfile路径
174
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
175
+ cam_color=None, as_pointcloud=False, transparent_cams=False):
176
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
177
+ pts3d = to_numpy(pts3d)
178
+ imgs = to_numpy(imgs)
179
+ focals = to_numpy(focals)
180
+ cams2world = to_numpy(cams2world)
181
+
182
+ scene = trimesh.Scene()
183
+
184
+ # full pointcloud
185
+ if as_pointcloud:
186
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
187
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
188
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
189
+ scene.add_geometry(pct)
190
+ else:
191
+ meshes = []
192
+ for i in range(len(imgs)):
193
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
194
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
195
+ scene.add_geometry(mesh)
196
+
197
+ # add each camera
198
+ for i, pose_c2w in enumerate(cams2world):
199
+ if isinstance(cam_color, list):
200
+ camera_edge_color = cam_color[i]
201
+ else:
202
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
203
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
204
+ None if transparent_cams else imgs[i], focals[i],
205
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
206
+
207
+ rot = np.eye(4)
208
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
209
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
210
+ outfile = os.path.join(outdir, 'scene.glb')
211
+ print('(exporting 3D scene to', outfile, ')')
212
+ scene.export(file_obj=outfile)
213
+ return outfile
214
+
215
+
216
+ def get_3D_model_from_scene(outdir, scene, sam2_masks, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
217
+ clean_depth=False, transparent_cams=False, cam_size=0.05):
218
+ """
219
+ extract 3D_model (glb file) from a reconstructed scene
220
+ """
221
+ if scene is None:
222
+ return None
223
+ # post processes
224
+ if clean_depth:
225
+ scene = scene.clean_pointcloud()
226
+ if mask_sky:
227
+ scene = scene.mask_sky()
228
+
229
+ # get optimized values from scene
230
+ rgbimg = scene.imgs
231
+
232
+ focals = scene.get_focals().cpu()
233
+ cams2world = scene.get_im_poses().cpu()
234
+ # 3D pointcloud from depthmap, poses and intrinsics
235
+ pts3d = to_numpy(scene.get_pts3d())
236
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
237
+ msk = to_numpy(scene.get_masks())
238
+
239
+ assert len(msk) == len(sam2_masks)
240
+ # 将sam2输出的mask 和 dust3r输出的置信度阈值筛选后的msk取交集
241
+ for i in range(len(sam2_masks)):
242
+ msk[i] = np.logical_and(msk[i], sam2_masks[i])
243
+
244
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
245
+ transparent_cams=transparent_cams, cam_size=cam_size) # 置信度和SAM2 mask的交集
246
+
247
+ # 将视频分割成固定帧数
248
+ def video_to_frames_fix(video_path, output_folder, frame_interval=10, target_fps=6):
249
+ """
250
+ 将视频转换为图像帧,并保存为 JPEG 文件。
251
+ frame_interval:保存��的步长
252
+ target_fps: 目标帧率(每秒保存的帧数)
253
+ """
254
+
255
+ # 确保输出文件夹存在
256
+ if not os.path.exists(output_folder):
257
+ os.makedirs(output_folder)
258
+ # 打开视频文件
259
+ cap = cv2.VideoCapture(video_path)
260
+ # 获取视频总帧数
261
+ frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
262
+
263
+ # 计算动态帧间隔
264
+ frame_interval = math.ceil(frames_num / target_fps)
265
+ print(f"总帧数: {frames_num} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
266
+ frame_count = 0
267
+ saved_frame_count = 0
268
+ success, frame = cap.read()
269
+
270
+ file_list = []
271
+ # 逐帧读取视频
272
+ while success:
273
+ if frame_count % frame_interval == 0:
274
+ # 每隔 frame_interval 帧保存一次
275
+ frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
276
+ cv2.imwrite(frame_filename, frame)
277
+ file_list.append(frame_filename)
278
+ saved_frame_count += 1
279
+ frame_count += 1
280
+ success, frame = cap.read()
281
+
282
+ # 释放视频捕获对象
283
+ cap.release()
284
+ print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
285
+ return file_list
286
+
287
+
288
+ def video_to_frames(video_path, output_folder, frame_interval=10, target_fps = 2):
289
+ """
290
+ 将视频转换为图像帧,并保存为 JPEG 文件。
291
+ frame_interval:保存帧的步长
292
+ target_fps: 目标帧率(每秒保存的帧数)
293
+ """
294
+
295
+ # 确保输出文件夹存在
296
+ if not os.path.exists(output_folder):
297
+ os.makedirs(output_folder)
298
+ # 打开视频文件
299
+ cap = cv2.VideoCapture(video_path)
300
+ # 获取视频的实际帧率
301
+ actual_fps = cap.get(cv2.CAP_PROP_FPS)
302
+
303
+ # 获取视频总帧数
304
+ frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
305
+
306
+ # 计算动态帧间隔
307
+ # frame_interval = math.ceil(actual_fps / target_fps)
308
+ print(f"实际帧率: {actual_fps} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
309
+ frame_count = 0
310
+ saved_frame_count = 0
311
+ success, frame = cap.read()
312
+
313
+ file_list = []
314
+ # 逐帧读取视频
315
+ while success:
316
+ if frame_count % frame_interval == 0:
317
+ # 每隔 frame_interval 帧保存一次
318
+ frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
319
+ cv2.imwrite(frame_filename, frame)
320
+ file_list.append(frame_filename)
321
+ saved_frame_count += 1
322
+ frame_count += 1
323
+ success, frame = cap.read()
324
+
325
+ # 释放视频捕获对象
326
+ cap.release()
327
+ print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
328
+ return file_list
329
+
330
+ def overlay_mask_on_image(image, mask, color=[0, 1, 0], alpha=0.5):
331
+ """
332
+ 将mask融合在image上显示。
333
+ 返回融合后的图片 (H, W, 3)
334
+ """
335
+
336
+ # 创建一个与image相同尺寸的全黑图像
337
+ mask_colored = np.zeros_like(image)
338
+
339
+ # 将mask为True的位置赋值为指定颜色
340
+ mask_colored[mask] = color
341
+
342
+ # 将彩色掩码与原图像叠加
343
+ overlay = cv2.addWeighted(image, 1 - alpha, mask_colored, alpha, 0)
344
+
345
+ return overlay
346
+ def get_reconstructed_video(sam2, outdir, model, device, image_size, image_mask, video_dir, schedule, niter, min_conf_thr,
347
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
348
+ scenegraph_type, winsize, refid, input_text):
349
+ target_dir = os.path.join(outdir, 'frames_video')
350
+ file_list = video_to_frames_fix(video_dir, target_dir)
351
+ scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, file_list, schedule, niter, min_conf_thr,
352
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
353
+ scenegraph_type, winsize, refid, target_dir, input_text)
354
+ return scene, outfile, imgs
355
+
356
+ def get_reconstructed_image(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
357
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
358
+ scenegraph_type, winsize, refid, input_text):
359
+ target_folder = handle_uploaded_files(filelist, os.path.join(outdir, 'uploaded_images'))
360
+ scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
361
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
362
+ scenegraph_type, winsize, refid, target_folder, input_text)
363
+ return scene, outfile, imgs
364
+ def get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
365
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
366
+ scenegraph_type, winsize, refid, images_folder, input_text=None):
367
+ """
368
+ from a list of images, run dust3rWithSam2 inference, global aligner.
369
+ then run get_3D_model_from_scene
370
+ """
371
+ imgs = load_images(filelist, size=image_size)
372
+ img_size = imgs[0]["true_shape"]
373
+ for img in imgs[1:]:
374
+ if not np.equal(img["true_shape"], img_size).all():
375
+ raise gradio.Error("Please ensure that the images you enter are of the same size")
376
+
377
+ if len(imgs) == 1:
378
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
379
+ imgs[1]['idx'] = 1
380
+ if scenegraph_type == "swin":
381
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
382
+ elif scenegraph_type == "oneref":
383
+ scenegraph_type = scenegraph_type + "-" + str(refid)
384
+
385
+
386
+
387
+
388
+
389
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
390
+ output = inference(pairs, model, device, batch_size=batch_size)
391
+
392
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
393
+ scene = global_aligner(output, device=device, mode=mode)
394
+ lr = 0.01
395
+
396
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
397
+ loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
398
+
399
+
400
+
401
+ # also return rgb, depth and confidence imgs
402
+ # depth is normalized with the max value for all images
403
+ # we apply the jet colormap on the confidence maps
404
+ rgbimg = scene.imgs
405
+ depths = to_numpy(scene.get_depthmaps())
406
+ confs = to_numpy([c for c in scene.im_conf])
407
+ cmap = plt.get_cmap('jet')
408
+ depths_max = max([d.max() for d in depths])
409
+ depths = [d / depths_max for d in depths]
410
+ confs_max = max([d.max() for d in confs])
411
+ confs = [cmap(d / confs_max) for d in confs]
412
+
413
+ # TODO 调用SAM2获取masks
414
+ h, w = rgbimg[0].shape[:-1]
415
+ masks = None
416
+ if not input_text or input_text.isspace(): # input_text 为空串
417
+ masks = get_masks_from_sam2(h, w, sam2, images_folder)
418
+ else:
419
+ masks = get_masks_from_grounded_sam2(h, w, sam2, images_folder, input_text) # gd-sam2
420
+
421
+
422
+ imgs = []
423
+ for i in range(len(rgbimg)):
424
+ imgs.append(rgbimg[i])
425
+ imgs.append(rgb(depths[i]))
426
+ imgs.append(rgb(confs[i]))
427
+ imgs.append(overlay_mask_on_image(rgbimg[i], masks[i])) # mask融合原图,展示SAM2的分割效果
428
+
429
+ # TODO 基于SAM2的mask过滤DUST3R的3D重建模型
430
+ outfile = get_3D_model_from_scene(outdir, scene, masks, min_conf_thr, as_pointcloud, mask_sky,
431
+ clean_depth, transparent_cams, cam_size)
432
+ return scene, outfile, imgs
433
+
434
+
435
+ def resize_mask_to_img(masks, target_width, target_height):
436
+ frame_mask = []
437
+ origin_size = masks[0][1].shape # 1表示object id
438
+ for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
439
+ # 每个frame可能包含多个object对应的mask
440
+ masks = list(objects_mask.values())
441
+ if not masks: # masks为空,即当前frame不包含object
442
+ frame_mask.append(np.ones(origin_size, dtype=bool))
443
+ else: # 将当前frame包含的所有object的mask取并集
444
+ union_mask = masks[0]
445
+ for mask in masks[1:]:
446
+ union_mask = np.logical_or(union_mask, mask)
447
+ frame_mask.append(union_mask)
448
+ resized_mask = []
449
+ for mask in frame_mask:
450
+ mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
451
+ resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
452
+ resized_mask.append(np.array(resized_mask_image) > 0)
453
+
454
+ return resized_mask
455
+
456
+ def get_masks_from_sam2(h, w, predictor, video_dir):
457
+
458
+ inference_state = predictor.init_state(video_path=video_dir)
459
+ predictor.reset_state(inference_state)
460
+
461
+
462
+ # 给一个帧添加points
463
+ points = np.array([[360, 550], [340, 400]], dtype=np.float32)
464
+ labels = np.array([1, 1], dtype=np.int32)
465
+
466
+
467
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points(
468
+ inference_state=inference_state,
469
+ frame_idx=0,
470
+ obj_id=1,
471
+ points=points,
472
+ labels=labels,
473
+ )
474
+
475
+ # sam2获取所有帧的分割结果
476
+ video_segments = {} # video_segments contains the per-frame segmentation results
477
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
478
+ video_segments[out_frame_idx] = {
479
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
480
+ for i, out_obj_id in enumerate(out_obj_ids)
481
+ }
482
+
483
+ resize_mask = resize_mask_to_img(video_segments, w, h)
484
+ return resize_mask
485
+
486
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
487
+ num_files = len(inputfiles) if inputfiles is not None else 1
488
+ max_winsize = max(1, (num_files - 1) // 2)
489
+ if scenegraph_type == "swin":
490
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
491
+ minimum=1, maximum=max_winsize, step=1, visible=True)
492
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
493
+ maximum=num_files - 1, step=1, visible=False)
494
+ elif scenegraph_type == "oneref":
495
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
496
+ minimum=1, maximum=max_winsize, step=1, visible=False)
497
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
498
+ maximum=num_files - 1, step=1, visible=True)
499
+ else:
500
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
501
+ minimum=1, maximum=max_winsize, step=1, visible=False)
502
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
503
+ maximum=num_files - 1, step=1, visible=False)
504
+ return winsize, refid
505
+
506
+ def process_images(imagesList):
507
+ return None
508
+
509
+ def process_videos(video):
510
+ return None
511
+
512
+ def upload_images_listener(image_size, file_list):
513
+ if len(file_list) == 1:
514
+ raise gradio.Error("Please enter images from at least two different views.")
515
+ print("Uploading image[0] to ImageMask:")
516
+ img_0 = load_images([file_list[0]], image_size)
517
+ i1 = img_0[0]['img'].squeeze(0)
518
+ rgb_img = rgb(i1)
519
+ return rgb_img
520
+ def upload_video_listener(image_size, video_dir):
521
+ cap = cv2.VideoCapture(video_dir)
522
+ success, frame = cap.read() # 第一帧
523
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
524
+ Image_frame = Image.fromarray(rgb_frame)
525
+ resized_frame = resize_images([Image_frame], image_size)
526
+ i1 = resized_frame[0]['img'].squeeze(0)
527
+ rgb_img = rgb(i1)
528
+ return rgb_img
529
+
530
+ def main_demo(sam2, tmpdirname, model, device, image_size, server_name, server_port):
531
+
532
+ # functools.partial解析:https://blog.csdn.net/wuShiJingZuo/article/details/135018810
533
+ recon_fun_image_demo = functools.partial(get_reconstructed_image,sam2, tmpdirname, model, device,
534
+ image_size)
535
+
536
+ recon_fun_video_demo = functools.partial(get_reconstructed_video, sam2, tmpdirname, model, device,
537
+ image_size)
538
+
539
+ upload_files_fun = functools.partial(upload_images_listener,image_size)
540
+ upload_video_fun = functools.partial(upload_video_listener, image_size)
541
+ with gradio.Blocks() as demo1:
542
+ scene = gradio.State(None)
543
+ gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
544
+ gradio.HTML("""<h2 style="text-align: center;">
545
+ <a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
546
+ <a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
547
+ </h2>""")
548
+ gradio.HTML("""
549
+ <div style="text-align: center;">
550
+ <h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
551
+ </div>
552
+ """)
553
+ gradio.set_static_paths(paths=["static/images/"])
554
+ project_path = "static/images/project.gif"
555
+ gradio.HTML(f"""
556
+ <div align='center' >
557
+ <img src="/file={project_path}" width='720px'>
558
+ </div>
559
+ """)
560
+ gradio.HTML("<p> \
561
+ <strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
562
+ Choose an example below &#128293; &#128293; &#128293; <br>\
563
+ Or, upload by yourself: <br>\
564
+ 1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
565
+ <br> \
566
+ 💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
567
+ 💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
568
+ 💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
569
+ Note: we only take the first 16 frames for the demo. \
570
+ </p>")
571
+
572
+ with gradio.Column():
573
+ with gradio.Row():
574
+ inputfiles = gradio.File(file_count="multiple")
575
+ with gradio.Column():
576
+ image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
577
+ label="prompt (提示图)", transforms=(), width=600, height=450)
578
+ input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
579
+ with gradio.Row():
580
+ schedule = gradio.Dropdown(["linear", "cosine"],
581
+ value='linear', label="schedule", info="For global alignment!")
582
+ niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
583
+ label="num_iterations", info="For global alignment!")
584
+ scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
585
+ value='complete', label="Scenegraph",
586
+ info="Define how to make pairs",
587
+ interactive=True)
588
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
589
+ minimum=1, maximum=1, step=1, visible=False)
590
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
591
+
592
+ run_btn = gradio.Button("Run")
593
+
594
+ with gradio.Row():
595
+ # adjust the confidence threshold
596
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
597
+ # adjust the camera size in the output pointcloud
598
+ cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
599
+ with gradio.Row():
600
+ as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
601
+ # two post process implemented
602
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
603
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
604
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
605
+
606
+ outmodel = gradio.Model3D()
607
+ outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
608
+
609
+ inputfiles.upload(upload_files_fun, inputs=inputfiles, outputs=image_mask)
610
+
611
+ run_btn.click(fn=recon_fun_image_demo, # 调用get_reconstructed_image即DUST3R模型
612
+ inputs=[image_mask, inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
613
+ mask_sky, clean_depth, transparent_cams, cam_size,
614
+ scenegraph_type, winsize, refid, input_text],
615
+ outputs=[scene, outmodel, outgallery])
616
+
617
+
618
+
619
+
620
+ # ## **************************** video *******************************************************
621
+ with gradio.Blocks() as demo2:
622
+ gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
623
+ gradio.HTML("""<h2 style="text-align: center;">
624
+ <a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
625
+ <a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
626
+ </h2>""")
627
+ gradio.HTML("""
628
+ <div style="text-align: center;">
629
+ <h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
630
+ </div>
631
+ """)
632
+ gradio.set_static_paths(paths=["static/images/"])
633
+ project_path = "static/images/project.gif"
634
+ gradio.HTML(f"""
635
+ <div align='center' >
636
+ <img src="/file={project_path}" width='720px'>
637
+ </div>
638
+ """)
639
+ gradio.HTML("<p> \
640
+ <strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
641
+ Choose an example below &#128293; &#128293; &#128293; <br>\
642
+ Or, upload by yourself: <br>\
643
+ 1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
644
+ <br> \
645
+ 💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
646
+ 💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
647
+ 💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
648
+ Note: we only take the first 16 frames for the demo. \
649
+ </p>")
650
+
651
+ with gradio.Column():
652
+ with gradio.Row():
653
+ input_video = gradio.Video(width=600, height=600)
654
+ with gradio.Column():
655
+ image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
656
+ label="prompt (提示图)", transforms=(), width=600, height=450)
657
+ input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
658
+
659
+
660
+ with gradio.Row():
661
+ schedule = gradio.Dropdown(["linear", "cosine"],
662
+ value='linear', label="schedule", info="For global alignment!")
663
+ niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
664
+ label="num_iterations", info="For global alignment!")
665
+ scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
666
+ value='complete', label="Scenegraph",
667
+ info="Define how to make pairs",
668
+ interactive=True)
669
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
670
+ minimum=1, maximum=1, step=1, visible=False)
671
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
672
+
673
+ run_btn = gradio.Button("Run")
674
+
675
+ with gradio.Row():
676
+ # adjust the confidence threshold
677
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
678
+ # adjust the camera size in the output pointcloud
679
+ cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
680
+ with gradio.Row():
681
+ as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
682
+ # two post process implemented
683
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
684
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
685
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
686
+
687
+ outmodel = gradio.Model3D()
688
+ outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
689
+
690
+ input_video.upload(upload_video_fun, inputs=input_video, outputs=image_mask)
691
+
692
+ run_btn.click(fn=recon_fun_video_demo, # 调用get_reconstructed_scene即DUST3R模型
693
+ inputs=[image_mask, input_video, schedule, niter, min_conf_thr, as_pointcloud,
694
+ mask_sky, clean_depth, transparent_cams, cam_size,
695
+ scenegraph_type, winsize, refid, input_text],
696
+ outputs=[scene, outmodel, outgallery])
697
+
698
+ app = gradio.TabbedInterface([demo1, demo2], ["3d rebuilding by images", "3d rebuilding by video"])
699
+ app.launch(share=False, server_name=server_name, server_port=server_port)
700
+
701
+
702
+ # TODO 修改bug:
703
+ #在项目的一次启动中,上传的多组图片在点击run后,会保存在同一个临时文件夹中,
704
+ # 这样后面再上传其他场景的图片时,不同场景下的图片会存在于一个文件夹中,
705
+ # 不同场景的图片导致分割与重建错误
706
+
707
+ ## 目前构思的解决:在文件夹下再基于创建一个文件夹存放不同场景的图片,可以基于时间命名该文件夹
708
+
709
+
710
+ if __name__ == '__main__':
711
+ parser = get_args_parser()
712
+ args = parser.parse_args()
713
+
714
+ if args.tmp_dir is not None:
715
+ tmp_path = args.tmp_dir
716
+ os.makedirs(tmp_path, exist_ok=True)
717
+ tempfile.tempdir = tmp_path
718
+
719
+ if args.server_name is not None:
720
+ server_name = args.server_name
721
+ else:
722
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
723
+
724
+ # DUST3R
725
+ model = load_model(args.weights, args.device)
726
+ # SAM2
727
+ # 加载模型
728
+ sam2_checkpoint = "./SAM2/checkpoints/sam2_hiera_large.pt"
729
+ model_cfg = "sam2_hiera_l.yaml"
730
+ sam2 = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
731
+ # dust3rWithSam2 will write the 3D model inside tmpdirname
732
+ with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: # DUST3R生成的3D .glb 文件所在的文件夹名称
733
+ print('Outputing stuff in', tmpdirname)
734
+ main_demo(sam2, tmpdirname, model, args.device, args.image_size, server_name, args.server_port)
735
+