Extend app to accept multiple images for scene reconstruction
Browse filesThis update modifies the Flash3D application to accept multiple input images for reconstructing a complete 3D scene. Key changes include:
Gradio Interface Changes:
Updated the input component to accept multiple images using gr.Images.
Added a gallery to display the preprocessed images.
Model Inference Updates:
Modified the preprocess function to handle multiple images.
Updated the reconstruct_and_export function to iterate over all uploaded images for scene reconstruction.
Adjusted logic to combine or save outputs from multiple views.
User Interaction:
Users can now upload several images from different angles to create a richer reconstruction.
Added sliders for adjustable parameters: padding and number of Gaussians per pixel.
This new functionality aims to provide a more comprehensive 3D reconstruction, allowing for richer inputs from multiple perspectives and generating a better quality model.
@@ -9,6 +9,7 @@ import torchvision.transforms as TT
|
|
9 |
import torchvision.transforms.functional as TTF
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import numpy as np
|
|
|
12 |
|
13 |
from networks.gaussian_predictor import GaussianPredictor
|
14 |
from util.vis3d import save_ply
|
@@ -54,50 +55,63 @@ def main():
|
|
54 |
to_tensor = TT.ToTensor() # Convert image to tensor
|
55 |
|
56 |
# Function to check if an image is uploaded by the user
|
57 |
-
def check_input_image(
|
58 |
-
print("[DEBUG] Checking input
|
59 |
-
if
|
60 |
-
print("[ERROR] No
|
61 |
-
raise gr.Error("No
|
62 |
-
print("[INFO] Input
|
63 |
-
|
64 |
-
# Function to preprocess the input
|
65 |
-
def preprocess(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
interpolation=TT.InterpolationMode.BICUBIC
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
# Function to reconstruct the 3D model from the input image and export it as a PLY file
|
79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
80 |
-
def reconstruct_and_export(
|
81 |
"""
|
82 |
-
Passes
|
83 |
"""
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# Export the reconstruction to a PLY file
|
96 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
97 |
-
save_ply(
|
98 |
print("[INFO] Reconstruction and export complete.")
|
99 |
|
100 |
-
return ply_out_path
|
101 |
|
102 |
# Path to save the output PLY file
|
103 |
ply_out_path = f'./mesh.ply'
|
@@ -120,18 +134,20 @@ def main():
|
|
120 |
with gr.Row(variant="panel"):
|
121 |
with gr.Column(scale=1):
|
122 |
with gr.Row():
|
123 |
-
# Input
|
124 |
-
|
125 |
-
label="Input
|
126 |
-
image_mode="RGBA",
|
127 |
-
sources="upload",
|
128 |
-
type="pil",
|
129 |
-
elem_id="
|
|
|
|
|
130 |
)
|
131 |
with gr.Row():
|
132 |
# Sliders for configurable parameters
|
133 |
-
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=
|
134 |
-
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
|
135 |
with gr.Row():
|
136 |
# Button to trigger the generation process
|
137 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
@@ -147,35 +163,35 @@ def main():
|
|
147 |
'./demo_examples/re10k_05.jpg',
|
148 |
'./demo_examples/re10k_06.jpg',
|
149 |
],
|
150 |
-
inputs=[
|
151 |
cache_examples=False,
|
152 |
-
label="Examples",
|
153 |
examples_per_page=20,
|
154 |
)
|
155 |
|
156 |
with gr.Row():
|
157 |
-
# Display the preprocessed
|
158 |
-
|
159 |
|
160 |
with gr.Column(scale=2):
|
161 |
with gr.Row():
|
162 |
with gr.Tab("Reconstruction"):
|
163 |
# 3D model viewer to display the reconstructed model
|
164 |
output_model = gr.Model3D(
|
165 |
-
height=512,
|
166 |
label="Output Model",
|
167 |
-
interactive=False
|
168 |
)
|
169 |
|
170 |
# Define the workflow for the Generate button
|
171 |
-
submit.click(fn=check_input_image, inputs=[
|
172 |
fn=preprocess,
|
173 |
-
inputs=[
|
174 |
-
outputs=[
|
175 |
).success(
|
176 |
fn=reconstruct_and_export,
|
177 |
-
inputs=[
|
178 |
-
outputs=[output_model],
|
179 |
)
|
180 |
|
181 |
# Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
|
|
|
9 |
import torchvision.transforms.functional as TTF
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import numpy as np
|
12 |
+
from einops import rearrange
|
13 |
|
14 |
from networks.gaussian_predictor import GaussianPredictor
|
15 |
from util.vis3d import save_ply
|
|
|
55 |
to_tensor = TT.ToTensor() # Convert image to tensor
|
56 |
|
57 |
# Function to check if an image is uploaded by the user
|
58 |
+
def check_input_image(input_images):
|
59 |
+
print("[DEBUG] Checking input images...")
|
60 |
+
if not input_images or len(input_images) == 0:
|
61 |
+
print("[ERROR] No images uploaded!")
|
62 |
+
raise gr.Error("No images uploaded!")
|
63 |
+
print("[INFO] Input images are valid.")
|
64 |
+
|
65 |
+
# Function to preprocess the input images before passing them to the model
|
66 |
+
def preprocess(images, padding_value):
|
67 |
+
processed_images = []
|
68 |
+
for image in images:
|
69 |
+
# Resize and pad each image
|
70 |
+
print("[DEBUG] Preprocessing image...")
|
71 |
+
image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
|
72 |
+
pad_border_fn = TT.Pad((padding_value, padding_value))
|
73 |
+
image = pad_border_fn(image)
|
74 |
+
print("[INFO] Image preprocessing complete.")
|
75 |
+
processed_images.append(image)
|
76 |
+
return processed_images
|
77 |
+
|
78 |
+
# Function to reconstruct the 3D model from the input images and export it as a PLY file
|
|
|
79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
80 |
+
def reconstruct_and_export(images, num_gauss):
|
81 |
"""
|
82 |
+
Passes images through model, outputs reconstruction in form of a dict of tensors.
|
83 |
"""
|
84 |
+
outputs_list = []
|
85 |
+
for image in images:
|
86 |
+
print("[DEBUG] Starting reconstruction and export...")
|
87 |
+
# Convert the preprocessed image to a tensor and move it to the specified device
|
88 |
+
image = to_tensor(image).to(device).unsqueeze(0) # Add a batch dimension to the image tensor
|
89 |
+
inputs = {
|
90 |
+
("color_aug", 0, 0): image, # The input dictionary expected by the model
|
91 |
+
}
|
92 |
+
|
93 |
+
# Pass the image through the model to get the output
|
94 |
+
print("[INFO] Passing image through the model...")
|
95 |
+
outputs = model(inputs) # Perform inference to get model outputs
|
96 |
+
outputs_list.append(outputs)
|
97 |
+
|
98 |
+
# Combine or process outputs from multiple images here if necessary
|
99 |
+
# For now, we'll just save the first one for illustration
|
100 |
+
gauss_means = outputs_list[0][('gauss_means', 0, 0)]
|
101 |
+
if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
|
102 |
+
adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
|
103 |
+
print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
|
104 |
+
num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
|
105 |
+
|
106 |
+
# Debugging tensor shape
|
107 |
+
print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
|
108 |
|
109 |
# Export the reconstruction to a PLY file
|
110 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
111 |
+
save_ply(outputs_list[0], ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
|
112 |
print("[INFO] Reconstruction and export complete.")
|
113 |
|
114 |
+
return ply_out_path # Return the path to the saved PLY file
|
115 |
|
116 |
# Path to save the output PLY file
|
117 |
ply_out_path = f'./mesh.ply'
|
|
|
134 |
with gr.Row(variant="panel"):
|
135 |
with gr.Column(scale=1):
|
136 |
with gr.Row():
|
137 |
+
# Input images component for the user to upload multiple images
|
138 |
+
input_images = gr.Images(
|
139 |
+
label="Input Images",
|
140 |
+
image_mode="RGBA", # Accept RGBA images
|
141 |
+
sources="upload", # Allow users to upload images
|
142 |
+
type="pil", # The images are returned as PIL images
|
143 |
+
elem_id="content_images",
|
144 |
+
tool="editor", # Optional, for editing images
|
145 |
+
multiple=True # Allow multiple image uploads
|
146 |
)
|
147 |
with gr.Row():
|
148 |
# Sliders for configurable parameters
|
149 |
+
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=1) # Slider to set the number of Gaussians per pixel
|
150 |
+
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32) # Slider to set padding value
|
151 |
with gr.Row():
|
152 |
# Button to trigger the generation process
|
153 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
|
|
163 |
'./demo_examples/re10k_05.jpg',
|
164 |
'./demo_examples/re10k_06.jpg',
|
165 |
],
|
166 |
+
inputs=[input_images], # Load the example images into the input component
|
167 |
cache_examples=False,
|
168 |
+
label="Examples", # Label for the examples section
|
169 |
examples_per_page=20,
|
170 |
)
|
171 |
|
172 |
with gr.Row():
|
173 |
+
# Display the preprocessed images (after resizing and padding)
|
174 |
+
processed_images = gr.Gallery(label="Processed Images", interactive=False) # Output component to show the processed images
|
175 |
|
176 |
with gr.Column(scale=2):
|
177 |
with gr.Row():
|
178 |
with gr.Tab("Reconstruction"):
|
179 |
# 3D model viewer to display the reconstructed model
|
180 |
output_model = gr.Model3D(
|
181 |
+
height=512, # Height of the 3D model viewer
|
182 |
label="Output Model",
|
183 |
+
interactive=False # The viewer is not interactive
|
184 |
)
|
185 |
|
186 |
# Define the workflow for the Generate button
|
187 |
+
submit.click(fn=check_input_image, inputs=[input_images]).success(
|
188 |
fn=preprocess,
|
189 |
+
inputs=[input_images, padding_value], # Pass the input images and padding value to the preprocess function
|
190 |
+
outputs=[processed_images], # Output the processed images
|
191 |
).success(
|
192 |
fn=reconstruct_and_export,
|
193 |
+
inputs=[processed_images, num_gauss], # Pass the processed images and number of Gaussians to the reconstruction function
|
194 |
+
outputs=[output_model], # Output the reconstructed 3D model
|
195 |
)
|
196 |
|
197 |
# Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
|