File size: 7,326 Bytes
6db5fd9
 
 
 
12ecc72
6db5fd9
 
 
 
 
 
 
12ecc72
d38e5ca
12ecc72
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46de64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6db5fd9
 
 
 
 
0c0ba9f
6db5fd9
665b2f0
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
b46de64
 
 
 
 
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
b46de64
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665b2f0
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46de64
6db5fd9
 
b46de64
6db5fd9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
from gs_train import train
import os

from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR

def get_dataset_folders(datasets_path):
    try:
        return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
    except FileNotFoundError:
        return []

def gs_demo_tab():
    # datasets_path = "/app/data/scenes/"
    dataset_path = os.path.join(CACHE_PATH, DATASET_DIR)
    def start_training(selected_folder, *args):
        selected_data_path = os.path.join(datasets_path, selected_folder)
        return train(selected_data_path, *args)
    
    def get_context():
        return gr.Blocks(delete_cache=(True, True))
    
    with get_context() as gs_demo:
        gr.Markdown("""
        <style>
        .fixed-size-video video {
            max-height: 400px !important;
            height: 400px !important;
            object-fit: contain;
        }
        </style>
        """)
        
        # Centered title
        gr.Markdown("""
        <h2 style="text-align: center;">3D Gaussian Splatting Reconstruction</h2>
        """)

        # Instructions
        gr.Markdown('''
        <div style="padding: 10px; background-color: #e9f7ef; border-radius: 5px; margin-bottom: 10px;">
            <h3>Instructions for 3DGS Demo</h3>
            <ul style="text-align: left; color: #333;">
                <li>Make sure to press "Refresh Datasets" to obtain an updated list of datasets from Stage 1. They are in the format run_0, run_1, run_...</li>
                <li>Adjust optimization parameters if needed, and press "Start Training".</li>
                <li>It is recommended to use 7k iterations to avoid exceeding the 3-minute limit. If you still exceed the limit, reduce the number of iterations.</li>
                <li>After reconstruction is finished, you can view it as a small video generated or download the full 3DGS reconstruction below the video.</li>
                <li>Press "Load 3D Model" to view the full 3DGS reconstruction.</li>
            </ul>
            <p><b>Note: 3DGS '.ply' models could be heavy, so it may take some time to download and view them in the 3D model section.</b></p>
        </div>
        ''')

        refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
        dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")

        def update_dataset_dropdown():
            print("update_dataset_dropdown, cache_path", CACHE_PATH)
            # Update the dataset folders list
            dataset_folders = get_dataset_folders(dataset_path)
            # dataset_folders = "/app/data/scenes/"
            print("dataset_folders", dataset_folders)
            # Only set a default value if there are folders available
            default_value = dataset_folders[0] if dataset_folders else None
            return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value)
        
        # Set the update function to be called when the refresh button is clicked
        refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown)
        
        with gr.Accordion("Optimization Parameters", open=False):
            with gr.Row():
                with gr.Column():
                    position_lr_init = gr.Number(label="Position LR Init", value=0.00032)
                    position_lr_final = gr.Number(label="Position LR Final", value=0.0000032)
                    position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.02)
                    position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=15000)
                    feature_lr = gr.Number(label="Feature LR", value=0.005)
                with gr.Column():
                    feature_lr = gr.Number(label="Feature LR", value=0.0025)
                    opacity_lr = gr.Number(label="Opacity LR", value=0.05)
                    scaling_lr = gr.Number(label="Scaling LR", value=0.005)
                    rotation_lr = gr.Number(label="Rotation LR", value=0.001)
                    percent_dense = gr.Number(label="Percent Dense", value=0.01)
                with gr.Column():
                    lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2)
                    densification_interval = gr.Number(label="Densification Interval", value=100)
                    opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000)
                    densify_from_iter = gr.Number(label="Densify From Iter", value=500)
                    densify_until_iter = gr.Number(label="Densify Until Iter", value=15000)
                    densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002)
        iterations = gr.Slider(label="Iterations", value=7000, minimum=1, maximum=15000, step=5)
        
        start_button = gr.Button("Start Training")
        
        # Add state variable to store model path
        model_path_state = gr.State()
        
        # Add video output and load model button with fixed scale
        video_output = gr.Video(
            label="Training Progress", 
            height=400,  # Fixed height
            width="100%",  # Full width of container
            autoplay=False,  # Prevent autoplay
            show_label=True,
            container=True,
            elem_classes="fixed-size-video"  # Add custom class for potential CSS
        )
        load_model_button = gr.Button("Load 3D Model", interactive=False)
        output = gr.Model3D(label="3D Model Output", visible=False)
        
        def handle_training_complete(selected_folder, *args):
            # Construct the full path to the selected dataset
            selected_data_path = os.path.join(dataset_path, selected_folder)
            # Call the training function with the full path
            video_path, model_path = train(selected_data_path, *args)
            # Then return all required outputs
            return [
                video_path,           # video output
                gr.Button(value="Load 3D Model", interactive=True),  # Return new button with updated properties
                gr.Model3D(visible=False),  # keep 3D model hidden
                model_path            # store model path in state
            ]
        
        def load_model(model_path):
            if not model_path:
                return gr.Model3D(visible=False)
            return gr.Model3D(value=model_path, visible=True)
        
        # Connect the start training button
        start_button.click(
            fn=handle_training_complete,
            inputs=[
                dataset_dropdown, iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
                position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
                percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
                densify_from_iter, densify_until_iter, densify_grad_threshold
            ],
            outputs=[video_output, load_model_button, output, model_path_state]
        )
        
        # Connect the load model button
        load_model_button.click(
            fn=load_model,
            inputs=[model_path_state],
            outputs=output
        )
    return gs_demo