mast3r-3dgs / demo /gs_demo.py
ostapagon's picture
Add more fixes to demo files
665b2f0
raw
history blame
7.32 kB
import gradio as gr
from gs_train import train
import os
DATASET_DIR = "colmap_data"
def get_dataset_folders(datasets_path):
try:
return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
except FileNotFoundError:
return []
def gs_demo_tab(cache_path):
# datasets_path = "/app/data/scenes/"
dataset_path = os.path.join(cache_path, DATASET_DIR)
def start_training(selected_folder, *args):
selected_data_path = os.path.join(datasets_path, selected_folder)
return train(selected_data_path, *args)
def get_context():
return gr.Blocks(delete_cache=(True, True))
with get_context() as gs_demo:
gr.Markdown("""
<style>
.fixed-size-video video {
max-height: 400px !important;
height: 400px !important;
object-fit: contain;
}
</style>
""")
gr.Markdown("# Gaussian Splatting Training Demo")
refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
def update_dataset_dropdown():
print("update_dataset_dropdown, cache_path", cache_path)
# Update the dataset folders list
dataset_folders = get_dataset_folders(dataset_path)
# dataset_folders = "/app/data/scenes/"
print("dataset_folders", dataset_folders)
# Only set a default value if there are folders available
default_value = dataset_folders[0] if dataset_folders else None
return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value)
# Set the update function to be called when the refresh button is clicked
refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown)
with gr.Accordion("Model Parameters", open=False):
with gr.Row():
with gr.Column():
sh_degree = gr.Number(label="SH Degree", value=3)
model_path = gr.Textbox(label="Model Path", value="")
images = gr.Textbox(label="Images", value="images")
resolution = gr.Number(label="Resolution", value=-1)
white_background = gr.Checkbox(label="White Background", value=True)
data_device = gr.Dropdown(label="Data Device", choices=["cuda", "cpu"], value="cuda")
eval = gr.Checkbox(label="Eval", value=False)
with gr.Accordion("Pipeline Parameters", open=False):
with gr.Row():
with gr.Column():
convert_SHs_python = gr.Checkbox(label="Convert SHs Python", value=False)
compute_cov3D_python = gr.Checkbox(label="Compute Cov3D Python", value=False)
debug = gr.Checkbox(label="Debug", value=False)
with gr.Accordion("Optimization Parameters", open=False):
with gr.Row():
with gr.Column():
iterations = gr.Number(label="Iterations", value=1000)
position_lr_init = gr.Number(label="Position LR Init", value=0.00016)
position_lr_final = gr.Number(label="Position LR Final", value=0.0000016)
position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.01)
position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=30000)
with gr.Column():
feature_lr = gr.Number(label="Feature LR", value=0.0025)
opacity_lr = gr.Number(label="Opacity LR", value=0.05)
scaling_lr = gr.Number(label="Scaling LR", value=0.005)
rotation_lr = gr.Number(label="Rotation LR", value=0.001)
percent_dense = gr.Number(label="Percent Dense", value=0.01)
with gr.Column():
lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2)
densification_interval = gr.Number(label="Densification Interval", value=100)
opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000)
densify_from_iter = gr.Number(label="Densify From Iter", value=500)
densify_until_iter = gr.Number(label="Densify Until Iter", value=15000)
densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002)
random_background = gr.Checkbox(label="Random Background", value=False)
start_button = gr.Button("Start Training")
# Add state variable to store model path
model_path_state = gr.State()
# Add video output and load model button with fixed scale
video_output = gr.Video(
label="Training Progress",
height=400, # Fixed height
width="100%", # Full width of container
autoplay=False, # Prevent autoplay
show_label=True,
container=True,
elem_classes="fixed-size-video" # Add custom class for potential CSS
)
load_model_button = gr.Button("Load 3D Model", interactive=False)
output = gr.Model3D(label="3D Model Output", visible=False)
def handle_training_complete(selected_folder, *args):
# Construct the full path to the selected dataset
selected_data_path = os.path.join(dataset_path, selected_folder)
# Call the training function with the full path
video_path, model_path = train(selected_data_path, *args)
# Then return all required outputs
return [
video_path, # video output
gr.Button(value="Load 3D Model", interactive=True), # Return new button with updated properties
gr.Model3D(visible=False), # keep 3D model hidden
model_path # store model path in state
]
def load_model(model_path):
if not model_path:
return gr.Model3D(visible=False)
return gr.Model3D(value=model_path, visible=True)
# Connect the start training button
start_button.click(
fn=handle_training_complete,
inputs=[
dataset_dropdown, sh_degree, model_path, images, resolution, white_background, data_device, eval,
convert_SHs_python, compute_cov3D_python, debug,
iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
],
outputs=[video_output, load_model_button, output, model_path_state]
)
# Connect the load model button
load_model_button.click(
fn=load_model,
inputs=[model_path_state],
outputs=output
)
return gs_demo