yansong1616's picture
Update app.py
1dd77e9 verified
raw
history blame
34.2 kB
# -*- coding: utf-8 -*-
import argparse
import gradio
import os
import torch
import numpy as np
import tempfile
import functools
import trimesh
import copy
from scipy.spatial.transform import Rotation
from dust3r.inference import inference, load_model
from dust3r.image_pairs import make_pairs
from dust3r.utils.image import load_images, rgb, resize_images
from dust3r.utils.device import to_numpy
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from SAM2.sam2.build_sam import build_sam2_video_predictor
import matplotlib.pyplot as plt
import shutil
import json
from PIL import Image
import math
import cv2
plt.ion()
# 添加 sam2 模块路径
sys.path.append(os.path.join(os.path.dirname(__file__), 'SAM2'))
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
batch_size = 1
########################## 引入grounding_dino #############################
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
def get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text):
# init grounding dino model from huggingface
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
# setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot
"""
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
"""
# prompt grounding dino to get the box coordinates on specific frame
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
# frame_names.sort(key=lambda p: os.path.splitext(p)[0])
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
image = Image.open(img_path)
# run Grounding DINO on the image
inputs = processor(images=image, text=input_text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = grounding_model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=0.25,
text_threshold=0.3,
target_sizes=[image.size[::-1]]
)
return results[0]["boxes"], results[0]["labels"]
def get_masks_from_grounded_sam2(h, w, predictor, video_dir, input_text):
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
print("Running Groundding DINO......")
input_boxes, OBJECTS = get_mask_from_grounding_dino(video_dir, ann_frame_idx, ann_obj_id, input_text)
print("Groundding DINO run over!")
if(len(OBJECTS) < 1):
raise gradio.Error("The images you input do not contain the target in '{}'".format(input_text))
# 给第一个帧输入由grounding_dino输出的boxes作为prompts
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes)):
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
box=box,
)
break #只加入第一个box
# sam2获取所有帧的分割结果
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
resize_mask = resize_mask_to_img(video_segments, w, h)
return resize_mask
def handle_uploaded_files(uploaded_files, target_folder):
# 创建目标文件夹
if not os.path.exists(target_folder):
os.makedirs(target_folder)
# 遍历上传的文件,移动到目标文件夹
for file in uploaded_files:
file_path = file.name # 文件的临时路径
file_name = os.path.basename(file_path) # 文件名
target_path = os.path.join(target_folder, file_name)
shutil.copy2(file_path, target_path)
print("copy images from {} to {}".format(file_path, target_path))
return target_folder
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_mask_sam2(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def get_args_parser():
parser = argparse.ArgumentParser()
parser_url = parser.add_mutually_exclusive_group()
parser_url.add_argument("--local_network", action='store_true', default=False,
help="make app accessible on local network: address will be set to 0.0.0.0")
parser_url.add_argument("--server_name", type=str, default="0.0.0.0", help="server url, default is 127.0.0.1")
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
"If None, will search for an available port starting at 7860."),
default=None)
parser.add_argument("--weights", type=str, default="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", required=False, help="path to the model weights")
parser.add_argument("--device", type=str, default='cpu', help="pytorch device")
parser.add_argument("--tmp_dir", type=str, default="./", help="value for tempfile.tempdir")
return parser
# 将渲染的3D保存到outfile路径
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
cam_color=None, as_pointcloud=False, transparent_cams=False):
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
pts3d = to_numpy(pts3d)
imgs = to_numpy(imgs)
focals = to_numpy(focals)
cams2world = to_numpy(cams2world)
scene = trimesh.Scene()
# full pointcloud
if as_pointcloud:
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
scene.add_geometry(pct)
else:
meshes = []
for i in range(len(imgs)):
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
mesh = trimesh.Trimesh(**cat_meshes(meshes))
scene.add_geometry(mesh)
# add each camera
for i, pose_c2w in enumerate(cams2world):
if isinstance(cam_color, list):
camera_edge_color = cam_color[i]
else:
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
add_scene_cam(scene, pose_c2w, camera_edge_color,
None if transparent_cams else imgs[i], focals[i],
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
outfile = os.path.join(outdir, 'scene.glb')
print('(exporting 3D scene to', outfile, ')')
scene.export(file_obj=outfile)
return outfile
def get_3D_model_from_scene(outdir, scene, sam2_masks, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
clean_depth=False, transparent_cams=False, cam_size=0.05):
"""
extract 3D_model (glb file) from a reconstructed scene
"""
if scene is None:
return None
# post processes
if clean_depth:
scene = scene.clean_pointcloud()
if mask_sky:
scene = scene.mask_sky()
# get optimized values from scene
rgbimg = scene.imgs
focals = scene.get_focals().cpu()
cams2world = scene.get_im_poses().cpu()
# 3D pointcloud from depthmap, poses and intrinsics
pts3d = to_numpy(scene.get_pts3d())
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
msk = to_numpy(scene.get_masks())
assert len(msk) == len(sam2_masks)
# 将sam2输出的mask 和 dust3r输出的置信度阈值筛选后的msk取交集
for i in range(len(sam2_masks)):
msk[i] = np.logical_and(msk[i], sam2_masks[i])
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
transparent_cams=transparent_cams, cam_size=cam_size) # 置信度和SAM2 mask的交集
# 将视频分割成固定帧数
def video_to_frames_fix(video_path, output_folder, frame_interval=10, target_fps=6):
"""
将视频转换为图像帧,并保存为 JPEG 文件。
frame_interval:保存帧的步长
target_fps: 目标帧率(每秒保存的帧数)
"""
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 打开视频文件
cap = cv2.VideoCapture(video_path)
# 获取视频总帧数
frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
# 计算动态帧间隔
frame_interval = math.ceil(frames_num / target_fps)
print(f"总帧数: {frames_num} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
frame_count = 0
saved_frame_count = 0
success, frame = cap.read()
file_list = []
# 逐帧读取视频
while success:
if frame_count % frame_interval == 0:
# 每隔 frame_interval 帧保存一次
frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
cv2.imwrite(frame_filename, frame)
file_list.append(frame_filename)
saved_frame_count += 1
frame_count += 1
success, frame = cap.read()
# 释放视频捕获对象
cap.release()
print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
return file_list
def video_to_frames(video_path, output_folder, frame_interval=10, target_fps = 2):
"""
将视频转换为图像帧,并保存为 JPEG 文件。
frame_interval:保存帧的步长
target_fps: 目标帧率(每秒保存的帧数)
"""
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 打开视频文件
cap = cv2.VideoCapture(video_path)
# 获取视频的实际帧率
actual_fps = cap.get(cv2.CAP_PROP_FPS)
# 获取视频总帧数
frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
# 计算动态帧间隔
# frame_interval = math.ceil(actual_fps / target_fps)
print(f"实际帧率: {actual_fps} FPS, 动态帧间隔: 每隔 {frame_interval} 帧保存一次.")
frame_count = 0
saved_frame_count = 0
success, frame = cap.read()
file_list = []
# 逐帧读取视频
while success:
if frame_count % frame_interval == 0:
# 每隔 frame_interval 帧保存一次
frame_filename = os.path.join(output_folder, f"frame_{saved_frame_count:04d}.jpg")
cv2.imwrite(frame_filename, frame)
file_list.append(frame_filename)
saved_frame_count += 1
frame_count += 1
success, frame = cap.read()
# 释放视频捕获对象
cap.release()
print(f"视频处理完成,共保存了 {saved_frame_count} 帧到文件夹 '{output_folder}'.")
return file_list
def overlay_mask_on_image(image, mask, color=[0, 1, 0], alpha=0.5):
"""
将mask融合在image上显示。
返回融合后的图片 (H, W, 3)
"""
# 创建一个与image相同尺寸的全黑图像
mask_colored = np.zeros_like(image)
# 将mask为True的位置赋值为指定颜色
mask_colored[mask] = color
# 将彩色掩码与原图像叠加
overlay = cv2.addWeighted(image, 1 - alpha, mask_colored, alpha, 0)
return overlay
def get_reconstructed_video(sam2, outdir, model, device, image_size, image_mask, video_dir, schedule, niter, min_conf_thr,
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid, input_text):
target_dir = os.path.join(outdir, 'frames_video')
file_list = video_to_frames_fix(video_dir, target_dir)
scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, file_list, schedule, niter, min_conf_thr,
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid, target_dir, input_text)
return scene, outfile, imgs
def get_reconstructed_image(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid, input_text):
target_folder = handle_uploaded_files(filelist, os.path.join(outdir, 'uploaded_images'))
scene, outfile, imgs = get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid, target_folder, input_text)
return scene, outfile, imgs
def get_reconstructed_scene(sam2, outdir, model, device, image_size, image_mask, filelist, schedule, niter, min_conf_thr,
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid, images_folder, input_text=None):
"""
from a list of images, run dust3rWithSam2 inference, global aligner.
then run get_3D_model_from_scene
"""
imgs = load_images(filelist, size=image_size)
img_size = imgs[0]["true_shape"]
for img in imgs[1:]:
if not np.equal(img["true_shape"], img_size).all():
raise gradio.Error("Please ensure that the images you enter are of the same size")
if len(imgs) == 1:
imgs = [imgs[0], copy.deepcopy(imgs[0])]
imgs[1]['idx'] = 1
if scenegraph_type == "swin":
scenegraph_type = scenegraph_type + "-" + str(winsize)
elif scenegraph_type == "oneref":
scenegraph_type = scenegraph_type + "-" + str(refid)
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
output = inference(pairs, model, device, batch_size=batch_size)
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
scene = global_aligner(output, device=device, mode=mode)
lr = 0.01
if mode == GlobalAlignerMode.PointCloudOptimizer:
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
# also return rgb, depth and confidence imgs
# depth is normalized with the max value for all images
# we apply the jet colormap on the confidence maps
rgbimg = scene.imgs
depths = to_numpy(scene.get_depthmaps())
confs = to_numpy([c for c in scene.im_conf])
cmap = plt.get_cmap('jet')
depths_max = max([d.max() for d in depths])
depths = [d / depths_max for d in depths]
confs_max = max([d.max() for d in confs])
confs = [cmap(d / confs_max) for d in confs]
# TODO 调用SAM2获取masks
h, w = rgbimg[0].shape[:-1]
masks = None
if not input_text or input_text.isspace(): # input_text 为空串
masks = get_masks_from_sam2(h, w, sam2, images_folder)
else:
masks = get_masks_from_grounded_sam2(h, w, sam2, images_folder, input_text) # gd-sam2
imgs = []
for i in range(len(rgbimg)):
imgs.append(rgbimg[i])
imgs.append(rgb(depths[i]))
imgs.append(rgb(confs[i]))
imgs.append(overlay_mask_on_image(rgbimg[i], masks[i])) # mask融合原图,展示SAM2的分割效果
# TODO 基于SAM2的mask过滤DUST3R的3D重建模型
outfile = get_3D_model_from_scene(outdir, scene, masks, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size)
return scene, outfile, imgs
def resize_mask_to_img(masks, target_width, target_height):
frame_mask = []
origin_size = masks[0][1].shape # 1表示object id
for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
# 每个frame可能包含多个object对应的mask
masks = list(objects_mask.values())
if not masks: # masks为空,即当前frame不包含object
frame_mask.append(np.ones(origin_size, dtype=bool))
else: # 将当前frame包含的所有object的mask取并集
union_mask = masks[0]
for mask in masks[1:]:
union_mask = np.logical_or(union_mask, mask)
frame_mask.append(union_mask)
resized_mask = []
for mask in frame_mask:
mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
resized_mask.append(np.array(resized_mask_image) > 0)
return resized_mask
def get_masks_from_sam2(h, w, predictor, video_dir):
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)
# 给一个帧添加points
points = np.array([[360, 550], [340, 400]], dtype=np.float32)
labels = np.array([1, 1], dtype=np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=0,
obj_id=1,
points=points,
labels=labels,
)
# sam2获取所有帧的分割结果
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
resize_mask = resize_mask_to_img(video_segments, w, h)
return resize_mask
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
num_files = len(inputfiles) if inputfiles is not None else 1
max_winsize = max(1, (num_files - 1) // 2)
if scenegraph_type == "swin":
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
minimum=1, maximum=max_winsize, step=1, visible=True)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
maximum=num_files - 1, step=1, visible=False)
elif scenegraph_type == "oneref":
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
minimum=1, maximum=max_winsize, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
maximum=num_files - 1, step=1, visible=True)
else:
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
minimum=1, maximum=max_winsize, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
maximum=num_files - 1, step=1, visible=False)
return winsize, refid
def process_images(imagesList):
return None
def process_videos(video):
return None
def upload_images_listener(image_size, file_list):
if len(file_list) == 1:
raise gradio.Error("Please enter images from at least two different views.")
print("Uploading image[0] to ImageMask:")
img_0 = load_images([file_list[0]], image_size)
i1 = img_0[0]['img'].squeeze(0)
rgb_img = rgb(i1)
return rgb_img
def upload_video_listener(image_size, video_dir):
cap = cv2.VideoCapture(video_dir)
success, frame = cap.read() # 第一帧
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
Image_frame = Image.fromarray(rgb_frame)
resized_frame = resize_images([Image_frame], image_size)
i1 = resized_frame[0]['img'].squeeze(0)
rgb_img = rgb(i1)
return rgb_img
def main_demo(sam2, tmpdirname, model, device, image_size, server_name, server_port):
# functools.partial解析:https://blog.csdn.net/wuShiJingZuo/article/details/135018810
recon_fun_image_demo = functools.partial(get_reconstructed_image,sam2, tmpdirname, model, device,
image_size)
recon_fun_video_demo = functools.partial(get_reconstructed_video, sam2, tmpdirname, model, device,
image_size)
upload_files_fun = functools.partial(upload_images_listener,image_size)
upload_video_fun = functools.partial(upload_video_listener, image_size)
with gradio.Blocks() as demo1:
scene = gradio.State(None)
gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
gradio.HTML("""<h2 style="text-align: center;">
<a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
<a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
</h2>""")
gradio.HTML("""
<div style="text-align: center;">
<h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
</div>
""")
gradio.set_static_paths(paths=["static/images/"])
project_path = "static/images/project.gif"
gradio.HTML(f"""
<div align='center' >
<img src="/file={project_path}" width='720px'>
</div>
""")
gradio.HTML("<p> \
<strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
Choose an example below &#128293; &#128293; &#128293; <br>\
Or, upload by yourself: <br>\
1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
<br> \
💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
Note: we only take the first 16 frames for the demo. \
</p>")
with gradio.Column():
with gradio.Row():
inputfiles = gradio.File(file_count="multiple")
with gradio.Column():
image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
label="prompt (提示图)", transforms=(), width=600, height=450)
input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
with gradio.Row():
schedule = gradio.Dropdown(["linear", "cosine"],
value='linear', label="schedule", info="For global alignment!")
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
label="num_iterations", info="For global alignment!")
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
value='complete', label="Scenegraph",
info="Define how to make pairs",
interactive=True)
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
minimum=1, maximum=1, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
run_btn = gradio.Button("Run")
with gradio.Row():
# adjust the confidence threshold
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
# adjust the camera size in the output pointcloud
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
with gradio.Row():
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
# two post process implemented
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
outmodel = gradio.Model3D()
outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
inputfiles.upload(upload_files_fun, inputs=inputfiles, outputs=image_mask)
run_btn.click(fn=recon_fun_image_demo, # 调用get_reconstructed_image即DUST3R模型
inputs=[image_mask, inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid, input_text],
outputs=[scene, outmodel, outgallery])
# ## **************************** video *******************************************************
with gradio.Blocks() as demo2:
gradio.HTML('<h1 style="text-align: center;">DUST3R With SAM2: Segmenting Everything In 3D</h1>')
gradio.HTML("""<h2 style="text-align: center;">
<a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a>
<a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a>
</h2>""")
gradio.HTML("""
<div style="text-align: center;">
<h2 style="text-align: center;">DUSt3R can unify various 3D vision tasks and set new SoTAs on monocular/multi-view depth estimation as well as relative pose estimation.</h2>
</div>
""")
gradio.set_static_paths(paths=["static/images/"])
project_path = "static/images/project.gif"
gradio.HTML(f"""
<div align='center' >
<img src="/file={project_path}" width='720px'>
</div>
""")
gradio.HTML("<p> \
<strong>DUST3R+SAM2: One touch for any segmentation in a video.</strong> <br>\
Choose an example below &#128293; &#128293; &#128293; <br>\
Or, upload by yourself: <br>\
1. Upload a video to be tested to 'video'. If failed, please check the codec, we recommend h.264 by default. <br>2. Upload a prompt image to 'prompt' and draw <strong>a point or line on the target</strong>. <br>\
<br> \
💎 SAM segments the target with any point or scribble, then SegGPT segments the whole video. <br>\
💎 Examples below were never trained and are randomly selected for testing in the wild. <br>\
💎 Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. <br> \
Note: we only take the first 16 frames for the demo. \
</p>")
with gradio.Column():
with gradio.Row():
input_video = gradio.Video(width=600, height=600)
with gradio.Column():
image_mask = gradio.ImageMask(image_mode="RGB", type="numpy", brush=gradio.Brush(),
label="prompt (提示图)", transforms=(), width=600, height=450)
input_text = gradio.Textbox(info="please enter object here", label="Text Prompt")
with gradio.Row():
schedule = gradio.Dropdown(["linear", "cosine"],
value='linear', label="schedule", info="For global alignment!")
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
label="num_iterations", info="For global alignment!")
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
value='complete', label="Scenegraph",
info="Define how to make pairs",
interactive=True)
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
minimum=1, maximum=1, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
run_btn = gradio.Button("Run")
with gradio.Row():
# adjust the confidence threshold
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
# adjust the camera size in the output pointcloud
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
with gradio.Row():
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
# two post process implemented
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
outmodel = gradio.Model3D()
outgallery = gradio.Gallery(label='rgb,depth,confidence,mask', columns=4, height="100%")
input_video.upload(upload_video_fun, inputs=input_video, outputs=image_mask)
run_btn.click(fn=recon_fun_video_demo, # 调用get_reconstructed_scene即DUST3R模型
inputs=[image_mask, input_video, schedule, niter, min_conf_thr, as_pointcloud,
mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid, input_text],
outputs=[scene, outmodel, outgallery])
app = gradio.TabbedInterface([demo1, demo2], ["3d rebuilding by images", "3d rebuilding by video"])
app.launch(share=False, server_name=server_name, server_port=server_port)
# TODO 修改bug:
#在项目的一次启动中,上传的多组图片在点击run后,会保存在同一个临时文件夹中,
# 这样后面再上传其他场景的图片时,不同场景下的图片会存在于一个文件夹中,
# 不同场景的图片导致分割与重建错误
## 目前构思的解决:在文件夹下再基于创建一个文件夹存放不同场景的图片,可以基于时间命名该文件夹
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
if args.tmp_dir is not None:
tmp_path = args.tmp_dir
os.makedirs(tmp_path, exist_ok=True)
tempfile.tempdir = tmp_path
if args.server_name is not None:
server_name = args.server_name
else:
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
# DUST3R
model = load_model(args.weights, args.device)
# SAM2
# 加载模型
sam2_checkpoint = "./SAM2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2 = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
# dust3rWithSam2 will write the 3D model inside tmpdirname
with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: # DUST3R生成的3D .glb 文件所在的文件夹名称
print('Outputing stuff in', tmpdirname)
main_demo(sam2, tmpdirname, model, args.device, args.image_size, server_name, args.server_port)