Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,326 Bytes
6db5fd9 12ecc72 6db5fd9 12ecc72 d38e5ca 12ecc72 6db5fd9 b46de64 6db5fd9 0c0ba9f 6db5fd9 665b2f0 6db5fd9 b46de64 6db5fd9 b46de64 6db5fd9 665b2f0 6db5fd9 b46de64 6db5fd9 b46de64 6db5fd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
from gs_train import train
import os
from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR
def get_dataset_folders(datasets_path):
try:
return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
except FileNotFoundError:
return []
def gs_demo_tab():
# datasets_path = "/app/data/scenes/"
dataset_path = os.path.join(CACHE_PATH, DATASET_DIR)
def start_training(selected_folder, *args):
selected_data_path = os.path.join(datasets_path, selected_folder)
return train(selected_data_path, *args)
def get_context():
return gr.Blocks(delete_cache=(True, True))
with get_context() as gs_demo:
gr.Markdown("""
<style>
.fixed-size-video video {
max-height: 400px !important;
height: 400px !important;
object-fit: contain;
}
</style>
""")
# Centered title
gr.Markdown("""
<h2 style="text-align: center;">3D Gaussian Splatting Reconstruction</h2>
""")
# Instructions
gr.Markdown('''
<div style="padding: 10px; background-color: #e9f7ef; border-radius: 5px; margin-bottom: 10px;">
<h3>Instructions for 3DGS Demo</h3>
<ul style="text-align: left; color: #333;">
<li>Make sure to press "Refresh Datasets" to obtain an updated list of datasets from Stage 1. They are in the format run_0, run_1, run_...</li>
<li>Adjust optimization parameters if needed, and press "Start Training".</li>
<li>It is recommended to use 7k iterations to avoid exceeding the 3-minute limit. If you still exceed the limit, reduce the number of iterations.</li>
<li>After reconstruction is finished, you can view it as a small video generated or download the full 3DGS reconstruction below the video.</li>
<li>Press "Load 3D Model" to view the full 3DGS reconstruction.</li>
</ul>
<p><b>Note: 3DGS '.ply' models could be heavy, so it may take some time to download and view them in the 3D model section.</b></p>
</div>
''')
refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
def update_dataset_dropdown():
print("update_dataset_dropdown, cache_path", CACHE_PATH)
# Update the dataset folders list
dataset_folders = get_dataset_folders(dataset_path)
# dataset_folders = "/app/data/scenes/"
print("dataset_folders", dataset_folders)
# Only set a default value if there are folders available
default_value = dataset_folders[0] if dataset_folders else None
return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value)
# Set the update function to be called when the refresh button is clicked
refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown)
with gr.Accordion("Optimization Parameters", open=False):
with gr.Row():
with gr.Column():
position_lr_init = gr.Number(label="Position LR Init", value=0.00032)
position_lr_final = gr.Number(label="Position LR Final", value=0.0000032)
position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.02)
position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=15000)
feature_lr = gr.Number(label="Feature LR", value=0.005)
with gr.Column():
feature_lr = gr.Number(label="Feature LR", value=0.0025)
opacity_lr = gr.Number(label="Opacity LR", value=0.05)
scaling_lr = gr.Number(label="Scaling LR", value=0.005)
rotation_lr = gr.Number(label="Rotation LR", value=0.001)
percent_dense = gr.Number(label="Percent Dense", value=0.01)
with gr.Column():
lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2)
densification_interval = gr.Number(label="Densification Interval", value=100)
opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000)
densify_from_iter = gr.Number(label="Densify From Iter", value=500)
densify_until_iter = gr.Number(label="Densify Until Iter", value=15000)
densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002)
iterations = gr.Slider(label="Iterations", value=7000, minimum=1, maximum=15000, step=5)
start_button = gr.Button("Start Training")
# Add state variable to store model path
model_path_state = gr.State()
# Add video output and load model button with fixed scale
video_output = gr.Video(
label="Training Progress",
height=400, # Fixed height
width="100%", # Full width of container
autoplay=False, # Prevent autoplay
show_label=True,
container=True,
elem_classes="fixed-size-video" # Add custom class for potential CSS
)
load_model_button = gr.Button("Load 3D Model", interactive=False)
output = gr.Model3D(label="3D Model Output", visible=False)
def handle_training_complete(selected_folder, *args):
# Construct the full path to the selected dataset
selected_data_path = os.path.join(dataset_path, selected_folder)
# Call the training function with the full path
video_path, model_path = train(selected_data_path, *args)
# Then return all required outputs
return [
video_path, # video output
gr.Button(value="Load 3D Model", interactive=True), # Return new button with updated properties
gr.Model3D(visible=False), # keep 3D model hidden
model_path # store model path in state
]
def load_model(model_path):
if not model_path:
return gr.Model3D(visible=False)
return gr.Model3D(value=model_path, visible=True)
# Connect the start training button
start_button.click(
fn=handle_training_complete,
inputs=[
dataset_dropdown, iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
densify_from_iter, densify_until_iter, densify_grad_threshold
],
outputs=[video_output, load_model_button, output, model_path_state]
)
# Connect the load model button
load_model_button.click(
fn=load_model,
inputs=[model_path_state],
outputs=output
)
return gs_demo |