Spaces:
Running
on
Zero
Running
on
Zero
kxhit
commited on
Commit
·
ce137a5
1
Parent(s):
c1a2c97
my app
Browse files
README.md
CHANGED
@@ -4,6 +4,7 @@ emoji: 📸📸➡️🖼️🖼️🖼️
|
|
4 |
app_file: app.py
|
5 |
sdk: gradio
|
6 |
sdk_version: 4.31.0
|
|
|
7 |
---
|
8 |
[comment]: <> (# EscherNet: A Generative Model for Scalable View Synthesis)
|
9 |
|
|
|
4 |
app_file: app.py
|
5 |
sdk: gradio
|
6 |
sdk_version: 4.31.0
|
7 |
+
short_description: 3D novel view synthesis from any number images!
|
8 |
---
|
9 |
[comment]: <> (# EscherNet: A Generative Model for Scalable View Synthesis)
|
10 |
|
app.py
CHANGED
@@ -1,125 +1,786 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
import spaces
|
4 |
import torch
|
5 |
-
|
6 |
-
import rerun as rr
|
7 |
-
import rerun.blueprint as rrb
|
8 |
-
from pathlib import Path
|
9 |
-
import uuid
|
10 |
-
|
11 |
-
from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
|
12 |
-
from mini_dust3r.model import AsymmetricCroCo3DStereo
|
13 |
-
|
14 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
-
model = AsymmetricCroCo3DStereo.from_pretrained(
|
16 |
-
"naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
|
17 |
-
).to(DEVICE)
|
18 |
-
|
19 |
-
|
20 |
-
def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
|
21 |
-
# dont show 2d views if there are more than 4 images as to not clutter the view
|
22 |
-
if len(image_name_list) > 4:
|
23 |
-
blueprint = rrb.Blueprint(
|
24 |
-
rrb.Horizontal(
|
25 |
-
rrb.Spatial3DView(origin=f"{log_path}"),
|
26 |
-
),
|
27 |
-
collapse_panels=True,
|
28 |
-
)
|
29 |
-
else:
|
30 |
-
blueprint = rrb.Blueprint(
|
31 |
-
rrb.Horizontal(
|
32 |
-
contents=[
|
33 |
-
rrb.Spatial3DView(origin=f"{log_path}"),
|
34 |
-
rrb.Vertical(
|
35 |
-
contents=[
|
36 |
-
rrb.Spatial2DView(
|
37 |
-
origin=f"{log_path}/camera_{i}/pinhole/",
|
38 |
-
contents=[
|
39 |
-
"+ $origin/**",
|
40 |
-
],
|
41 |
-
)
|
42 |
-
for i in range(len(image_name_list))
|
43 |
-
]
|
44 |
-
),
|
45 |
-
],
|
46 |
-
column_shares=[3, 1],
|
47 |
-
),
|
48 |
-
collapse_panels=True,
|
49 |
-
)
|
50 |
-
return blueprint
|
51 |
-
|
52 |
-
|
53 |
-
@spaces.GPU
|
54 |
-
def predict(image_name_list: list[str] | str):
|
55 |
-
# check if is list or string and if not raise error
|
56 |
-
if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
|
57 |
-
raise gr.Error(
|
58 |
-
f"Input must be a list of strings or a string, got: {type(image_name_list)}"
|
59 |
-
)
|
60 |
-
uuid_str = str(uuid.uuid4())
|
61 |
-
filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
|
62 |
-
rr.init(f"{uuid_str}")
|
63 |
-
log_path = Path("world")
|
64 |
-
|
65 |
-
if isinstance(image_name_list, str):
|
66 |
-
image_name_list = [image_name_list]
|
67 |
-
|
68 |
-
optimized_results: OptimizedResult = inferece_dust3r(
|
69 |
-
image_dir_or_list=image_name_list,
|
70 |
-
model=model,
|
71 |
-
device=DEVICE,
|
72 |
-
batch_size=1,
|
73 |
-
)
|
74 |
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
)
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
with gr.Column():
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
with gr.Column():
|
117 |
-
multi_files = gr.File(file_count="multiple")
|
118 |
-
run_btn_multi = gr.Button("Run")
|
119 |
-
rerun_viewer_multi = Rerun(height=900)
|
120 |
-
run_btn_multi.click(
|
121 |
-
fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi]
|
122 |
-
)
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
|
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
+
print("cuda is available: ", torch.cuda.is_available())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
import gradio as gr
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
import rembg
|
9 |
+
import numpy as np
|
10 |
+
import math
|
11 |
+
import open3d as o3d
|
12 |
+
from PIL import Image
|
13 |
+
import torchvision
|
14 |
+
import trimesh
|
15 |
+
from skimage.io import imsave
|
16 |
+
import imageio
|
17 |
+
import cv2
|
18 |
+
import matplotlib.pyplot as pl
|
19 |
+
pl.ion()
|
20 |
|
21 |
+
CaPE_TYPE = "6DoF"
|
22 |
+
device = 'cuda' #if torch.cuda.is_available() else 'cpu'
|
23 |
+
weight_dtype = torch.float16
|
24 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
25 |
|
26 |
+
# EscherNet
|
27 |
+
# create angles in archimedean spiral with N steps
|
28 |
+
def get_archimedean_spiral(sphere_radius, num_steps=250):
|
29 |
+
# x-z plane, around upper y
|
30 |
+
'''
|
31 |
+
https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
|
32 |
+
'''
|
33 |
+
a = 40
|
34 |
+
r = sphere_radius
|
35 |
|
36 |
+
translations = []
|
37 |
+
angles = []
|
38 |
+
|
39 |
+
# i = a / 2
|
40 |
+
i = 0.01
|
41 |
+
while i < a:
|
42 |
+
theta = i / a * math.pi
|
43 |
+
x = r * math.sin(theta) * math.cos(-i)
|
44 |
+
z = r * math.sin(-theta + math.pi) * math.sin(-i)
|
45 |
+
y = r * - math.cos(theta)
|
46 |
+
|
47 |
+
# translations.append((x, y, z)) # origin
|
48 |
+
translations.append((x, z, -y))
|
49 |
+
angles.append([np.rad2deg(-i), np.rad2deg(theta)])
|
50 |
+
|
51 |
+
# i += a / (2 * num_steps)
|
52 |
+
i += a / (1 * num_steps)
|
53 |
+
|
54 |
+
return np.array(translations), np.stack(angles)
|
55 |
+
|
56 |
+
def look_at(origin, target, up):
|
57 |
+
forward = (target - origin)
|
58 |
+
forward = forward / np.linalg.norm(forward)
|
59 |
+
right = np.cross(up, forward)
|
60 |
+
right = right / np.linalg.norm(right)
|
61 |
+
new_up = np.cross(forward, right)
|
62 |
+
rotation_matrix = np.column_stack((right, new_up, -forward, target))
|
63 |
+
matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
|
64 |
+
return matrix
|
65 |
+
|
66 |
+
import einops
|
67 |
+
import sys
|
68 |
+
|
69 |
+
sys.path.insert(0, "./6DoF/") # TODO change it when deploying
|
70 |
+
# use the customized diffusers modules
|
71 |
+
from diffusers import DDIMScheduler
|
72 |
+
from dataset import get_pose
|
73 |
+
from CN_encoder import CN_encoder
|
74 |
+
from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
|
75 |
+
from segment_anything import sam_model_registry, SamPredictor
|
76 |
+
|
77 |
+
# import rembg
|
78 |
+
from carvekit.api.high import HiInterface
|
79 |
+
|
80 |
+
|
81 |
+
pretrained_model_name_or_path = "kxic/EscherNet_demo"
|
82 |
+
resolution = 256
|
83 |
+
h,w = resolution,resolution
|
84 |
+
guidance_scale = 3.0
|
85 |
+
radius = 2.2
|
86 |
+
bg_color = [1., 1., 1., 1.]
|
87 |
+
image_transforms = torchvision.transforms.Compose(
|
88 |
+
[
|
89 |
+
torchvision.transforms.Resize((resolution, resolution)), # 256, 256
|
90 |
+
torchvision.transforms.ToTensor(),
|
91 |
+
torchvision.transforms.Normalize([0.5], [0.5])
|
92 |
+
]
|
93 |
)
|
94 |
+
xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
|
95 |
+
# only half toop
|
96 |
+
xyzs_spiral = xyzs_spiral[:100]
|
97 |
+
angles_spiral = angles_spiral[:100]
|
98 |
+
|
99 |
+
# Init pipeline
|
100 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
|
101 |
+
image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
|
102 |
+
pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
|
103 |
+
pretrained_model_name_or_path,
|
104 |
+
revision=None,
|
105 |
+
scheduler=scheduler,
|
106 |
+
image_encoder=None,
|
107 |
+
safety_checker=None,
|
108 |
+
feature_extractor=None,
|
109 |
+
torch_dtype=weight_dtype,
|
110 |
+
)
|
111 |
+
pipeline.image_encoder = image_encoder.to(weight_dtype)
|
112 |
+
|
113 |
+
pipeline.set_progress_bar_config(disable=False)
|
114 |
+
|
115 |
+
pipeline = pipeline.to(device)
|
116 |
+
|
117 |
+
# pipeline.enable_xformers_memory_efficient_attention()
|
118 |
+
# enable vae slicing
|
119 |
+
pipeline.enable_vae_slicing()
|
120 |
+
# pipeline.enable_xformers_memory_efficient_attention()
|
121 |
+
|
122 |
+
|
123 |
+
#### object segmentation
|
124 |
+
def sam_init():
|
125 |
+
sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
|
126 |
+
if os.path.exists(sam_checkpoint) is False:
|
127 |
+
os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
|
128 |
+
model_type = "vit_h"
|
129 |
+
|
130 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
|
131 |
+
predictor = SamPredictor(sam)
|
132 |
+
return predictor
|
133 |
+
|
134 |
+
def create_carvekit_interface():
|
135 |
+
# Check doc strings for more information
|
136 |
+
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
|
137 |
+
batch_size_seg=6,
|
138 |
+
batch_size_matting=1,
|
139 |
+
device="cpu",
|
140 |
+
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
141 |
+
matting_mask_size=2048,
|
142 |
+
trimap_prob_threshold=231,
|
143 |
+
trimap_dilation=30,
|
144 |
+
trimap_erosion_iters=5,
|
145 |
+
fp16=True)
|
146 |
+
|
147 |
+
return interface
|
148 |
+
|
149 |
+
|
150 |
+
# rembg_session = rembg.new_session()
|
151 |
+
rembg_session = create_carvekit_interface()
|
152 |
+
predictor = sam_init()
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
@spaces.GPU(duration=120)
|
157 |
+
def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
|
158 |
+
# set the random seed
|
159 |
+
generator = torch.Generator(device=device).manual_seed(sample_seed)
|
160 |
+
# generator = None
|
161 |
+
T_out = nvs_num
|
162 |
+
T_in = len(eschernet_input_dict['imgs'])
|
163 |
+
####### output pose
|
164 |
+
# TODO choose T_out number of poses sequentially from the spiral
|
165 |
+
xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
|
166 |
+
angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
|
167 |
+
|
168 |
+
####### input's max radius for translation scaling
|
169 |
+
radii = eschernet_input_dict['radii']
|
170 |
+
max_t = np.max(radii)
|
171 |
+
min_t = np.min(radii)
|
172 |
+
|
173 |
+
####### input pose
|
174 |
+
pose_in = []
|
175 |
+
for T_in_index in range(T_in):
|
176 |
+
pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
|
177 |
+
pose[1:3, :] *= -1 # coordinate system conversion
|
178 |
+
pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
|
179 |
+
pose_in.append(torch.from_numpy(pose))
|
180 |
+
|
181 |
+
####### input image
|
182 |
+
img = eschernet_input_dict['imgs'] / 255.
|
183 |
+
img[img[:, :, :, -1] == 0.] = bg_color
|
184 |
+
# TODO batch image_transforms
|
185 |
+
input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
|
186 |
+
|
187 |
+
####### nvs pose
|
188 |
+
pose_out = []
|
189 |
+
for T_out_index in range(T_out):
|
190 |
+
azimuth, polar = angles_out[T_out_index]
|
191 |
+
if CaPE_TYPE == "4DoF":
|
192 |
+
pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
|
193 |
+
elif CaPE_TYPE == "6DoF":
|
194 |
+
pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
|
195 |
+
pose = np.linalg.inv(pose)
|
196 |
+
pose[2, :] *= -1
|
197 |
+
pose_out.append(torch.from_numpy(get_pose(pose)))
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
# [B, T, C, H, W]
|
202 |
+
input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
|
203 |
+
# [B, T, 4]
|
204 |
+
pose_in = np.stack(pose_in)
|
205 |
+
pose_out = np.stack(pose_out)
|
206 |
+
|
207 |
+
if CaPE_TYPE == "6DoF":
|
208 |
+
pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
|
209 |
+
pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
|
210 |
+
pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
|
211 |
+
pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
|
212 |
+
|
213 |
+
pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
|
214 |
+
pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
|
215 |
+
|
216 |
+
input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
|
217 |
+
assert T_in == input_image.shape[0]
|
218 |
+
assert T_in == pose_in.shape[1]
|
219 |
+
assert T_out == pose_out.shape[1]
|
220 |
+
|
221 |
+
# run inference
|
222 |
+
# pipeline.to(device)
|
223 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
224 |
+
image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
|
225 |
+
poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
|
226 |
+
height=h, width=w, T_in=T_in, T_out=T_out,
|
227 |
+
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
228 |
+
output_type="numpy").images
|
229 |
+
|
230 |
+
# save output image
|
231 |
+
output_dir = os.path.join(tmpdirname, "eschernet")
|
232 |
+
if os.path.exists(output_dir):
|
233 |
+
shutil.rmtree(output_dir)
|
234 |
+
os.makedirs(output_dir, exist_ok=True)
|
235 |
+
# # save to N imgs
|
236 |
+
# for i in range(T_out):
|
237 |
+
# imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
|
238 |
+
# make a gif
|
239 |
+
frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
|
240 |
+
# frame_one = frames[0]
|
241 |
+
# frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
|
242 |
+
# save_all=True, duration=50, loop=1)
|
243 |
+
|
244 |
+
# get a video
|
245 |
+
video_path = os.path.join(output_dir, "output.mp4")
|
246 |
+
imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
|
247 |
+
|
248 |
+
|
249 |
+
return video_path
|
250 |
+
|
251 |
+
# TODO mesh it
|
252 |
+
@spaces.GPU(duration=120)
|
253 |
+
def make3d():
|
254 |
+
pass
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
############################ Dust3r as Pose Estimation ############################
|
259 |
+
from scipy.spatial.transform import Rotation
|
260 |
+
import copy
|
261 |
+
|
262 |
+
from dust3r.inference import inference
|
263 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
264 |
+
from dust3r.image_pairs import make_pairs
|
265 |
+
from dust3r.utils.image import load_images, rgb
|
266 |
+
from dust3r.utils.device import to_numpy
|
267 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
268 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
269 |
+
import math
|
270 |
+
|
271 |
+
@spaces.GPU(duration=120)
|
272 |
+
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
273 |
+
cam_color=None, as_pointcloud=False,
|
274 |
+
transparent_cams=False, silent=False, same_focals=False):
|
275 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
|
276 |
+
if not same_focals:
|
277 |
+
assert (len(cams2world) == len(focals))
|
278 |
+
pts3d = to_numpy(pts3d)
|
279 |
+
imgs = to_numpy(imgs)
|
280 |
+
focals = to_numpy(focals)
|
281 |
+
cams2world = to_numpy(cams2world)
|
282 |
+
|
283 |
+
scene = trimesh.Scene()
|
284 |
+
|
285 |
+
# add axes
|
286 |
+
scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
|
287 |
+
|
288 |
+
# full pointcloud
|
289 |
+
if as_pointcloud:
|
290 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
291 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
292 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
293 |
+
scene.add_geometry(pct)
|
294 |
+
else:
|
295 |
+
meshes = []
|
296 |
+
for i in range(len(imgs)):
|
297 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
|
298 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
299 |
+
scene.add_geometry(mesh)
|
300 |
+
|
301 |
+
# add each camera
|
302 |
+
for i, pose_c2w in enumerate(cams2world):
|
303 |
+
if isinstance(cam_color, list):
|
304 |
+
camera_edge_color = cam_color[i]
|
305 |
+
else:
|
306 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
307 |
+
if same_focals:
|
308 |
+
focal = focals[0]
|
309 |
+
else:
|
310 |
+
focal = focals[i]
|
311 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
312 |
+
None if transparent_cams else imgs[i], focal,
|
313 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
314 |
+
|
315 |
+
rot = np.eye(4)
|
316 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
317 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
318 |
+
outfile = os.path.join(outdir, 'scene.glb')
|
319 |
+
if not silent:
|
320 |
+
print('(exporting 3D scene to', outfile, ')')
|
321 |
+
scene.export(file_obj=outfile)
|
322 |
+
return outfile
|
323 |
+
|
324 |
+
@spaces.GPU(duration=120)
|
325 |
+
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
326 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
|
327 |
+
"""
|
328 |
+
extract 3D_model (glb file) from a reconstructed scene
|
329 |
+
"""
|
330 |
+
if scene is None:
|
331 |
+
return None
|
332 |
+
# post processes
|
333 |
+
if clean_depth:
|
334 |
+
scene = scene.clean_pointcloud()
|
335 |
+
if mask_sky:
|
336 |
+
scene = scene.mask_sky()
|
337 |
+
|
338 |
+
# get optimized values from scene
|
339 |
+
rgbimg = to_numpy(scene.imgs)
|
340 |
+
focals = to_numpy(scene.get_focals().cpu())
|
341 |
+
# cams2world = to_numpy(scene.get_im_poses().cpu())
|
342 |
+
# TODO use the vis_poses
|
343 |
+
cams2world = scene.vis_poses
|
344 |
+
|
345 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
346 |
+
# pts3d = to_numpy(scene.get_pts3d())
|
347 |
+
# TODO use the vis_poses
|
348 |
+
pts3d = scene.vis_pts3d
|
349 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
350 |
+
msk = to_numpy(scene.get_masks())
|
351 |
+
|
352 |
+
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
353 |
+
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
|
354 |
+
same_focals=same_focals)
|
355 |
+
|
356 |
+
@spaces.GPU(duration=120)
|
357 |
+
def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
|
358 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
359 |
+
scenegraph_type, winsize, refid, same_focals):
|
360 |
+
"""
|
361 |
+
from a list of images, run dust3r inference, global aligner.
|
362 |
+
then run get_3D_model_from_scene
|
363 |
+
"""
|
364 |
+
silent = False
|
365 |
+
image_size = 224
|
366 |
+
# remove the directory if it already exists
|
367 |
+
outdir = tmpdirname
|
368 |
+
if os.path.exists(outdir):
|
369 |
+
shutil.rmtree(outdir)
|
370 |
+
os.makedirs(outdir, exist_ok=True)
|
371 |
+
imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
|
372 |
+
if len(imgs) == 1:
|
373 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
374 |
+
imgs[1]['idx'] = 1
|
375 |
+
if scenegraph_type == "swin":
|
376 |
+
scenegraph_type = scenegraph_type + "-" + str(winsize)
|
377 |
+
elif scenegraph_type == "oneref":
|
378 |
+
scenegraph_type = scenegraph_type + "-" + str(refid)
|
379 |
+
|
380 |
+
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
|
381 |
+
output = inference(pairs, model, device, batch_size=1, verbose=not silent)
|
382 |
+
|
383 |
+
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
|
384 |
+
scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
|
385 |
+
lr = 0.01
|
386 |
+
|
387 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
388 |
+
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
|
389 |
+
|
390 |
+
# outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
|
391 |
+
# clean_depth, transparent_cams, cam_size, same_focals=same_focals)
|
392 |
+
|
393 |
+
# also return rgb, depth and confidence imgs
|
394 |
+
# depth is normalized with the max value for all images
|
395 |
+
# we apply the jet colormap on the confidence maps
|
396 |
+
rgbimg = scene.imgs
|
397 |
+
# depths = to_numpy(scene.get_depthmaps())
|
398 |
+
# confs = to_numpy([c for c in scene.im_conf])
|
399 |
+
# cmap = pl.get_cmap('jet')
|
400 |
+
# depths_max = max([d.max() for d in depths])
|
401 |
+
# depths = [d / depths_max for d in depths]
|
402 |
+
# confs_max = max([d.max() for d in confs])
|
403 |
+
# confs = [cmap(d / confs_max) for d in confs]
|
404 |
+
|
405 |
+
imgs = []
|
406 |
+
rgbaimg = []
|
407 |
+
for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
|
408 |
+
imgs.append(rgbimg[i])
|
409 |
+
# imgs.append(rgb(depths[i]))
|
410 |
+
# imgs.append(rgb(confs[i]))
|
411 |
+
# imgs.append(imgs_rgba[i])
|
412 |
+
if len(imgs_rgba) == 1 and i == 1:
|
413 |
+
imgs.append(imgs_rgba[0])
|
414 |
+
rgbaimg.append(np.array(imgs_rgba[0]))
|
415 |
+
else:
|
416 |
+
imgs.append(imgs_rgba[i])
|
417 |
+
rgbaimg.append(np.array(imgs_rgba[i]))
|
418 |
+
|
419 |
+
rgbaimg = np.array(rgbaimg)
|
420 |
+
|
421 |
+
# for eschernet
|
422 |
+
# get optimized values from scene
|
423 |
+
rgbimg = to_numpy(scene.imgs)
|
424 |
+
# focals = to_numpy(scene.get_focals().cpu())
|
425 |
+
cams2world = to_numpy(scene.get_im_poses().cpu())
|
426 |
+
|
427 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
428 |
+
pts3d = to_numpy(scene.get_pts3d())
|
429 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
430 |
+
msk = to_numpy(scene.get_masks())
|
431 |
+
obj_mask = rgbaimg[..., 3] > 0
|
432 |
+
|
433 |
+
# TODO set global coordinate system at the center of the scene, z-axis is up
|
434 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
|
435 |
+
pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
|
436 |
+
centroid = np.mean(pts_obj, axis=0) # obj center
|
437 |
+
obj2world = np.eye(4)
|
438 |
+
obj2world[:3, 3] = -centroid # T_wc
|
439 |
+
|
440 |
+
# get z_up vector
|
441 |
+
# TODO fit a plane and get the normal vector
|
442 |
+
pcd = o3d.geometry.PointCloud()
|
443 |
+
pcd.points = o3d.utility.Vector3dVector(pts)
|
444 |
+
plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
|
445 |
+
# get the normalised normal vector dim = 3
|
446 |
+
normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
|
447 |
+
# the normal direction should be pointing up
|
448 |
+
if normal[1] < 0:
|
449 |
+
normal = -normal
|
450 |
+
# print("normal", normal)
|
451 |
+
|
452 |
+
# # TODO z-up 180
|
453 |
+
# z_up = np.array([[1,0,0,0],
|
454 |
+
# [0,-1,0,0],
|
455 |
+
# [0,0,-1,0],
|
456 |
+
# [0,0,0,1]])
|
457 |
+
# obj2world = z_up @ obj2world
|
458 |
+
|
459 |
+
# # avg the y
|
460 |
+
# z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
|
461 |
+
# # import pdb; pdb.set_trace()
|
462 |
+
# rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
|
463 |
+
# rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
|
464 |
+
# rot = Rotation.from_rotvec(rot_angle * rot_axis)
|
465 |
+
# z_up = np.eye(4)
|
466 |
+
# z_up[:3, :3] = rot.as_matrix()
|
467 |
+
|
468 |
+
# get the rotation matrix from normal to z-axis
|
469 |
+
z_axis = np.array([0, 0, 1])
|
470 |
+
rot_axis = np.cross(normal, z_axis)
|
471 |
+
rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
|
472 |
+
rot = Rotation.from_rotvec(rot_angle * rot_axis)
|
473 |
+
z_up = np.eye(4)
|
474 |
+
z_up[:3, :3] = rot.as_matrix()
|
475 |
+
obj2world = z_up @ obj2world
|
476 |
+
# flip 180
|
477 |
+
flip_rot = np.array([[1, 0, 0, 0],
|
478 |
+
[0, -1, 0, 0],
|
479 |
+
[0, 0, -1, 0],
|
480 |
+
[0, 0, 0, 1]])
|
481 |
+
obj2world = flip_rot @ obj2world
|
482 |
+
|
483 |
+
# get new cams2obj
|
484 |
+
cams2obj = []
|
485 |
+
for i, cam2world in enumerate(cams2world):
|
486 |
+
cams2obj.append(obj2world @ cam2world)
|
487 |
+
# TODO transform pts3d to the new coordinate system
|
488 |
+
for i, pts in enumerate(pts3d):
|
489 |
+
pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
|
490 |
+
-1)) \
|
491 |
+
.reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
|
492 |
+
cams2world = np.array(cams2obj)
|
493 |
+
# TODO rewrite hack
|
494 |
+
scene.vis_poses = cams2world.copy()
|
495 |
+
scene.vis_pts3d = pts3d.copy()
|
496 |
+
|
497 |
+
# TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
|
498 |
+
for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
|
499 |
+
np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
|
500 |
+
pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
|
501 |
+
pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
|
502 |
+
# np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
|
503 |
+
# save the min/max radius of camera
|
504 |
+
radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
|
505 |
+
np.save(os.path.join(outdir, "radii.npy"), radii)
|
506 |
+
|
507 |
+
eschernet_input = {"poses": cams2world,
|
508 |
+
"radii": radii,
|
509 |
+
"imgs": rgbaimg}
|
510 |
+
print("got eschernet input")
|
511 |
+
outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
|
512 |
+
clean_depth, transparent_cams, cam_size, same_focals=same_focals)
|
513 |
+
|
514 |
+
return scene, outfile, imgs, eschernet_input
|
515 |
+
|
516 |
+
|
517 |
+
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
|
518 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
519 |
+
max_winsize = max(1, math.ceil((num_files - 1) / 2))
|
520 |
+
if scenegraph_type == "swin":
|
521 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
522 |
+
minimum=1, maximum=max_winsize, step=1, visible=True)
|
523 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
524 |
+
maximum=num_files - 1, step=1, visible=False)
|
525 |
+
elif scenegraph_type == "oneref":
|
526 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
527 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
528 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
529 |
+
maximum=num_files - 1, step=1, visible=True)
|
530 |
+
else:
|
531 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
532 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
533 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
534 |
+
maximum=num_files - 1, step=1, visible=False)
|
535 |
+
return winsize, refid
|
536 |
+
|
537 |
+
|
538 |
+
def get_examples(path):
|
539 |
+
objs = []
|
540 |
+
for obj_name in sorted(os.listdir(path)):
|
541 |
+
img_files = []
|
542 |
+
for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
|
543 |
+
img_files.append(os.path.join(path, obj_name, img_file))
|
544 |
+
objs.append([img_files])
|
545 |
+
print("objs = ", objs)
|
546 |
+
return objs
|
547 |
+
|
548 |
+
def preview_input(inputfiles):
|
549 |
+
if inputfiles is None:
|
550 |
+
return None
|
551 |
+
imgs = []
|
552 |
+
for img_file in inputfiles:
|
553 |
+
img = pl.imread(img_file)
|
554 |
+
imgs.append(img)
|
555 |
+
return imgs
|
556 |
+
|
557 |
+
# def main():
|
558 |
+
# dustr init
|
559 |
+
silent = False
|
560 |
+
image_size = 224
|
561 |
+
weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
|
562 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
|
563 |
+
# dust3r will write the 3D model inside tmpdirname
|
564 |
+
# with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
|
565 |
+
tmpdirname = os.path.join('logs/user_object')
|
566 |
+
# remove the directory if it already exists
|
567 |
+
if os.path.exists(tmpdirname):
|
568 |
+
shutil.rmtree(tmpdirname)
|
569 |
+
os.makedirs(tmpdirname, exist_ok=True)
|
570 |
+
if not silent:
|
571 |
+
print('Outputing stuff in', tmpdirname)
|
572 |
+
|
573 |
+
_HEADER_ = '''
|
574 |
+
<h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
|
575 |
+
<b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
|
576 |
+
|
577 |
+
Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
|
578 |
+
|
579 |
+
<a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
|
580 |
+
<a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
|
581 |
+
<a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
|
582 |
+
|
583 |
+
<h4><b>Tips:</b></h4>
|
584 |
+
|
585 |
+
- Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
|
586 |
+
|
587 |
+
- Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
|
588 |
+
|
589 |
+
- The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
|
590 |
+
|
591 |
+
- The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
|
592 |
+
|
593 |
+
'''
|
594 |
+
|
595 |
+
_CITE_ = r"""
|
596 |
+
📝 <b>Citation</b>:
|
597 |
+
```bibtex
|
598 |
+
@article{kong2024eschernet,
|
599 |
+
title={EscherNet: A Generative Model for Scalable View Synthesis},
|
600 |
+
author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
|
601 |
+
journal={arXiv preprint arXiv:2402.03908},
|
602 |
+
year={2024}
|
603 |
+
}
|
604 |
+
```
|
605 |
+
"""
|
606 |
+
|
607 |
+
with gr.Blocks() as demo:
|
608 |
+
gr.Markdown(_HEADER_)
|
609 |
+
# mv_images = gr.State()
|
610 |
+
scene = gr.State(None)
|
611 |
+
eschernet_input = gr.State(None)
|
612 |
+
with gr.Row(variant="panel"):
|
613 |
+
# left column
|
614 |
with gr.Column():
|
615 |
+
with gr.Row():
|
616 |
+
input_image = gr.File(file_count="multiple")
|
617 |
+
with gr.Row():
|
618 |
+
run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
|
619 |
+
with gr.Row():
|
620 |
+
processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
|
621 |
+
with gr.Row(variant="panel"):
|
622 |
+
# input examples under "examples" folder
|
623 |
+
gr.Examples(
|
624 |
+
examples=get_examples('examples'),
|
625 |
+
inputs=[input_image],
|
626 |
+
label="Examples (click one set of images to start!)",
|
627 |
+
examples_per_page=20
|
628 |
+
)
|
629 |
+
|
630 |
+
|
631 |
+
|
632 |
+
|
633 |
+
|
634 |
+
# right column
|
635 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
|
637 |
+
with gr.Row():
|
638 |
+
outmodel = gr.Model3D()
|
639 |
+
|
640 |
+
with gr.Row():
|
641 |
+
gr.Markdown('''
|
642 |
+
<h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
|
643 |
+
''')
|
644 |
+
|
645 |
+
with gr.Row():
|
646 |
+
with gr.Group():
|
647 |
+
do_remove_background = gr.Checkbox(
|
648 |
+
label="Remove Background", value=True
|
649 |
+
)
|
650 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
651 |
+
|
652 |
+
sample_steps = gr.Slider(
|
653 |
+
label="Sample Steps",
|
654 |
+
minimum=30,
|
655 |
+
maximum=75,
|
656 |
+
value=50,
|
657 |
+
step=5,
|
658 |
+
visible=False
|
659 |
+
)
|
660 |
+
|
661 |
+
nvs_num = gr.Slider(
|
662 |
+
label="Number of Novel Views",
|
663 |
+
minimum=5,
|
664 |
+
maximum=100,
|
665 |
+
value=30,
|
666 |
+
step=1
|
667 |
+
)
|
668 |
+
|
669 |
+
nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
|
670 |
+
value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
|
671 |
+
|
672 |
+
with gr.Row():
|
673 |
+
gr.Markdown('''
|
674 |
+
<h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
|
675 |
+
''')
|
676 |
+
|
677 |
+
with gr.Row():
|
678 |
+
submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
|
679 |
+
|
680 |
+
with gr.Row():
|
681 |
+
with gr.Column():
|
682 |
+
output_video = gr.Video(
|
683 |
+
label="video", format="mp4",
|
684 |
+
width=379,
|
685 |
+
autoplay=True,
|
686 |
+
interactive=False
|
687 |
+
)
|
688 |
+
|
689 |
+
with gr.Row():
|
690 |
+
gr.Markdown('''
|
691 |
+
<h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
|
692 |
+
''')
|
693 |
+
|
694 |
+
gr.Markdown(_CITE_)
|
695 |
+
|
696 |
+
# set dust3r parameter invisible to be clean
|
697 |
+
with gr.Column():
|
698 |
+
with gr.Row():
|
699 |
+
schedule = gr.Dropdown(["linear", "cosine"],
|
700 |
+
value='linear', label="schedule", info="For global alignment!", visible=False)
|
701 |
+
niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
|
702 |
+
label="num_iterations", info="For global alignment!", visible=False)
|
703 |
+
scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
|
704 |
+
value='complete', label="Scenegraph",
|
705 |
+
info="Define how to make pairs",
|
706 |
+
interactive=True, visible=False)
|
707 |
+
same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
|
708 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
|
709 |
+
minimum=1, maximum=1, step=1, visible=False)
|
710 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
711 |
+
|
712 |
+
with gr.Row():
|
713 |
+
# adjust the confidence threshold
|
714 |
+
min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
|
715 |
+
# adjust the camera size in the output pointcloud
|
716 |
+
cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
|
717 |
+
with gr.Row():
|
718 |
+
as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
|
719 |
+
# two post process implemented
|
720 |
+
mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
|
721 |
+
clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
|
722 |
+
transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
|
723 |
+
|
724 |
+
# events
|
725 |
+
# scenegraph_type.change(set_scenegraph_options,
|
726 |
+
# inputs=[input_image, winsize, refid, scenegraph_type],
|
727 |
+
# outputs=[winsize, refid])
|
728 |
+
# min_conf_thr.release(fn=model_from_scene_fun,
|
729 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
730 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
731 |
+
# outputs=outmodel)
|
732 |
+
# cam_size.change(fn=model_from_scene_fun,
|
733 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
734 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
735 |
+
# outputs=outmodel)
|
736 |
+
# as_pointcloud.change(fn=model_from_scene_fun,
|
737 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
738 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
739 |
+
# outputs=outmodel)
|
740 |
+
# mask_sky.change(fn=model_from_scene_fun,
|
741 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
742 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
743 |
+
# outputs=outmodel)
|
744 |
+
# clean_depth.change(fn=model_from_scene_fun,
|
745 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
746 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
747 |
+
# outputs=outmodel)
|
748 |
+
# transparent_cams.change(model_from_scene_fun,
|
749 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
750 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
751 |
+
# outputs=outmodel)
|
752 |
+
# run_dust3r.click(fn=recon_fun,
|
753 |
+
# inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
|
754 |
+
# mask_sky, clean_depth, transparent_cams, cam_size,
|
755 |
+
# scenegraph_type, winsize, refid, same_focals],
|
756 |
+
# outputs=[scene, outmodel, processed_image, eschernet_input])
|
757 |
+
|
758 |
+
# events
|
759 |
+
input_image.change(set_scenegraph_options,
|
760 |
+
inputs=[input_image, winsize, refid, scenegraph_type],
|
761 |
+
outputs=[winsize, refid])
|
762 |
+
run_dust3r.click(fn=get_reconstructed_scene,
|
763 |
+
inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
|
764 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
765 |
+
scenegraph_type, winsize, refid, same_focals],
|
766 |
+
outputs=[scene, outmodel, processed_image, eschernet_input])
|
767 |
+
|
768 |
+
|
769 |
+
# events
|
770 |
+
input_image.change(fn=preview_input,
|
771 |
+
inputs=[input_image],
|
772 |
+
outputs=[processed_image])
|
773 |
+
|
774 |
+
submit.click(fn=run_eschernet,
|
775 |
+
inputs=[eschernet_input, sample_steps, sample_seed,
|
776 |
+
nvs_num, nvs_mode],
|
777 |
+
outputs=[output_video])
|
778 |
+
|
779 |
+
|
780 |
+
|
781 |
+
# demo.queue(max_size=10)
|
782 |
+
# demo.launch(share=True, server_name="0.0.0.0", server_port=None)
|
783 |
+
demo.queue(max_size=10).launch()
|
784 |
|
785 |
+
# if __name__ == '__main__':
|
786 |
+
# main()
|