jbilcke-hf HF staff commited on
Commit
2728300
·
verified ·
1 Parent(s): 2bed0cd

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +43 -17
gradio_app.py CHANGED
@@ -58,13 +58,22 @@ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.
58
 
59
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
60
  """Prepare image batch for model input."""
61
- # Convert input image to numpy array and normalize
62
- img_array = np.array(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0
 
63
  print("[debug] img_array shape:", img_array.shape)
64
 
65
  # Extract RGB and alpha channels
66
- rgb = torch.from_numpy(img_array[..., :3]).float()
67
- mask = torch.from_numpy(img_array[..., 3:4]).float()
 
 
 
 
 
 
 
 
68
  print("[debug] rgb tensor shape:", rgb.shape)
69
  print("[debug] mask tensor shape:", mask.shape)
70
 
@@ -76,15 +85,16 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
76
  rgb_cond = torch.lerp(bg_tensor, rgb, mask)
77
  print("[debug] rgb_cond shape:", rgb_cond.shape)
78
 
79
- # Note: We need to permute the tensors to match the expected shape
80
- rgb_cond = rgb_cond.permute(2, 0, 1) # Change from [H, W, C] to [C, H, W]
81
- mask = mask.permute(2, 0, 1) # Change from [H, W, 1] to [1, H, W]
 
82
  print("[debug] rgb_cond after permute shape:", rgb_cond.shape)
83
  print("[debug] mask after permute shape:", mask.shape)
84
 
85
  batch = {
86
- "rgb_cond": rgb_cond.unsqueeze(0), # Add batch dimension
87
- "mask_cond": mask.unsqueeze(0),
88
  "c2w_cond": c2w_cond.unsqueeze(0),
89
  "intrinsic_cond": intrinsic.unsqueeze(0),
90
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
@@ -112,25 +122,23 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
112
  guidance_scale=0.0
113
  ).images[0]
114
 
115
- # Process the generated image
116
  print("[debug] converting the image to rgb")
117
  rgb_image = generated_image.convert('RGB')
118
 
119
- # Remove background
120
  print("[debug] removing the background by calling bg_remover.process(rgb_image)")
121
  no_bg_image = bg_remover.process(rgb_image)
122
 
123
- # Convert to numpy array to extract mask
124
  print("[debug] converting to numpy array to extract the mask")
125
  no_bg_array = np.array(no_bg_image)
126
- mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
127
 
128
- # Create RGBA image
 
 
 
129
  print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
130
  rgba_image = create_rgba_image(rgb_image, mask)
131
 
132
- # Auto crop with foreground
133
- print(f"[debug] auto-cropping the rgba_image using spar3d_utils.foreground_crop(...). newsize=(COND_WIDTH, COND_HEIGHT) = ({COND_WIDTH}, {COND_HEIGHT})")
134
  processed_image = spar3d_utils.foreground_crop(
135
  rgba_image,
136
  crop_ratio=1.3,
@@ -138,8 +146,8 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
138
  no_crop=False
139
  )
140
 
 
141
  print("[debug] preparing the batch by calling create_batch(processed_image)")
142
- # Prepare batch for 3D generation
143
  batch = create_batch(processed_image)
144
  batch = {k: v.to(device) for k, v in batch.items()}
145
 
@@ -147,6 +155,24 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
147
  with torch.no_grad():
148
  print("[debug] calling torch.autocast(....) to generate the mesh")
149
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  trimesh_mesh, _ = spar3d_model.generate_mesh(
151
  batch,
152
  1024, # texture_resolution
 
58
 
59
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
60
  """Prepare image batch for model input."""
61
+ # Resize and convert input image to numpy array
62
+ resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
63
+ img_array = np.array(resized_image).astype(np.float32) / 255.0
64
  print("[debug] img_array shape:", img_array.shape)
65
 
66
  # Extract RGB and alpha channels
67
+ if img_array.shape[-1] == 4: # RGBA
68
+ rgb = img_array[..., :3]
69
+ mask = img_array[..., 3:4]
70
+ else: # RGB
71
+ rgb = img_array
72
+ mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
73
+
74
+ # Convert to tensors
75
+ rgb = torch.from_numpy(rgb).float()
76
+ mask = torch.from_numpy(mask).float()
77
  print("[debug] rgb tensor shape:", rgb.shape)
78
  print("[debug] mask tensor shape:", mask.shape)
79
 
 
85
  rgb_cond = torch.lerp(bg_tensor, rgb, mask)
86
  print("[debug] rgb_cond shape:", rgb_cond.shape)
87
 
88
+ # Permute the tensors to match the expected shape [B, C, H, W]
89
+ rgb_cond = rgb_cond.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
90
+ mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
91
+
92
  print("[debug] rgb_cond after permute shape:", rgb_cond.shape)
93
  print("[debug] mask after permute shape:", mask.shape)
94
 
95
  batch = {
96
+ "rgb_cond": rgb_cond,
97
+ "mask_cond": mask,
98
  "c2w_cond": c2w_cond.unsqueeze(0),
99
  "intrinsic_cond": intrinsic.unsqueeze(0),
100
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
 
122
  guidance_scale=0.0
123
  ).images[0]
124
 
 
125
  print("[debug] converting the image to rgb")
126
  rgb_image = generated_image.convert('RGB')
127
 
 
128
  print("[debug] removing the background by calling bg_remover.process(rgb_image)")
129
  no_bg_image = bg_remover.process(rgb_image)
130
 
 
131
  print("[debug] converting to numpy array to extract the mask")
132
  no_bg_array = np.array(no_bg_image)
 
133
 
134
+ # Create mask based on RGB values
135
+ mask = ((no_bg_array > 0).any(axis=2)).astype(np.float32)
136
+ mask = np.expand_dims(mask, axis=2) # Add channel dimension
137
+
138
  print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
139
  rgba_image = create_rgba_image(rgb_image, mask)
140
 
141
+ print(f"[debug] auto-cropping the rgba_image using spar3d_utils.foreground_crop(...)")
 
142
  processed_image = spar3d_utils.foreground_crop(
143
  rgba_image,
144
  crop_ratio=1.3,
 
146
  no_crop=False
147
  )
148
 
149
+ # Forward pass through SPAR3D
150
  print("[debug] preparing the batch by calling create_batch(processed_image)")
 
151
  batch = create_batch(processed_image)
152
  batch = {k: v.to(device) for k, v in batch.items()}
153
 
 
155
  with torch.no_grad():
156
  print("[debug] calling torch.autocast(....) to generate the mesh")
157
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
158
+ # Add point cloud conditioning to match expected input
159
+ if "pc_cond" not in batch:
160
+ # Sample tokens from model's diffusion process
161
+ cond_tokens = spar3d_model.forward_pdiff_cond(batch)
162
+ sample_iter = spar3d_model.sampler.sample_batch_progressive(
163
+ 1, # batch size
164
+ cond_tokens,
165
+ guidance_scale=3.0,
166
+ device=device,
167
+ )
168
+ for x in sample_iter:
169
+ samples = x["xstart"]
170
+ # Add point cloud to batch
171
+ batch["pc_cond"] = samples.permute(0, 2, 1).float()
172
+ batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"])
173
+ # Subsample to 512 points
174
+ batch["pc_cond"] = batch["pc_cond"][:, torch.randperm(batch["pc_cond"].shape[1])[:512]]
175
+
176
  trimesh_mesh, _ = spar3d_model.generate_mesh(
177
  batch,
178
  1024, # texture_resolution