Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from gs_train import train | |
import os | |
DATASET_DIR = "colmap_data" | |
def get_dataset_folders(datasets_path): | |
try: | |
return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))] | |
except FileNotFoundError: | |
return [] | |
def gs_demo_tab(cache_path): | |
# datasets_path = "/app/data/scenes/" | |
dataset_path = os.path.join(cache_path, DATASET_DIR) | |
def start_training(selected_folder, *args): | |
selected_data_path = os.path.join(datasets_path, selected_folder) | |
return train(selected_data_path, *args) | |
def get_context(): | |
return gr.Blocks(delete_cache=(True, True)) | |
with get_context() as gs_demo: | |
gr.Markdown(""" | |
<style> | |
.fixed-size-video video { | |
max-height: 400px !important; | |
height: 400px !important; | |
object-fit: contain; | |
} | |
</style> | |
""") | |
gr.Markdown("# Gaussian Splatting Training Demo") | |
refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button") | |
dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="") | |
def update_dataset_dropdown(): | |
print("update_dataset_dropdown, cache_path", cache_path) | |
# Update the dataset folders list | |
dataset_folders = get_dataset_folders(dataset_path) | |
# dataset_folders = "/app/data/scenes/" | |
print("dataset_folders", dataset_folders) | |
# Only set a default value if there are folders available | |
default_value = dataset_folders[0] if dataset_folders else None | |
return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value) | |
# Set the update function to be called when the refresh button is clicked | |
refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown) | |
with gr.Accordion("Model Parameters", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
sh_degree = gr.Number(label="SH Degree", value=3) | |
model_path = gr.Textbox(label="Model Path", value="") | |
images = gr.Textbox(label="Images", value="images") | |
resolution = gr.Number(label="Resolution", value=-1) | |
white_background = gr.Checkbox(label="White Background", value=True) | |
data_device = gr.Dropdown(label="Data Device", choices=["cuda", "cpu"], value="cuda") | |
eval = gr.Checkbox(label="Eval", value=False) | |
with gr.Accordion("Pipeline Parameters", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
convert_SHs_python = gr.Checkbox(label="Convert SHs Python", value=False) | |
compute_cov3D_python = gr.Checkbox(label="Compute Cov3D Python", value=False) | |
debug = gr.Checkbox(label="Debug", value=False) | |
with gr.Accordion("Optimization Parameters", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
iterations = gr.Number(label="Iterations", value=1000) | |
position_lr_init = gr.Number(label="Position LR Init", value=0.00016) | |
position_lr_final = gr.Number(label="Position LR Final", value=0.0000016) | |
position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.01) | |
position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=30000) | |
with gr.Column(): | |
feature_lr = gr.Number(label="Feature LR", value=0.0025) | |
opacity_lr = gr.Number(label="Opacity LR", value=0.05) | |
scaling_lr = gr.Number(label="Scaling LR", value=0.005) | |
rotation_lr = gr.Number(label="Rotation LR", value=0.001) | |
percent_dense = gr.Number(label="Percent Dense", value=0.01) | |
with gr.Column(): | |
lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2) | |
densification_interval = gr.Number(label="Densification Interval", value=100) | |
opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000) | |
densify_from_iter = gr.Number(label="Densify From Iter", value=500) | |
densify_until_iter = gr.Number(label="Densify Until Iter", value=15000) | |
densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002) | |
random_background = gr.Checkbox(label="Random Background", value=False) | |
start_button = gr.Button("Start Training") | |
# Add state variable to store model path | |
model_path_state = gr.State() | |
# Add video output and load model button with fixed scale | |
video_output = gr.Video( | |
label="Training Progress", | |
height=400, # Fixed height | |
width="100%", # Full width of container | |
autoplay=False, # Prevent autoplay | |
show_label=True, | |
container=True, | |
elem_classes="fixed-size-video" # Add custom class for potential CSS | |
) | |
load_model_button = gr.Button("Load 3D Model", interactive=False) | |
output = gr.Model3D(label="3D Model Output", visible=False) | |
def handle_training_complete(selected_folder, *args): | |
# Construct the full path to the selected dataset | |
selected_data_path = os.path.join(dataset_path, selected_folder) | |
# Call the training function with the full path | |
video_path, model_path = train(selected_data_path, *args) | |
# Then return all required outputs | |
return [ | |
video_path, # video output | |
gr.Button(value="Load 3D Model", interactive=True), # Return new button with updated properties | |
gr.Model3D(visible=False), # keep 3D model hidden | |
model_path # store model path in state | |
] | |
def load_model(model_path): | |
if not model_path: | |
return gr.Model3D(visible=False) | |
return gr.Model3D(value=model_path, visible=True) | |
# Connect the start training button | |
start_button.click( | |
fn=handle_training_complete, | |
inputs=[ | |
dataset_dropdown, sh_degree, model_path, images, resolution, white_background, data_device, eval, | |
convert_SHs_python, compute_cov3D_python, debug, | |
iterations, position_lr_init, position_lr_final, position_lr_delay_mult, | |
position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr, | |
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval, | |
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background | |
], | |
outputs=[video_output, load_model_button, output, model_path_state] | |
) | |
# Connect the load model button | |
load_model_button.click( | |
fn=load_model, | |
inputs=[model_path_state], | |
outputs=output | |
) | |
return gs_demo |