Ryukijano commited on
Commit
af2e32a
·
verified ·
1 Parent(s): 6a66177

Extend app to accept multiple images for scene reconstruction

Browse files

This update modifies the Flash3D application to accept multiple input images for reconstructing a complete 3D scene. Key changes include:

Gradio Interface Changes:

Updated the input component to accept multiple images using gr.Images.
Added a gallery to display the preprocessed images.
Model Inference Updates:

Modified the preprocess function to handle multiple images.
Updated the reconstruct_and_export function to iterate over all uploaded images for scene reconstruction.
Adjusted logic to combine or save outputs from multiple views.
User Interaction:

Users can now upload several images from different angles to create a richer reconstruction.
Added sliders for adjustable parameters: padding and number of Gaussians per pixel.
This new functionality aims to provide a more comprehensive 3D reconstruction, allowing for richer inputs from multiple perspectives and generating a better quality model.

Files changed (1) hide show
  1. app.py +72 -56
app.py CHANGED
@@ -9,6 +9,7 @@ import torchvision.transforms as TT
9
  import torchvision.transforms.functional as TTF
10
  from huggingface_hub import hf_hub_download
11
  import numpy as np
 
12
 
13
  from networks.gaussian_predictor import GaussianPredictor
14
  from util.vis3d import save_ply
@@ -54,50 +55,63 @@ def main():
54
  to_tensor = TT.ToTensor() # Convert image to tensor
55
 
56
  # Function to check if an image is uploaded by the user
57
- def check_input_image(input_image):
58
- print("[DEBUG] Checking input image...")
59
- if input_image is None:
60
- print("[ERROR] No image uploaded!")
61
- raise gr.Error("No image uploaded!")
62
- print("[INFO] Input image is valid.")
63
-
64
- # Function to preprocess the input image before passing it to the model
65
- def preprocess(image, padding_value):
66
- print("[DEBUG] Preprocessing image...")
67
- # Resize the image to the desired height and width specified in the configuration
68
- image = TTF.resize(
69
- image, (cfg.dataset.height, cfg.dataset.width),
70
- interpolation=TT.InterpolationMode.BICUBIC
71
- )
72
- # Apply padding to the image
73
- pad_border_fn = TT.Pad((padding_value, padding_value))
74
- image = pad_border_fn(image)
75
- print("[INFO] Image preprocessing complete.")
76
- return image
77
-
78
- # Function to reconstruct the 3D model from the input image and export it as a PLY file
79
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
80
- def reconstruct_and_export(image, num_gauss):
81
  """
82
- Passes image through model, outputs reconstruction in form of a dict of tensors.
83
  """
84
- print("[DEBUG] Starting reconstruction and export...")
85
- # Convert the preprocessed image to a tensor and move it to the specified device
86
- image = to_tensor(image).to(device).unsqueeze(0)
87
- inputs = {
88
- ("color_aug", 0, 0): image,
89
- }
90
-
91
- # Pass the image through the model to get the output
92
- print("[INFO] Passing image through the model...")
93
- outputs = model(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Export the reconstruction to a PLY file
96
  print(f"[INFO] Saving output to {ply_out_path}...")
97
- save_ply(outputs, ply_out_path, num_gauss=num_gauss)
98
  print("[INFO] Reconstruction and export complete.")
99
 
100
- return ply_out_path
101
 
102
  # Path to save the output PLY file
103
  ply_out_path = f'./mesh.ply'
@@ -120,18 +134,20 @@ def main():
120
  with gr.Row(variant="panel"):
121
  with gr.Column(scale=1):
122
  with gr.Row():
123
- # Input image component for the user to upload an image
124
- input_image = gr.Image(
125
- label="Input Image",
126
- image_mode="RGBA",
127
- sources="upload",
128
- type="pil",
129
- elem_id="content_image",
 
 
130
  )
131
  with gr.Row():
132
  # Sliders for configurable parameters
133
- num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
134
- padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
135
  with gr.Row():
136
  # Button to trigger the generation process
137
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
@@ -147,35 +163,35 @@ def main():
147
  './demo_examples/re10k_05.jpg',
148
  './demo_examples/re10k_06.jpg',
149
  ],
150
- inputs=[input_image],
151
  cache_examples=False,
152
- label="Examples",
153
  examples_per_page=20,
154
  )
155
 
156
  with gr.Row():
157
- # Display the preprocessed image (after resizing and padding)
158
- processed_image = gr.Image(label="Processed Image", interactive=False)
159
 
160
  with gr.Column(scale=2):
161
  with gr.Row():
162
  with gr.Tab("Reconstruction"):
163
  # 3D model viewer to display the reconstructed model
164
  output_model = gr.Model3D(
165
- height=512,
166
  label="Output Model",
167
- interactive=False
168
  )
169
 
170
  # Define the workflow for the Generate button
171
- submit.click(fn=check_input_image, inputs=[input_image]).success(
172
  fn=preprocess,
173
- inputs=[input_image, padding_value],
174
- outputs=[processed_image],
175
  ).success(
176
  fn=reconstruct_and_export,
177
- inputs=[processed_image, num_gauss],
178
- outputs=[output_model],
179
  )
180
 
181
  # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
 
9
  import torchvision.transforms.functional as TTF
10
  from huggingface_hub import hf_hub_download
11
  import numpy as np
12
+ from einops import rearrange
13
 
14
  from networks.gaussian_predictor import GaussianPredictor
15
  from util.vis3d import save_ply
 
55
  to_tensor = TT.ToTensor() # Convert image to tensor
56
 
57
  # Function to check if an image is uploaded by the user
58
+ def check_input_image(input_images):
59
+ print("[DEBUG] Checking input images...")
60
+ if not input_images or len(input_images) == 0:
61
+ print("[ERROR] No images uploaded!")
62
+ raise gr.Error("No images uploaded!")
63
+ print("[INFO] Input images are valid.")
64
+
65
+ # Function to preprocess the input images before passing them to the model
66
+ def preprocess(images, padding_value):
67
+ processed_images = []
68
+ for image in images:
69
+ # Resize and pad each image
70
+ print("[DEBUG] Preprocessing image...")
71
+ image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
72
+ pad_border_fn = TT.Pad((padding_value, padding_value))
73
+ image = pad_border_fn(image)
74
+ print("[INFO] Image preprocessing complete.")
75
+ processed_images.append(image)
76
+ return processed_images
77
+
78
+ # Function to reconstruct the 3D model from the input images and export it as a PLY file
 
79
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
80
+ def reconstruct_and_export(images, num_gauss):
81
  """
82
+ Passes images through model, outputs reconstruction in form of a dict of tensors.
83
  """
84
+ outputs_list = []
85
+ for image in images:
86
+ print("[DEBUG] Starting reconstruction and export...")
87
+ # Convert the preprocessed image to a tensor and move it to the specified device
88
+ image = to_tensor(image).to(device).unsqueeze(0) # Add a batch dimension to the image tensor
89
+ inputs = {
90
+ ("color_aug", 0, 0): image, # The input dictionary expected by the model
91
+ }
92
+
93
+ # Pass the image through the model to get the output
94
+ print("[INFO] Passing image through the model...")
95
+ outputs = model(inputs) # Perform inference to get model outputs
96
+ outputs_list.append(outputs)
97
+
98
+ # Combine or process outputs from multiple images here if necessary
99
+ # For now, we'll just save the first one for illustration
100
+ gauss_means = outputs_list[0][('gauss_means', 0, 0)]
101
+ if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
102
+ adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
103
+ print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
104
+ num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
105
+
106
+ # Debugging tensor shape
107
+ print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
108
 
109
  # Export the reconstruction to a PLY file
110
  print(f"[INFO] Saving output to {ply_out_path}...")
111
+ save_ply(outputs_list[0], ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
112
  print("[INFO] Reconstruction and export complete.")
113
 
114
+ return ply_out_path # Return the path to the saved PLY file
115
 
116
  # Path to save the output PLY file
117
  ply_out_path = f'./mesh.ply'
 
134
  with gr.Row(variant="panel"):
135
  with gr.Column(scale=1):
136
  with gr.Row():
137
+ # Input images component for the user to upload multiple images
138
+ input_images = gr.Images(
139
+ label="Input Images",
140
+ image_mode="RGBA", # Accept RGBA images
141
+ sources="upload", # Allow users to upload images
142
+ type="pil", # The images are returned as PIL images
143
+ elem_id="content_images",
144
+ tool="editor", # Optional, for editing images
145
+ multiple=True # Allow multiple image uploads
146
  )
147
  with gr.Row():
148
  # Sliders for configurable parameters
149
+ num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=1) # Slider to set the number of Gaussians per pixel
150
+ padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32) # Slider to set padding value
151
  with gr.Row():
152
  # Button to trigger the generation process
153
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
 
163
  './demo_examples/re10k_05.jpg',
164
  './demo_examples/re10k_06.jpg',
165
  ],
166
+ inputs=[input_images], # Load the example images into the input component
167
  cache_examples=False,
168
+ label="Examples", # Label for the examples section
169
  examples_per_page=20,
170
  )
171
 
172
  with gr.Row():
173
+ # Display the preprocessed images (after resizing and padding)
174
+ processed_images = gr.Gallery(label="Processed Images", interactive=False) # Output component to show the processed images
175
 
176
  with gr.Column(scale=2):
177
  with gr.Row():
178
  with gr.Tab("Reconstruction"):
179
  # 3D model viewer to display the reconstructed model
180
  output_model = gr.Model3D(
181
+ height=512, # Height of the 3D model viewer
182
  label="Output Model",
183
+ interactive=False # The viewer is not interactive
184
  )
185
 
186
  # Define the workflow for the Generate button
187
+ submit.click(fn=check_input_image, inputs=[input_images]).success(
188
  fn=preprocess,
189
+ inputs=[input_images, padding_value], # Pass the input images and padding value to the preprocess function
190
+ outputs=[processed_images], # Output the processed images
191
  ).success(
192
  fn=reconstruct_and_export,
193
+ inputs=[processed_images, num_gauss], # Pass the processed images and number of Gaussians to the reconstruction function
194
+ outputs=[output_model], # Output the reconstructed 3D model
195
  )
196
 
197
  # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)