#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# gradio demo
# --------------------------------------------------------
import argparse
import gradio
import os
import torch
import numpy as np
import tempfile
import functools
import trimesh
import copy
from scipy.spatial.transform import Rotation
from dust3r.inference import inference, load_model
from dust3r.image_pairs import make_pairs
from dust3r.utils.image import load_images, rgb
from dust3r.utils.device import to_numpy
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
import matplotlib.pyplot as plt
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
batch_size = 1
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
from SAM import SamPredictor
from SAM.build_sam import sam_model_registry
sam_checkpoint = "checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to("cuda" if torch.cuda.is_available() else "cpu")
predictor = SamPredictor(sam)
def get_args_parser():
parser = argparse.ArgumentParser()
parser_url = parser.add_mutually_exclusive_group()
parser_url.add_argument("--local_network", action='store_true', default=False,
help="make app accessible on local network: address will be set to")
parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is")
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
"If None, will search for an available port starting at 7860."),
parser.add_argument("--weights", type=str, default="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", required=False, help="path to the model weights")
parser.add_argument("--device", type=str, default='cpu', help="pytorch device")
parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
return parser
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
cam_color=None, as_pointcloud=False, transparent_cams=False):
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
pts3d = to_numpy(pts3d)
imgs = to_numpy(imgs)
focals = to_numpy(focals)
cams2world = to_numpy(cams2world)
scene = trimesh.Scene()
# full pointcloud
if as_pointcloud:
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
meshes = []
for i in range(len(imgs)):
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
mesh = trimesh.Trimesh(**cat_meshes(meshes))
# add each camera
for i, pose_c2w in enumerate(cams2world):
if isinstance(cam_color, list):
camera_edge_color = cam_color[i]
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
add_scene_cam(scene, pose_c2w, camera_edge_color,
None if transparent_cams else imgs[i], focals[i],
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
outfile = os.path.join(outdir, 'scene.glb')
print('(exporting 3D scene to', outfile, ')')
return outfile
def get_3D_model_from_scene(outdir, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
clean_depth=False, transparent_cams=False, cam_size=0.05):
extract 3D_model (glb file) from a reconstructed scene
if scene is None:
return None
# post processes
if clean_depth:
scene = scene.clean_pointcloud()
if mask_sky:
scene = scene.mask_sky()
# get optimized values from scene
rgbimg = scene.imgs
# print('SAM step...')
# predictor.set_image((rgbimg[0] * 255).astype(np.uint8))
# h,w,c = rgbimg[0].shape
# input_point = np.array([
# [int(w/2), int(h/2)],
# [int(w/2), int(h/2)-20]
# ])
# input_label = np.array([1,1])
# masks1, scores, logits = predictor.predict(
# point_coords=input_point,
# point_labels=input_label,
# multimask_output=False,
# )
# fig, ax = plt.subplots(4, 2, figsize=(20, 20))
# show_mask(masks1[0], ax[0][0], random_color=True)
# show_points(input_point, input_label, ax[0][0])
# ax[0][1].imshow(rgbimg[0])
# predictor.set_image((rgbimg[1] * 255).astype(np.uint8))
# h,w,c = rgbimg[1].shape
# input_point = np.array([
# [int(w/2), int(h/2)],
# [int(w/2), int(h/2)-20]
# ])
# input_label = np.array([1,1])
# masks2, scores, logits = predictor.predict(
# point_coords=input_point,
# point_labels=input_label,
# multimask_output=False,
# )
focals = scene.get_focals().cpu()
cams2world = scene.get_im_poses().cpu()
# 3D pointcloud from depthmap, poses and intrinsics
pts3d = to_numpy(scene.get_pts3d())
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
msk = to_numpy(scene.get_masks())
# ax[1][0].imshow(msk[0])
# msk[0] = msk[0] & masks1[0]
# ax[1][1].imshow(msk[0])
# ax[2][1].imshow(rgbimg[1])
# show_mask(masks2[0], ax[2][0], random_color=True)
# show_points(input_point, input_label, ax[2][0])
# ax[3][0].imshow(msk[1])
# # msk[1] = msk[1] & masks2[0]
# ax[3][1].imshow(msk[1])
# plt.savefig("rgb.png")
# import pdb
# pdb.set_trace()
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
transparent_cams=transparent_cams, cam_size=cam_size)
def get_reconstructed_scene(outdir, model, device, image_size, filelist, schedule, niter, min_conf_thr,
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid):
from a list of images, run dust3r inference, global aligner.
then run get_3D_model_from_scene
imgs = load_images(filelist, size=image_size)
if len(imgs) == 1:
imgs = [imgs[0], copy.deepcopy(imgs[0])]
imgs[1]['idx'] = 1
if scenegraph_type == "swin":
scenegraph_type = scenegraph_type + "-" + str(winsize)
elif scenegraph_type == "oneref":
scenegraph_type = scenegraph_type + "-" + str(refid)
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
output = inference(pairs, model, device, batch_size=batch_size)
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
scene = global_aligner(output, device=device, mode=mode)
lr = 0.01
if mode == GlobalAlignerMode.PointCloudOptimizer:
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size)
# also return rgb, depth and confidence imgs
# depth is normalized with the max value for all images
# we apply the jet colormap on the confidence maps
rgbimg = scene.imgs
depths = to_numpy(scene.get_depthmaps())
confs = to_numpy([c for c in scene.im_conf])
cmap = plt.get_cmap('jet')
depths_max = max([d.max() for d in depths])
depths = [d/depths_max for d in depths]
confs_max = max([d.max() for d in confs])
confs = [cmap(d/confs_max) for d in confs]
imgs = []
for i in range(len(rgbimg)):
return scene, outfile, imgs
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
num_files = len(inputfiles) if inputfiles is not None else 1
max_winsize = max(1, (num_files - 1)//2)
if scenegraph_type == "swin":
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
minimum=1, maximum=max_winsize, step=1, visible=True)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
maximum=num_files-1, step=1, visible=False)
elif scenegraph_type == "oneref":
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
minimum=1, maximum=max_winsize, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
maximum=num_files-1, step=1, visible=True)
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
minimum=1, maximum=max_winsize, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
maximum=num_files-1, step=1, visible=False)
return winsize, refid
def main_demo(tmpdirname, model, device, image_size, server_name, server_port):
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, image_size)
model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo:
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
scene = gradio.State(None)
gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>')
with gradio.Column():
inputfiles = gradio.File(file_count="multiple")
with gradio.Row():
schedule = gradio.Dropdown(["linear", "cosine"],
value='linear', label="schedule", info="For global alignment!")
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
label="num_iterations", info="For global alignment!")
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
value='complete', label="Scenegraph",
info="Define how to make pairs",
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
minimum=1, maximum=1, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
run_btn = gradio.Button("Run")
with gradio.Row():
# adjust the confidence threshold
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
# adjust the camera size in the output pointcloud
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
with gradio.Row():
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
# two post process implemented
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
outmodel = gradio.Model3D()
outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%")
# events
inputs=[inputfiles, winsize, refid, scenegraph_type],
outputs=[winsize, refid])
inputs=[inputfiles, winsize, refid, scenegraph_type],
outputs=[winsize, refid])
inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid],
outputs=[scene, outmodel, outgallery])
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
demo.launch(share=False, server_name=server_name, server_port=server_port)
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
if args.tmp_dir is not None:
tmp_path = args.tmp_dir
os.makedirs(tmp_path, exist_ok=True)
tempfile.tempdir = tmp_path
if args.server_name is not None:
server_name = args.server_name
server_name = '' if args.local_network else ''
model = load_model(args.weights, args.device)
# dust3r will write the 3D model inside tmpdirname
with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
print('Outputing stuff in', tmpdirname)
main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port)