Spaces:
Running
on
Zero
Running
on
Zero
Fix bug with tmp folder for multiuser
Browse files- app.py +14 -0
- demo/demo_globals.py +2 -4
- demo/gs_demo.py +8 -4
- demo/gs_train.py +1 -1
- demo/mast3r_demo.py +5 -7
app.py
CHANGED
@@ -2,11 +2,22 @@ import sys
|
|
2 |
sys.path.append('wild-gaussian-splatting/mast3r/')
|
3 |
sys.path.append('demo/')
|
4 |
|
|
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
from mast3r.demo import get_args_parser
|
8 |
from mast3r_demo import mast3r_demo_tab
|
9 |
from gs_demo import gs_demo_tab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
if __name__ == '__main__':
|
12 |
with gr.Blocks() as demo:
|
@@ -30,4 +41,7 @@ if __name__ == '__main__':
|
|
30 |
with gr.Tab("3DGS"):
|
31 |
gs_demo_tab()
|
32 |
|
|
|
|
|
|
|
33 |
demo.launch(show_error=True, share=None, server_name=None, server_port=None)
|
|
|
2 |
sys.path.append('wild-gaussian-splatting/mast3r/')
|
3 |
sys.path.append('demo/')
|
4 |
|
5 |
+
import os
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
from mast3r.demo import get_args_parser
|
9 |
from mast3r_demo import mast3r_demo_tab
|
10 |
from gs_demo import gs_demo_tab
|
11 |
+
from demo_globals import CACHE_PATH
|
12 |
+
import shutil
|
13 |
+
|
14 |
+
def start_session(req: gr.Request):
|
15 |
+
user_dir = os.path.join(CACHE_PATH, str(req.session_hash))
|
16 |
+
os.makedirs(user_dir, exist_ok=True)
|
17 |
+
|
18 |
+
def end_session(req: gr.Request):
|
19 |
+
user_dir = os.path.join(CACHE_PATH, str(req.session_hash))
|
20 |
+
shutil.rmtree(user_dir)
|
21 |
|
22 |
if __name__ == '__main__':
|
23 |
with gr.Blocks() as demo:
|
|
|
41 |
with gr.Tab("3DGS"):
|
42 |
gs_demo_tab()
|
43 |
|
44 |
+
demo.load(start_session)
|
45 |
+
demo.unload(end_session)
|
46 |
+
|
47 |
demo.launch(show_error=True, share=None, server_name=None, server_port=None)
|
demo/demo_globals.py
CHANGED
@@ -9,10 +9,8 @@ from mast3r.model import AsymmetricMASt3R
|
|
9 |
|
10 |
DATASET_DIR = "colmap_data"
|
11 |
weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
|
12 |
-
|
13 |
-
|
14 |
-
chkpt_tag = hash_md5(weights_path)
|
15 |
-
CACHE_PATH = os.path.join(tmpdirname.name, chkpt_tag)
|
16 |
os.makedirs(CACHE_PATH, exist_ok=True)
|
17 |
|
18 |
DEVICE = device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
9 |
|
10 |
DATASET_DIR = "colmap_data"
|
11 |
weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
|
12 |
+
|
13 |
+
CACHE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
|
|
|
14 |
os.makedirs(CACHE_PATH, exist_ok=True)
|
15 |
|
16 |
DEVICE = device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
demo/gs_demo.py
CHANGED
@@ -13,7 +13,7 @@ def get_dataset_folders(datasets_path):
|
|
13 |
|
14 |
def gs_demo_tab():
|
15 |
# datasets_path = "/app/data/scenes/"
|
16 |
-
|
17 |
def start_training(selected_folder, *args):
|
18 |
selected_data_path = os.path.join(datasets_path, selected_folder)
|
19 |
return train(selected_data_path, *args)
|
@@ -55,8 +55,10 @@ def gs_demo_tab():
|
|
55 |
refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
|
56 |
dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
|
57 |
|
58 |
-
def update_dataset_dropdown():
|
59 |
-
|
|
|
|
|
60 |
# Update the dataset folders list
|
61 |
dataset_folders = get_dataset_folders(dataset_path)
|
62 |
print("dataset_folders", dataset_folders)
|
@@ -108,7 +110,9 @@ def gs_demo_tab():
|
|
108 |
load_model_button = gr.Button("Load 3D Model", interactive=False)
|
109 |
output = gr.Model3D(label="3D Model Output", visible=False)
|
110 |
|
111 |
-
def handle_training_complete(selected_folder, *args):
|
|
|
|
|
112 |
# Construct the full path to the selected dataset
|
113 |
selected_data_path = os.path.join(dataset_path, selected_folder)
|
114 |
# Call the training function with the full path
|
|
|
13 |
|
14 |
def gs_demo_tab():
|
15 |
# datasets_path = "/app/data/scenes/"
|
16 |
+
|
17 |
def start_training(selected_folder, *args):
|
18 |
selected_data_path = os.path.join(datasets_path, selected_folder)
|
19 |
return train(selected_data_path, *args)
|
|
|
55 |
refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
|
56 |
dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
|
57 |
|
58 |
+
def update_dataset_dropdown(req: gr.Request):
|
59 |
+
USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
|
60 |
+
print("update_dataset_dropdown, user_path", USER_DIR)
|
61 |
+
dataset_path = os.path.join(USER_DIR, DATASET_DIR)
|
62 |
# Update the dataset folders list
|
63 |
dataset_folders = get_dataset_folders(dataset_path)
|
64 |
print("dataset_folders", dataset_folders)
|
|
|
110 |
load_model_button = gr.Button("Load 3D Model", interactive=False)
|
111 |
output = gr.Model3D(label="3D Model Output", visible=False)
|
112 |
|
113 |
+
def handle_training_complete(selected_folder, req: gr.Request, *args):
|
114 |
+
USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
|
115 |
+
dataset_path = os.path.join(USER_DIR, DATASET_DIR)
|
116 |
# Construct the full path to the selected dataset
|
117 |
selected_data_path = os.path.join(dataset_path, selected_folder)
|
118 |
# Call the training function with the full path
|
demo/gs_train.py
CHANGED
@@ -160,7 +160,7 @@ def train(
|
|
160 |
|
161 |
# Log and save
|
162 |
if (iteration == opt.iterations):
|
163 |
-
point_cloud_path = os.path.join(os.path.join(
|
164 |
print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
|
165 |
scene.save(iteration)
|
166 |
|
|
|
160 |
|
161 |
# Log and save
|
162 |
if (iteration == opt.iterations):
|
163 |
+
point_cloud_path = os.path.join(os.path.join(data_source_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
|
164 |
print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
|
165 |
scene.save(iteration)
|
166 |
|
demo/mast3r_demo.py
CHANGED
@@ -176,11 +176,12 @@ def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False, mask_i
|
|
176 |
@spaces.GPU(duration=20)
|
177 |
def get_reconstructed_scene(snapshot, current_scene_state,
|
178 |
min_conf_thr, matching_conf_thr,
|
179 |
-
as_pointcloud, cam_size, shared_intrinsics, clean_depth, filelist, **kw):
|
180 |
"""
|
181 |
from a list of images, run mast3r inference, sparse global aligner.
|
182 |
then run get_3D_model_from_scene
|
183 |
"""
|
|
|
184 |
image_size = 512
|
185 |
imgs = load_images(filelist, size=image_size, verbose=not SILENT)
|
186 |
if len(imgs) == 1:
|
@@ -216,7 +217,7 @@ def get_reconstructed_scene(snapshot, current_scene_state,
|
|
216 |
scene_graph = '-'.join(scene_graph_params)
|
217 |
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
|
218 |
|
219 |
-
base_cache_dir = os.path.join(
|
220 |
os.makedirs(base_cache_dir, exist_ok=True)
|
221 |
def get_next_dir(base_dir):
|
222 |
run_counter = 0
|
@@ -235,7 +236,7 @@ def get_reconstructed_scene(snapshot, current_scene_state,
|
|
235 |
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
|
236 |
matching_conf_thr=matching_conf_thr, **kw)
|
237 |
|
238 |
-
base_colmapdata_dir = os.path.join(
|
239 |
os.makedirs(base_colmapdata_dir, exist_ok=True)
|
240 |
colmap_data_dir = get_next_dir(base_colmapdata_dir)
|
241 |
#
|
@@ -245,7 +246,7 @@ def get_reconstructed_scene(snapshot, current_scene_state,
|
|
245 |
current_scene_state.outfile_name is not None:
|
246 |
outfile_name = current_scene_state.outfile_name
|
247 |
else:
|
248 |
-
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=
|
249 |
|
250 |
scene_state = SparseGAState(cache_dir, outfile_name)
|
251 |
outfile = get_3D_model_from_scene(scene, scene_state, min_conf_thr, as_pointcloud, mask_sky,
|
@@ -258,9 +259,6 @@ def get_reconstructed_scene(snapshot, current_scene_state,
|
|
258 |
|
259 |
|
260 |
def mast3r_demo_tab():
|
261 |
-
if not SILENT:
|
262 |
-
print('Outputing stuff in', CACHE_PATH)
|
263 |
-
|
264 |
def get_context():
|
265 |
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
266 |
title = "MASt3R Demo"
|
|
|
176 |
@spaces.GPU(duration=20)
|
177 |
def get_reconstructed_scene(snapshot, current_scene_state,
|
178 |
min_conf_thr, matching_conf_thr,
|
179 |
+
as_pointcloud, cam_size, shared_intrinsics, clean_depth, filelist, req: gradio.Request, **kw):
|
180 |
"""
|
181 |
from a list of images, run mast3r inference, sparse global aligner.
|
182 |
then run get_3D_model_from_scene
|
183 |
"""
|
184 |
+
USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
|
185 |
image_size = 512
|
186 |
imgs = load_images(filelist, size=image_size, verbose=not SILENT)
|
187 |
if len(imgs) == 1:
|
|
|
217 |
scene_graph = '-'.join(scene_graph_params)
|
218 |
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
|
219 |
|
220 |
+
base_cache_dir = os.path.join(USER_DIR, 'cache')
|
221 |
os.makedirs(base_cache_dir, exist_ok=True)
|
222 |
def get_next_dir(base_dir):
|
223 |
run_counter = 0
|
|
|
236 |
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
|
237 |
matching_conf_thr=matching_conf_thr, **kw)
|
238 |
|
239 |
+
base_colmapdata_dir = os.path.join(USER_DIR, DATASET_DIR)
|
240 |
os.makedirs(base_colmapdata_dir, exist_ok=True)
|
241 |
colmap_data_dir = get_next_dir(base_colmapdata_dir)
|
242 |
#
|
|
|
246 |
current_scene_state.outfile_name is not None:
|
247 |
outfile_name = current_scene_state.outfile_name
|
248 |
else:
|
249 |
+
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=USER_DIR)
|
250 |
|
251 |
scene_state = SparseGAState(cache_dir, outfile_name)
|
252 |
outfile = get_3D_model_from_scene(scene, scene_state, min_conf_thr, as_pointcloud, mask_sky,
|
|
|
259 |
|
260 |
|
261 |
def mast3r_demo_tab():
|
|
|
|
|
|
|
262 |
def get_context():
|
263 |
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
264 |
title = "MASt3R Demo"
|