jbilcke-hf HF staff commited on
Commit
0471bc8
·
verified ·
1 Parent(s): 9babed2

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +15 -57
gradio_app.py CHANGED
@@ -47,12 +47,10 @@ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.
47
  """Create an RGBA image from RGB image and optional mask."""
48
  rgba_image = rgb_image.convert('RGBA')
49
  if mask is not None:
50
- print("[debug] mask shape before alpha:", mask.shape)
51
  # Ensure mask is 2D before converting to alpha
52
  if len(mask.shape) > 2:
53
  mask = mask.squeeze()
54
  alpha = Image.fromarray((mask * 255).astype(np.uint8))
55
- print("[debug] alpha size:", alpha.size)
56
  rgba_image.putalpha(alpha)
57
  return rgba_image
58
 
@@ -61,8 +59,7 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
61
  # Resize and convert input image to numpy array
62
  resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
63
  img_array = np.array(resized_image).astype(np.float32) / 255.0
64
- print("[debug] img_array shape:", img_array.shape)
65
-
66
  # Extract RGB and alpha channels
67
  if img_array.shape[-1] == 4: # RGBA
68
  rgb = img_array[..., :3]
@@ -74,25 +71,18 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
74
  # Convert to tensors while keeping channel-last format
75
  rgb = torch.from_numpy(rgb).float() # [H, W, 3]
76
  mask = torch.from_numpy(mask).float() # [H, W, 1]
77
- print("[debug] rgb tensor shape:", rgb.shape)
78
- print("[debug] mask tensor shape:", mask.shape)
79
-
80
  # Create background blend (match channel-last format)
81
  bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3) # [1, 1, 3]
82
- print("[debug] bg_tensor shape:", bg_tensor.shape)
83
-
84
  # Blend RGB with background using mask (all in channel-last format)
85
  rgb_cond = torch.lerp(bg_tensor, rgb, mask) # [H, W, 3]
86
- print("[debug] rgb_cond shape after blend:", rgb_cond.shape)
87
-
88
  # Move channels to correct dimension and add batch dimension
89
  # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
90
  rgb_cond = rgb_cond.unsqueeze(0) # [1, H, W, 3]
91
  mask = mask.unsqueeze(0) # [1, H, W, 1]
92
 
93
- print("[debug] rgb_cond final shape:", rgb_cond.shape)
94
- print("[debug] mask final shape:", mask.shape)
95
-
96
  # Create the batch dictionary
97
  batch = {
98
  "rgb_cond": rgb_cond, # [1, H, W, 3]
@@ -102,35 +92,20 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
102
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), # [1, 3, 3]
103
  }
104
 
105
- print("\nFinal batch shapes:")
106
  for k, v in batch.items():
107
  print(f"[debug] {k} final shape:", v.shape)
108
- print("\nrgb_cond max:", batch["rgb_cond"].max())
109
- print("rgb_cond min:", batch["rgb_cond"].min())
110
- print("mask_cond unique values:", torch.unique(batch["mask_cond"]))
111
-
112
  return batch
113
 
114
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
115
  """Process batch through model and generate point cloud."""
116
- print("\n[debug] Starting forward_model")
117
- print("[debug] Input rgb_cond shape:", batch["rgb_cond"].shape)
118
- print("[debug] Input mask_cond shape:", batch["mask_cond"].shape)
119
-
120
  batch_size = batch["rgb_cond"].shape[0]
121
  assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
122
 
123
- # Print value ranges for debugging
124
- print("\nValue ranges:")
125
- print("rgb_cond max:", batch["rgb_cond"].max())
126
- print("rgb_cond min:", batch["rgb_cond"].min())
127
- print("mask_cond unique values:", torch.unique(batch["mask_cond"]))
128
-
129
  # Generate point cloud tokens
130
- print("\n[debug] Generating point cloud tokens")
131
  try:
132
  cond_tokens = system.forward_pdiff_cond(batch)
133
- print("[debug] cond_tokens shape:", cond_tokens.shape)
134
  except Exception as e:
135
  print("\n[ERROR] Failed in forward_pdiff_cond:")
136
  print(e)
@@ -141,7 +116,6 @@ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
141
  raise
142
 
143
  # Sample points
144
- print("\n[debug] Sampling points")
145
  sample_iter = system.sampler.sample_batch_progressive(
146
  batch_size,
147
  cond_tokens,
@@ -153,18 +127,14 @@ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
153
  for x in sample_iter:
154
  samples = x["xstart"]
155
 
156
- print("[debug] samples shape before permute:", samples.shape)
157
  pc_cond = samples.permute(0, 2, 1).float()
158
- print("[debug] pc_cond shape after permute:", pc_cond.shape)
159
-
160
  # Normalize point cloud
161
  pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
162
- print("[debug] pc_cond shape after normalize:", pc_cond.shape)
163
-
164
  # Subsample to 512 points
165
  pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
166
- print("[debug] pc_cond final shape:", pc_cond.shape)
167
-
168
  return pc_cond
169
 
170
  def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Image.Image | None]:
@@ -180,7 +150,6 @@ def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Im
180
 
181
  # Generate image using FLUX
182
  generator = torch.Generator(device=device).manual_seed(seed)
183
- print("[debug] generating the image using Flux")
184
  generated_image = flux_pipe(
185
  prompt=prompt,
186
  width=width,
@@ -190,10 +159,8 @@ def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Im
190
  guidance_scale=0.0
191
  ).images[0]
192
 
193
- print("[debug] converting the image to rgb")
194
  rgb_image = generated_image.convert('RGB')
195
 
196
- print("[debug] removing the background by calling bg_remover.process(rgb_image)")
197
  # bg_remover returns a PIL Image already, no need to convert
198
  no_bg_image = bg_remover.process(rgb_image)
199
  print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
@@ -202,7 +169,6 @@ def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Im
202
  rgba_image = no_bg_image.convert('RGBA')
203
  print(f"[debug] rgba_image mode: {rgba_image.mode}")
204
 
205
- print("[debug] auto-cropping the rgba_image using spar3d_utils.foreground_crop(...)")
206
  processed_image = spar3d_utils.foreground_crop(
207
  rgba_image,
208
  crop_ratio=1.3,
@@ -215,7 +181,6 @@ def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Im
215
  print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")
216
 
217
  # Prepare batch for processing
218
- print("[debug] preparing the batch by calling create_batch(processed_image)")
219
  batch = create_batch(processed_image)
220
  batch = {k: v.to(device) for k, v in batch.items()}
221
 
@@ -231,7 +196,6 @@ def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Im
231
 
232
  # Generate mesh
233
  with torch.no_grad():
234
- print("[debug] calling torch.autocast(....) to generate the mesh")
235
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
236
  trimesh_mesh, _ = spar3d_model.generate_mesh(
237
  batch,
@@ -243,20 +207,18 @@ def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Im
243
  trimesh_mesh = trimesh_mesh[0]
244
 
245
  # Export to GLB
246
- print("[debug] creating tmp dir for the .glb output")
247
  temp_dir = tempfile.mkdtemp()
248
  output_path = os.path.join(temp_dir, 'output.glb')
249
 
250
- print("[debug] calling trimesh_mesh.export(...) to export to .glb")
251
  trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
252
 
253
- return output_path, generated_image
254
 
255
  except Exception as e:
256
  print(f"Error during generation: {str(e)}")
257
  import traceback
258
  traceback.print_exc()
259
- return None, None
260
 
261
  # Create Gradio interface
262
  demo = gr.Interface(
@@ -276,16 +238,12 @@ demo = gr.Interface(
276
  ],
277
  outputs=[
278
  gr.Model3D(
279
- label="3D Model Preview",
280
  clear_color=[0.0, 0.0, 0.0, 0.0],
281
- ),
282
- gr.Image(
283
- label="Generated Image",
284
- type="pil"
285
- ),
286
  ],
287
- title="Text to 3D Model Generator",
288
- description="Enter a text prompt to generate an image that will be converted into a 3D model",
289
  )
290
 
291
  if __name__ == "__main__":
 
47
  """Create an RGBA image from RGB image and optional mask."""
48
  rgba_image = rgb_image.convert('RGBA')
49
  if mask is not None:
 
50
  # Ensure mask is 2D before converting to alpha
51
  if len(mask.shape) > 2:
52
  mask = mask.squeeze()
53
  alpha = Image.fromarray((mask * 255).astype(np.uint8))
 
54
  rgba_image.putalpha(alpha)
55
  return rgba_image
56
 
 
59
  # Resize and convert input image to numpy array
60
  resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
61
  img_array = np.array(resized_image).astype(np.float32) / 255.0
62
+
 
63
  # Extract RGB and alpha channels
64
  if img_array.shape[-1] == 4: # RGBA
65
  rgb = img_array[..., :3]
 
71
  # Convert to tensors while keeping channel-last format
72
  rgb = torch.from_numpy(rgb).float() # [H, W, 3]
73
  mask = torch.from_numpy(mask).float() # [H, W, 1]
74
+
 
 
75
  # Create background blend (match channel-last format)
76
  bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3) # [1, 1, 3]
77
+
 
78
  # Blend RGB with background using mask (all in channel-last format)
79
  rgb_cond = torch.lerp(bg_tensor, rgb, mask) # [H, W, 3]
80
+
 
81
  # Move channels to correct dimension and add batch dimension
82
  # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
83
  rgb_cond = rgb_cond.unsqueeze(0) # [1, H, W, 3]
84
  mask = mask.unsqueeze(0) # [1, H, W, 1]
85
 
 
 
 
86
  # Create the batch dictionary
87
  batch = {
88
  "rgb_cond": rgb_cond, # [1, H, W, 3]
 
92
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), # [1, 3, 3]
93
  }
94
 
 
95
  for k, v in batch.items():
96
  print(f"[debug] {k} final shape:", v.shape)
97
+
 
 
 
98
  return batch
99
 
100
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
101
  """Process batch through model and generate point cloud."""
102
+
 
 
 
103
  batch_size = batch["rgb_cond"].shape[0]
104
  assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
105
 
 
 
 
 
 
 
106
  # Generate point cloud tokens
 
107
  try:
108
  cond_tokens = system.forward_pdiff_cond(batch)
 
109
  except Exception as e:
110
  print("\n[ERROR] Failed in forward_pdiff_cond:")
111
  print(e)
 
116
  raise
117
 
118
  # Sample points
 
119
  sample_iter = system.sampler.sample_batch_progressive(
120
  batch_size,
121
  cond_tokens,
 
127
  for x in sample_iter:
128
  samples = x["xstart"]
129
 
 
130
  pc_cond = samples.permute(0, 2, 1).float()
131
+
 
132
  # Normalize point cloud
133
  pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
134
+
 
135
  # Subsample to 512 points
136
  pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
137
+
 
138
  return pc_cond
139
 
140
  def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Image.Image | None]:
 
150
 
151
  # Generate image using FLUX
152
  generator = torch.Generator(device=device).manual_seed(seed)
 
153
  generated_image = flux_pipe(
154
  prompt=prompt,
155
  width=width,
 
159
  guidance_scale=0.0
160
  ).images[0]
161
 
 
162
  rgb_image = generated_image.convert('RGB')
163
 
 
164
  # bg_remover returns a PIL Image already, no need to convert
165
  no_bg_image = bg_remover.process(rgb_image)
166
  print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
 
169
  rgba_image = no_bg_image.convert('RGBA')
170
  print(f"[debug] rgba_image mode: {rgba_image.mode}")
171
 
 
172
  processed_image = spar3d_utils.foreground_crop(
173
  rgba_image,
174
  crop_ratio=1.3,
 
181
  print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")
182
 
183
  # Prepare batch for processing
 
184
  batch = create_batch(processed_image)
185
  batch = {k: v.to(device) for k, v in batch.items()}
186
 
 
196
 
197
  # Generate mesh
198
  with torch.no_grad():
 
199
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
200
  trimesh_mesh, _ = spar3d_model.generate_mesh(
201
  batch,
 
207
  trimesh_mesh = trimesh_mesh[0]
208
 
209
  # Export to GLB
 
210
  temp_dir = tempfile.mkdtemp()
211
  output_path = os.path.join(temp_dir, 'output.glb')
212
 
 
213
  trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
214
 
215
+ return output_path
216
 
217
  except Exception as e:
218
  print(f"Error during generation: {str(e)}")
219
  import traceback
220
  traceback.print_exc()
221
+ return None
222
 
223
  # Create Gradio interface
224
  demo = gr.Interface(
 
238
  ],
239
  outputs=[
240
  gr.Model3D(
241
+ label="Generated 3D model",
242
  clear_color=[0.0, 0.0, 0.0, 0.0],
243
+ )
 
 
 
 
244
  ],
245
+ title="Text to 3D",
246
+ description="Enter a text prompt to generate an image that will be converted into a 3D model using Stable Point-Awaire 3D by Stability AI.",
247
  )
248
 
249
  if __name__ == "__main__":