Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,321 Bytes
6db5fd9 d38e5ca 6db5fd9 665b2f0 6db5fd9 665b2f0 6db5fd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
from gs_train import train
import os
DATASET_DIR = "colmap_data"
def get_dataset_folders(datasets_path):
try:
return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
except FileNotFoundError:
return []
def gs_demo_tab(cache_path):
# datasets_path = "/app/data/scenes/"
dataset_path = os.path.join(cache_path, DATASET_DIR)
def start_training(selected_folder, *args):
selected_data_path = os.path.join(datasets_path, selected_folder)
return train(selected_data_path, *args)
def get_context():
return gr.Blocks(delete_cache=(True, True))
with get_context() as gs_demo:
gr.Markdown("""
<style>
.fixed-size-video video {
max-height: 400px !important;
height: 400px !important;
object-fit: contain;
}
</style>
""")
gr.Markdown("# Gaussian Splatting Training Demo")
refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
def update_dataset_dropdown():
print("update_dataset_dropdown, cache_path", cache_path)
# Update the dataset folders list
dataset_folders = get_dataset_folders(dataset_path)
# dataset_folders = "/app/data/scenes/"
print("dataset_folders", dataset_folders)
# Only set a default value if there are folders available
default_value = dataset_folders[0] if dataset_folders else None
return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value)
# Set the update function to be called when the refresh button is clicked
refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown)
with gr.Accordion("Model Parameters", open=False):
with gr.Row():
with gr.Column():
sh_degree = gr.Number(label="SH Degree", value=3)
model_path = gr.Textbox(label="Model Path", value="")
images = gr.Textbox(label="Images", value="images")
resolution = gr.Number(label="Resolution", value=-1)
white_background = gr.Checkbox(label="White Background", value=True)
data_device = gr.Dropdown(label="Data Device", choices=["cuda", "cpu"], value="cuda")
eval = gr.Checkbox(label="Eval", value=False)
with gr.Accordion("Pipeline Parameters", open=False):
with gr.Row():
with gr.Column():
convert_SHs_python = gr.Checkbox(label="Convert SHs Python", value=False)
compute_cov3D_python = gr.Checkbox(label="Compute Cov3D Python", value=False)
debug = gr.Checkbox(label="Debug", value=False)
with gr.Accordion("Optimization Parameters", open=False):
with gr.Row():
with gr.Column():
iterations = gr.Number(label="Iterations", value=1000)
position_lr_init = gr.Number(label="Position LR Init", value=0.00016)
position_lr_final = gr.Number(label="Position LR Final", value=0.0000016)
position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.01)
position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=30000)
with gr.Column():
feature_lr = gr.Number(label="Feature LR", value=0.0025)
opacity_lr = gr.Number(label="Opacity LR", value=0.05)
scaling_lr = gr.Number(label="Scaling LR", value=0.005)
rotation_lr = gr.Number(label="Rotation LR", value=0.001)
percent_dense = gr.Number(label="Percent Dense", value=0.01)
with gr.Column():
lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2)
densification_interval = gr.Number(label="Densification Interval", value=100)
opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000)
densify_from_iter = gr.Number(label="Densify From Iter", value=500)
densify_until_iter = gr.Number(label="Densify Until Iter", value=15000)
densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002)
random_background = gr.Checkbox(label="Random Background", value=False)
start_button = gr.Button("Start Training")
# Add state variable to store model path
model_path_state = gr.State()
# Add video output and load model button with fixed scale
video_output = gr.Video(
label="Training Progress",
height=400, # Fixed height
width="100%", # Full width of container
autoplay=False, # Prevent autoplay
show_label=True,
container=True,
elem_classes="fixed-size-video" # Add custom class for potential CSS
)
load_model_button = gr.Button("Load 3D Model", interactive=False)
output = gr.Model3D(label="3D Model Output", visible=False)
def handle_training_complete(selected_folder, *args):
# Construct the full path to the selected dataset
selected_data_path = os.path.join(dataset_path, selected_folder)
# Call the training function with the full path
video_path, model_path = train(selected_data_path, *args)
# Then return all required outputs
return [
video_path, # video output
gr.Button(value="Load 3D Model", interactive=True), # Return new button with updated properties
gr.Model3D(visible=False), # keep 3D model hidden
model_path # store model path in state
]
def load_model(model_path):
if not model_path:
return gr.Model3D(visible=False)
return gr.Model3D(value=model_path, visible=True)
# Connect the start training button
start_button.click(
fn=handle_training_complete,
inputs=[
dataset_dropdown, sh_degree, model_path, images, resolution, white_background, data_device, eval,
convert_SHs_python, compute_cov3D_python, debug,
iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
],
outputs=[video_output, load_model_button, output, model_path_state]
)
# Connect the load model button
load_model_button.click(
fn=load_model,
inputs=[model_path_state],
outputs=output
)
return gs_demo |