File size: 8,313 Bytes
38dbec8
 
 
 
eecf990
38dbec8
eecf990
 
38dbec8
1c05005
38dbec8
eecf990
a399d55
e2ccc8a
 
38dbec8
eecf990
e2ccc8a
 
 
 
 
38dbec8
eecf990
e2ccc8a
eecf990
1c05005
e2ccc8a
 
eecf990
 
38dbec8
1c05005
 
 
 
 
 
 
eecf990
 
 
 
 
38dbec8
aec7186
 
 
 
c882a68
 
 
aec7186
 
 
03dc078
daf9fe6
c882a68
2728300
 
 
0471bc8
03dc078
2728300
 
 
 
 
 
 
4b4ce8a
9e70cab
 
0471bc8
4b4ce8a
 
0471bc8
4b4ce8a
 
0471bc8
4b4ce8a
 
 
 
2728300
4b4ce8a
c882a68
4b4ce8a
 
 
 
 
c882a68
 
 
 
0471bc8
c882a68
eecf990
751171e
 
0471bc8
751171e
4b4ce8a
 
751171e
4b4ce8a
 
 
 
 
 
 
 
 
 
751171e
 
 
 
 
 
 
 
 
 
 
 
 
 
0471bc8
751171e
 
0471bc8
751171e
 
0471bc8
751171e
 
af1d9cb
1c05005
b0a67b8
 
 
af1d9cb
 
 
b0a67b8
eecf990
751171e
 
 
 
1c05005
e973397
1c05005
 
 
 
 
 
 
 
 
daf9fe6
 
3b58a26
aec7186
3b58a26
aec7186
3b58a26
 
 
eecf990
daf9fe6
e973397
1c05005
38dbec8
eecf990
38dbec8
dc16672
 
 
 
38dbec8
751171e
daf9fe6
eecf990
 
751171e
 
 
 
 
 
 
 
 
 
eecf990
 
e973397
1c05005
eecf990
9babed2
eecf990
 
287be50
eecf990
 
 
 
e973397
 
f779fbc
e973397
eecf990
0471bc8
eecf990
 
daf9fe6
751171e
 
0471bc8
af1d9cb
 
 
 
6599110
af1d9cb
 
 
1c05005
6599110
1c05005
af1d9cb
 
 
 
 
 
 
a6bc9a4
0471bc8
af1d9cb
 
 
 
 
 
 
 
 
eecf990
e973397
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import os
import tempfile
from typing import Any
import torch
import numpy as np
from PIL import Image
import gradio as gr
import trimesh
from transparent_background import Remover
from diffusers import DiffusionPipeline

# Import and setup SPAR3D 
os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
import spar3d.utils as spar3d_utils
from spar3d.system import SPAR3D

# Constants
COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 2.2
COND_FOVY = 0.591627
BACKGROUND_COLOR = [0.5, 0.5, 0.5]

# Initialize models
device = spar3d_utils.get_device()
bg_remover = Remover()
spar3d_model = SPAR3D.from_pretrained(
    "stabilityai/stable-point-aware-3d",
    config_name="config.yaml",
    weight_name="model.safetensors"
).eval().to(device)

# Initialize FLUX model
dtype = torch.bfloat16
flux_pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=dtype
).to(device)

# Initialize camera parameters
c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
    COND_FOVY, COND_HEIGHT, COND_WIDTH
)

def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
    """Create an RGBA image from RGB image and optional mask."""
    rgba_image = rgb_image.convert('RGBA')
    if mask is not None:
        # Ensure mask is 2D before converting to alpha
        if len(mask.shape) > 2:
            mask = mask.squeeze()
        alpha = Image.fromarray((mask * 255).astype(np.uint8))
        rgba_image.putalpha(alpha)
    return rgba_image

def create_batch(input_image: Image.Image) -> dict[str, Any]:
    """Prepare image batch for model input."""
    # Resize and convert input image to numpy array
    resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
    img_array = np.array(resized_image).astype(np.float32) / 255.0

    # Extract RGB and alpha channels
    if img_array.shape[-1] == 4:  # RGBA
        rgb = img_array[..., :3]
        mask = img_array[..., 3:4]
    else:  # RGB
        rgb = img_array
        mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
    
    # Convert to tensors while keeping channel-last format
    rgb = torch.from_numpy(rgb).float()  # [H, W, 3]
    mask = torch.from_numpy(mask).float()  # [H, W, 1]

    # Create background blend (match channel-last format)
    bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3)  # [1, 1, 3]
 
    # Blend RGB with background using mask (all in channel-last format)
    rgb_cond = torch.lerp(bg_tensor, rgb, mask)  # [H, W, 3]
 
    # Move channels to correct dimension and add batch dimension
    # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
    rgb_cond = rgb_cond.unsqueeze(0)  # [1, H, W, 3]
    mask = mask.unsqueeze(0)  # [1, H, W, 1]
    
    # Create the batch dictionary
    batch = {
        "rgb_cond": rgb_cond,  # [1, H, W, 3]
        "mask_cond": mask,  # [1, H, W, 1]
        "c2w_cond": c2w_cond.unsqueeze(0),  # [1, 4, 4]
        "intrinsic_cond": intrinsic.unsqueeze(0),  # [1, 3, 3]
        "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),  # [1, 3, 3]
    }
    
    for k, v in batch.items():
        print(f"[debug] {k} final shape:", v.shape)
 
    return batch

def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
    """Process batch through model and generate point cloud."""

    batch_size = batch["rgb_cond"].shape[0]
    assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
    
    # Generate point cloud tokens
    try:
        cond_tokens = system.forward_pdiff_cond(batch)
    except Exception as e:
        print("\n[ERROR] Failed in forward_pdiff_cond:")
        print(e)
        print("\nInput tensor properties:")
        print("rgb_cond dtype:", batch["rgb_cond"].dtype)
        print("rgb_cond device:", batch["rgb_cond"].device)
        print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
        raise
    
    # Sample points
    sample_iter = system.sampler.sample_batch_progressive(
        batch_size,
        cond_tokens,
        guidance_scale=guidance_scale,
        device=device
    )
    
    # Get final samples
    for x in sample_iter:
        samples = x["xstart"]
    
    pc_cond = samples.permute(0, 2, 1).float()

    # Normalize point cloud
    pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)

    # Subsample to 512 points
    pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]

    return pc_cond

def generate_and_process_3d(prompt: str) -> tuple[str | None, Image.Image | None]:
    """Generate image from prompt and convert to 3D model."""

    width: int = 1024
    height: int = 1024

    # Generate random seed
    seed = np.random.randint(0, np.iinfo(np.int32).max)
    
    try:
        # Set random seeds
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # Generate image using FLUX
        generator = torch.Generator(device=device).manual_seed(seed)
        generated_image = flux_pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=4,
            generator=generator,
            guidance_scale=0.0
        ).images[0]
        
        rgb_image = generated_image.convert('RGB')
        
        # bg_remover returns a PIL Image already, no need to convert
        no_bg_image = bg_remover.process(rgb_image)
        print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
        
        # Convert to RGBA if not already
        rgba_image = no_bg_image.convert('RGBA')
        print(f"[debug] rgba_image mode: {rgba_image.mode}")
        
        processed_image = spar3d_utils.foreground_crop(
            rgba_image,
            crop_ratio=1.3,
            newsize=(COND_WIDTH, COND_HEIGHT),
            no_crop=False
        )
        
        # Show the processed image alpha channel for debugging
        alpha = np.array(processed_image)[:, :, 3]
        print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")

        # Prepare batch for processing
        batch = create_batch(processed_image)
        batch = {k: v.to(device) for k, v in batch.items()}

        # Generate point cloud
        pc_cond = forward_model(
            batch,
            spar3d_model,
            guidance_scale=3.0,
            seed=seed,
            device=device
        )
        batch["pc_cond"] = pc_cond

        # Generate mesh
        with torch.no_grad():
            with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
                trimesh_mesh, _ = spar3d_model.generate_mesh(
                    batch,
                    1024,  # texture_resolution
                    remesh="none",
                    vertex_count=-1,
                    estimate_illumination=True
                )
                trimesh_mesh = trimesh_mesh[0]

        # Export to GLB
        temp_dir = tempfile.mkdtemp()
        output_path = os.path.join(temp_dir, 'output.glb')
        
        trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
        
        return output_path
        
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

# Create Gradio app using Blocks
with gr.Blocks() as demo:
    gr.Markdown("# Text to 3D")
    gr.Markdown("This space is based on [Stable Point-Aware 3D](https://huggingface.co/spaces/stabilityai/stable-point-aware-3d) by Stability AI.")
    
    with gr.Row():
        prompt_input = gr.Text(
            label="Enter your prompt",
            placeholder="eg. isometric 3D castle"
        )
    
    with gr.Row():
        generate_btn = gr.Button("Generate", variant="primary")
    
    with gr.Row():
        model_output = gr.Model3D(
            label="Generated .GLB model",
            clear_color=[0.0, 0.0, 0.0, 0.0],
        )
    
    # Event handler
    generate_btn.click(
        fn=generate_and_process_3d,
        inputs=[prompt_input],
        outputs=[model_output],
        api_name="generate"
    )
    
if __name__ == "__main__":
    demo.queue().launch()