ostapagon commited on
Commit
89e21df
·
1 Parent(s): fccded4

Fix bug with tmp folder for multiuser

Browse files
Files changed (5) hide show
  1. app.py +14 -0
  2. demo/demo_globals.py +2 -4
  3. demo/gs_demo.py +8 -4
  4. demo/gs_train.py +1 -1
  5. demo/mast3r_demo.py +5 -7
app.py CHANGED
@@ -2,11 +2,22 @@ import sys
2
  sys.path.append('wild-gaussian-splatting/mast3r/')
3
  sys.path.append('demo/')
4
 
 
5
  import gradio as gr
6
  import torch
7
  from mast3r.demo import get_args_parser
8
  from mast3r_demo import mast3r_demo_tab
9
  from gs_demo import gs_demo_tab
 
 
 
 
 
 
 
 
 
 
10
 
11
  if __name__ == '__main__':
12
  with gr.Blocks() as demo:
@@ -30,4 +41,7 @@ if __name__ == '__main__':
30
  with gr.Tab("3DGS"):
31
  gs_demo_tab()
32
 
 
 
 
33
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
 
2
  sys.path.append('wild-gaussian-splatting/mast3r/')
3
  sys.path.append('demo/')
4
 
5
+ import os
6
  import gradio as gr
7
  import torch
8
  from mast3r.demo import get_args_parser
9
  from mast3r_demo import mast3r_demo_tab
10
  from gs_demo import gs_demo_tab
11
+ from demo_globals import CACHE_PATH
12
+ import shutil
13
+
14
+ def start_session(req: gr.Request):
15
+ user_dir = os.path.join(CACHE_PATH, str(req.session_hash))
16
+ os.makedirs(user_dir, exist_ok=True)
17
+
18
+ def end_session(req: gr.Request):
19
+ user_dir = os.path.join(CACHE_PATH, str(req.session_hash))
20
+ shutil.rmtree(user_dir)
21
 
22
  if __name__ == '__main__':
23
  with gr.Blocks() as demo:
 
41
  with gr.Tab("3DGS"):
42
  gs_demo_tab()
43
 
44
+ demo.load(start_session)
45
+ demo.unload(end_session)
46
+
47
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
demo/demo_globals.py CHANGED
@@ -9,10 +9,8 @@ from mast3r.model import AsymmetricMASt3R
9
 
10
  DATASET_DIR = "colmap_data"
11
  weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
12
- # weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
13
- tmpdirname = tempfile.TemporaryDirectory(suffix='demo')
14
- chkpt_tag = hash_md5(weights_path)
15
- CACHE_PATH = os.path.join(tmpdirname.name, chkpt_tag)
16
  os.makedirs(CACHE_PATH, exist_ok=True)
17
 
18
  DEVICE = device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
9
 
10
  DATASET_DIR = "colmap_data"
11
  weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
12
+
13
+ CACHE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
 
14
  os.makedirs(CACHE_PATH, exist_ok=True)
15
 
16
  DEVICE = device = 'cuda' if torch.cuda.is_available() else 'cpu'
demo/gs_demo.py CHANGED
@@ -13,7 +13,7 @@ def get_dataset_folders(datasets_path):
13
 
14
  def gs_demo_tab():
15
  # datasets_path = "/app/data/scenes/"
16
- dataset_path = os.path.join(CACHE_PATH, DATASET_DIR)
17
  def start_training(selected_folder, *args):
18
  selected_data_path = os.path.join(datasets_path, selected_folder)
19
  return train(selected_data_path, *args)
@@ -55,8 +55,10 @@ def gs_demo_tab():
55
  refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
56
  dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
57
 
58
- def update_dataset_dropdown():
59
- print("update_dataset_dropdown, cache_path", CACHE_PATH)
 
 
60
  # Update the dataset folders list
61
  dataset_folders = get_dataset_folders(dataset_path)
62
  print("dataset_folders", dataset_folders)
@@ -108,7 +110,9 @@ def gs_demo_tab():
108
  load_model_button = gr.Button("Load 3D Model", interactive=False)
109
  output = gr.Model3D(label="3D Model Output", visible=False)
110
 
111
- def handle_training_complete(selected_folder, *args):
 
 
112
  # Construct the full path to the selected dataset
113
  selected_data_path = os.path.join(dataset_path, selected_folder)
114
  # Call the training function with the full path
 
13
 
14
  def gs_demo_tab():
15
  # datasets_path = "/app/data/scenes/"
16
+
17
  def start_training(selected_folder, *args):
18
  selected_data_path = os.path.join(datasets_path, selected_folder)
19
  return train(selected_data_path, *args)
 
55
  refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
56
  dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
57
 
58
+ def update_dataset_dropdown(req: gr.Request):
59
+ USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
60
+ print("update_dataset_dropdown, user_path", USER_DIR)
61
+ dataset_path = os.path.join(USER_DIR, DATASET_DIR)
62
  # Update the dataset folders list
63
  dataset_folders = get_dataset_folders(dataset_path)
64
  print("dataset_folders", dataset_folders)
 
110
  load_model_button = gr.Button("Load 3D Model", interactive=False)
111
  output = gr.Model3D(label="3D Model Output", visible=False)
112
 
113
+ def handle_training_complete(selected_folder, req: gr.Request, *args):
114
+ USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
115
+ dataset_path = os.path.join(USER_DIR, DATASET_DIR)
116
  # Construct the full path to the selected dataset
117
  selected_data_path = os.path.join(dataset_path, selected_folder)
118
  # Call the training function with the full path
demo/gs_train.py CHANGED
@@ -160,7 +160,7 @@ def train(
160
 
161
  # Log and save
162
  if (iteration == opt.iterations):
163
- point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
164
  print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
165
  scene.save(iteration)
166
 
 
160
 
161
  # Log and save
162
  if (iteration == opt.iterations):
163
+ point_cloud_path = os.path.join(os.path.join(data_source_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
164
  print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
165
  scene.save(iteration)
166
 
demo/mast3r_demo.py CHANGED
@@ -176,11 +176,12 @@ def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False, mask_i
176
  @spaces.GPU(duration=20)
177
  def get_reconstructed_scene(snapshot, current_scene_state,
178
  min_conf_thr, matching_conf_thr,
179
- as_pointcloud, cam_size, shared_intrinsics, clean_depth, filelist, **kw):
180
  """
181
  from a list of images, run mast3r inference, sparse global aligner.
182
  then run get_3D_model_from_scene
183
  """
 
184
  image_size = 512
185
  imgs = load_images(filelist, size=image_size, verbose=not SILENT)
186
  if len(imgs) == 1:
@@ -216,7 +217,7 @@ def get_reconstructed_scene(snapshot, current_scene_state,
216
  scene_graph = '-'.join(scene_graph_params)
217
  pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
218
 
219
- base_cache_dir = os.path.join(CACHE_PATH, 'cache')
220
  os.makedirs(base_cache_dir, exist_ok=True)
221
  def get_next_dir(base_dir):
222
  run_counter = 0
@@ -235,7 +236,7 @@ def get_reconstructed_scene(snapshot, current_scene_state,
235
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
236
  matching_conf_thr=matching_conf_thr, **kw)
237
 
238
- base_colmapdata_dir = os.path.join(CACHE_PATH, DATASET_DIR)
239
  os.makedirs(base_colmapdata_dir, exist_ok=True)
240
  colmap_data_dir = get_next_dir(base_colmapdata_dir)
241
  #
@@ -245,7 +246,7 @@ def get_reconstructed_scene(snapshot, current_scene_state,
245
  current_scene_state.outfile_name is not None:
246
  outfile_name = current_scene_state.outfile_name
247
  else:
248
- outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=CACHE_PATH)
249
 
250
  scene_state = SparseGAState(cache_dir, outfile_name)
251
  outfile = get_3D_model_from_scene(scene, scene_state, min_conf_thr, as_pointcloud, mask_sky,
@@ -258,9 +259,6 @@ def get_reconstructed_scene(snapshot, current_scene_state,
258
 
259
 
260
  def mast3r_demo_tab():
261
- if not SILENT:
262
- print('Outputing stuff in', CACHE_PATH)
263
-
264
  def get_context():
265
  css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
266
  title = "MASt3R Demo"
 
176
  @spaces.GPU(duration=20)
177
  def get_reconstructed_scene(snapshot, current_scene_state,
178
  min_conf_thr, matching_conf_thr,
179
+ as_pointcloud, cam_size, shared_intrinsics, clean_depth, filelist, req: gradio.Request, **kw):
180
  """
181
  from a list of images, run mast3r inference, sparse global aligner.
182
  then run get_3D_model_from_scene
183
  """
184
+ USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
185
  image_size = 512
186
  imgs = load_images(filelist, size=image_size, verbose=not SILENT)
187
  if len(imgs) == 1:
 
217
  scene_graph = '-'.join(scene_graph_params)
218
  pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
219
 
220
+ base_cache_dir = os.path.join(USER_DIR, 'cache')
221
  os.makedirs(base_cache_dir, exist_ok=True)
222
  def get_next_dir(base_dir):
223
  run_counter = 0
 
236
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
237
  matching_conf_thr=matching_conf_thr, **kw)
238
 
239
+ base_colmapdata_dir = os.path.join(USER_DIR, DATASET_DIR)
240
  os.makedirs(base_colmapdata_dir, exist_ok=True)
241
  colmap_data_dir = get_next_dir(base_colmapdata_dir)
242
  #
 
246
  current_scene_state.outfile_name is not None:
247
  outfile_name = current_scene_state.outfile_name
248
  else:
249
+ outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=USER_DIR)
250
 
251
  scene_state = SparseGAState(cache_dir, outfile_name)
252
  outfile = get_3D_model_from_scene(scene, scene_state, min_conf_thr, as_pointcloud, mask_sky,
 
259
 
260
 
261
  def mast3r_demo_tab():
 
 
 
262
  def get_context():
263
  css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
264
  title = "MASt3R Demo"