jbilcke-hf HF staff commited on
Commit
c882a68
·
verified ·
1 Parent(s): e02679c

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +44 -28
gradio_app.py CHANGED
@@ -47,38 +47,54 @@ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.
47
  """Create an RGBA image from RGB image and optional mask."""
48
  rgba_image = rgb_image.convert('RGBA')
49
  if mask is not None:
50
- # Convert mask to alpha channel format
 
 
 
51
  alpha = Image.fromarray((mask * 255).astype(np.uint8))
 
52
  rgba_image.putalpha(alpha)
53
  return rgba_image
54
-
55
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
56
- """Prepare image batch for model input."""
57
- # Ensure input is RGBA
58
- if input_image.mode != 'RGBA':
59
- input_image = input_image.convert('RGBA')
60
-
61
- # Resize and convert to numpy array
62
- resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
63
- img_array = np.array(resized_image).astype(np.float32) / 255.0
64
-
65
- # Split into RGB and alpha
66
- mask_cond = img_array[..., 3:4] # Alpha channel
67
- # Blend RGB with background based on alpha
68
- rgb_cond = np.clip(
69
- img_array[..., :3] * mask_cond + BACKGROUND_COLOR * (1 - mask_cond),
70
- 0,
71
- 1
72
- )
73
-
74
- batch = {
75
- "rgb_cond": torch.from_numpy(rgb_cond).unsqueeze(0),
76
- "mask_cond": torch.from_numpy(mask_cond).unsqueeze(0),
77
- "c2w_cond": c2w_cond.unsqueeze(0),
78
- "intrinsic_cond": intrinsic.unsqueeze(0),
79
- "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
80
- }
81
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
84
  """Generate image from prompt and convert to 3D model."""
 
47
  """Create an RGBA image from RGB image and optional mask."""
48
  rgba_image = rgb_image.convert('RGBA')
49
  if mask is not None:
50
+ print("[debug] mask shape before alpha:", mask.shape)
51
+ # Ensure mask is 2D before converting to alpha
52
+ if len(mask.shape) > 2:
53
+ mask = mask.squeeze()
54
  alpha = Image.fromarray((mask * 255).astype(np.uint8))
55
+ print("[debug] alpha size:", alpha.size)
56
  rgba_image.putalpha(alpha)
57
  return rgba_image
58
+
59
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
60
+ """Prepare image batch for model input."""
61
+ # Ensure input is RGBA
62
+ if input_image.mode != 'RGBA':
63
+ input_image = input_image.convert('RGBA')
64
+
65
+ # Resize and convert to numpy array
66
+ resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
67
+ img_array = np.array(resized_image).astype(np.float32) / 255.0
68
+
69
+ print("[debug] img_array shape:", img_array.shape)
70
+
71
+ # Split into RGB and alpha
72
+ rgb = torch.from_numpy(img_array[..., :3]).float()
73
+ alpha = torch.from_numpy(img_array[..., 3:4]).float()
74
+
75
+ print("[debug] rgb tensor shape:", rgb.shape)
76
+ print("[debug] alpha tensor shape:", alpha.shape)
77
+
78
+ # Create background blend using torch.lerp()
79
+ bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
80
+ print("[debug] bg_tensor shape:", bg_tensor.shape)
81
+
82
+ rgb_cond = torch.lerp(bg_tensor, rgb, alpha)
83
+ print("[debug] rgb_cond shape:", rgb_cond.shape)
84
+
85
+ batch = {
86
+ "rgb_cond": rgb_cond.unsqueeze(0),
87
+ "mask_cond": alpha.unsqueeze(0),
88
+ "c2w_cond": c2w_cond.unsqueeze(0),
89
+ "intrinsic_cond": intrinsic.unsqueeze(0),
90
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
91
+ }
92
+
93
+ # Final shapes check
94
+ for k, v in batch.items():
95
+ print(f"[debug] {k} final shape:", v.shape)
96
+
97
+ return batch
98
 
99
  def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
100
  """Generate image from prompt and convert to 3D model."""