import gradio as gr from gs_train import train import os from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR def get_dataset_folders(datasets_path): try: return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))] except FileNotFoundError: return [] def gs_demo_tab(): # datasets_path = "/app/data/scenes/" dataset_path = os.path.join(CACHE_PATH, DATASET_DIR) def start_training(selected_folder, *args): selected_data_path = os.path.join(datasets_path, selected_folder) return train(selected_data_path, *args) def get_context(): return gr.Blocks(delete_cache=(True, True)) with get_context() as gs_demo: gr.Markdown(""" """) # Centered title gr.Markdown("""

3D Gaussian Splatting Reconstruction

""") # Instructions gr.Markdown('''

Instructions for 3DGS Demo

Note: 3DGS '.ply' models could be heavy, so it may take some time to download and view them in the 3D model section.

''') refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button") dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="") def update_dataset_dropdown(): print("update_dataset_dropdown, cache_path", CACHE_PATH) # Update the dataset folders list dataset_folders = get_dataset_folders(dataset_path) # dataset_folders = "/app/data/scenes/" print("dataset_folders", dataset_folders) # Only set a default value if there are folders available default_value = dataset_folders[0] if dataset_folders else None return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value) # Set the update function to be called when the refresh button is clicked refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown) with gr.Accordion("Optimization Parameters", open=False): with gr.Row(): with gr.Column(): position_lr_init = gr.Number(label="Position LR Init", value=0.00032) position_lr_final = gr.Number(label="Position LR Final", value=0.0000032) position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.02) position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=15000) feature_lr = gr.Number(label="Feature LR", value=0.005) with gr.Column(): feature_lr = gr.Number(label="Feature LR", value=0.0025) opacity_lr = gr.Number(label="Opacity LR", value=0.05) scaling_lr = gr.Number(label="Scaling LR", value=0.005) rotation_lr = gr.Number(label="Rotation LR", value=0.001) percent_dense = gr.Number(label="Percent Dense", value=0.01) with gr.Column(): lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2) densification_interval = gr.Number(label="Densification Interval", value=100) opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000) densify_from_iter = gr.Number(label="Densify From Iter", value=500) densify_until_iter = gr.Number(label="Densify Until Iter", value=15000) densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002) iterations = gr.Slider(label="Iterations", value=7000, minimum=1, maximum=15000, step=5) start_button = gr.Button("Start Training") # Add state variable to store model path model_path_state = gr.State() # Add video output and load model button with fixed scale video_output = gr.Video( label="Training Progress", height=400, # Fixed height width="100%", # Full width of container autoplay=False, # Prevent autoplay show_label=True, container=True, elem_classes="fixed-size-video" # Add custom class for potential CSS ) load_model_button = gr.Button("Load 3D Model", interactive=False) output = gr.Model3D(label="3D Model Output", visible=False) def handle_training_complete(selected_folder, *args): # Construct the full path to the selected dataset selected_data_path = os.path.join(dataset_path, selected_folder) # Call the training function with the full path video_path, model_path = train(selected_data_path, *args) # Then return all required outputs return [ video_path, # video output gr.Button(value="Load 3D Model", interactive=True), # Return new button with updated properties gr.Model3D(visible=False), # keep 3D model hidden model_path # store model path in state ] def load_model(model_path): if not model_path: return gr.Model3D(visible=False) return gr.Model3D(value=model_path, visible=True) # Connect the start training button start_button.click( fn=handle_training_complete, inputs=[ dataset_dropdown, iterations, position_lr_init, position_lr_final, position_lr_delay_mult, position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr, percent_dense, lambda_dssim, densification_interval, opacity_reset_interval, densify_from_iter, densify_until_iter, densify_grad_threshold ], outputs=[video_output, load_model_button, output, model_path_state] ) # Connect the load model button load_model_button.click( fn=load_model, inputs=[model_path_state], outputs=output ) return gs_demo