ostapagon commited on
Commit
bd0195a
·
1 Parent(s): c6a15a0

Fix cache_dir error

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. demo/demo_globals.py +2 -1
  3. demo/gs_train.py +2 -2
  4. demo/mast3r_demo.py +3 -3
app.py CHANGED
@@ -28,7 +28,7 @@ if __name__ == '__main__':
28
  with gr.Tab("MASt3R Demo"):
29
  mast3r_demo_tab()
30
  with gr.Tab("Gaussian Splatting Demo"):
31
- gs_demo_tab(cache_path)
32
 
33
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
34
  # demo.launch(show_error=True, share=None, server_name='0.0.0.0', server_port=5555)
 
28
  with gr.Tab("MASt3R Demo"):
29
  mast3r_demo_tab()
30
  with gr.Tab("Gaussian Splatting Demo"):
31
+ gs_demo_tab()
32
 
33
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
34
  # demo.launch(show_error=True, share=None, server_name='0.0.0.0', server_port=5555)
demo/demo_globals.py CHANGED
@@ -7,8 +7,9 @@ import torch
7
  from mast3r.utils.misc import hash_md5
8
  from mast3r.model import AsymmetricMASt3R
9
 
 
10
  weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
11
- # weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
12
  tmpdirname = tempfile.TemporaryDirectory(suffix='demo')
13
  chkpt_tag = hash_md5(weights_path)
14
  CACHE_PATH = os.path.join(tmpdirname.name, chkpt_tag)
 
7
  from mast3r.utils.misc import hash_md5
8
  from mast3r.model import AsymmetricMASt3R
9
 
10
+ DATASET_DIR = "colmap_data"
11
  weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
12
+ weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
13
  tmpdirname = tempfile.TemporaryDirectory(suffix='demo')
14
  chkpt_tag = hash_md5(weights_path)
15
  CACHE_PATH = os.path.join(tmpdirname.name, chkpt_tag)
demo/gs_train.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
  import importlib.util
9
  from dataclasses import dataclass, field
10
 
11
- # import spaces
12
 
13
 
14
  @dataclass
@@ -60,7 +60,7 @@ class TrainingArgs:
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
- # @spaces.GPU(duration=20)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
 
8
  import importlib.util
9
  from dataclasses import dataclass, field
10
 
11
+ import spaces
12
 
13
 
14
  @dataclass
 
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
+ @spaces.GPU(duration=20)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
demo/mast3r_demo.py CHANGED
@@ -34,7 +34,7 @@ import matplotlib.pyplot as pl
34
  import torch
35
 
36
 
37
- from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT
38
 
39
  class SparseGAState():
40
  def __init__(self, cache_dir=None, outfile_name=None):
@@ -219,12 +219,12 @@ def get_reconstructed_scene(image_size, current_scene_state,
219
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
220
  matching_conf_thr=matching_conf_thr, **kw)
221
 
222
- base_colmapdata_dir = os.path.join(CACHE_PATH, 'colmap_data')
223
  os.makedirs(base_colmapdata_dir, exist_ok=True)
224
  colmap_data_dir = get_next_dir(base_colmapdata_dir)
225
  #
226
  save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
227
-
228
  if current_scene_state is not None and \
229
  current_scene_state.outfile_name is not None:
230
  outfile_name = current_scene_state.outfile_name
 
34
  import torch
35
 
36
 
37
+ from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR
38
 
39
  class SparseGAState():
40
  def __init__(self, cache_dir=None, outfile_name=None):
 
219
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
220
  matching_conf_thr=matching_conf_thr, **kw)
221
 
222
+ base_colmapdata_dir = os.path.join(CACHE_PATH, DATASET_DIR)
223
  os.makedirs(base_colmapdata_dir, exist_ok=True)
224
  colmap_data_dir = get_next_dir(base_colmapdata_dir)
225
  #
226
  save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
227
+
228
  if current_scene_state is not None and \
229
  current_scene_state.outfile_name is not None:
230
  outfile_name = current_scene_state.outfile_name