jbilcke-hf HF staff commited on
Commit
1c05005
·
verified ·
1 Parent(s): 287be50

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +78 -24
gradio_app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  import gradio as gr
8
  import trimesh
9
  from transparent_background import Remover
 
10
 
11
  # Import and setup SPAR3D
12
  os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
@@ -23,12 +24,19 @@ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
23
  # Initialize models
24
  device = spar3d_utils.get_device()
25
  bg_remover = Remover()
26
- model = SPAR3D.from_pretrained(
27
  "stabilityai/stable-point-aware-3d",
28
  config_name="config.yaml",
29
  weight_name="model.safetensors"
30
  ).eval().to(device)
31
 
 
 
 
 
 
 
 
32
  # Initialize camera parameters
33
  c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
34
  intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
@@ -59,20 +67,30 @@ def create_batch(input_image: Image) -> dict[str, Any]:
59
  }
60
  return batch
61
 
62
- def process_image(image_path: str) -> str:
63
- """Process image and return path to GLB file."""
64
  try:
65
- # Load image
66
- input_image = Image.open(image_path)
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # Remove background if needed
69
- if input_image.mode != 'RGBA':
70
- input_image = bg_remover.process(input_image.convert("RGB"))
71
 
72
  # Auto crop
73
  input_image = spar3d_utils.foreground_crop(
74
  input_image,
75
- crop_ratio=1.3, # Default padding ratio
76
  newsize=(COND_WIDTH, COND_HEIGHT),
77
  no_crop=False
78
  )
@@ -83,10 +101,10 @@ def process_image(image_path: str) -> str:
83
 
84
  # Generate mesh
85
  with torch.no_grad():
86
- with torch.autocast(device_type=device, dtype=torch.bfloat16) if "cuda" in device else nullcontext():
87
- trimesh_mesh, _ = model.generate_mesh(
88
  batch,
89
- 1024, # <- texture_resolution
90
  remesh="none",
91
  vertex_count=-1,
92
  estimate_illumination=True
@@ -97,24 +115,60 @@ def process_image(image_path: str) -> str:
97
  temp_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
98
  trimesh_mesh.export(temp_file.name, file_type="glb", include_normals=True)
99
 
100
- return temp_file.name
101
 
102
  except Exception as e:
103
- return str(e)
104
 
105
  # Create Gradio interface
 
 
 
 
 
 
106
  demo = gr.Interface(
107
- fn=process_image,
108
- inputs=gr.File(
109
- label="Upload Image",
110
- file_types=["image"],
111
- ),
112
- outputs=gr.File(
113
- label="Download GLB",
114
- file_types=[".glb"],
115
- ),
116
- title="SPAR3D Image to GLB Converter",
117
- description="Upload an image (JPG, PNG, or WebP) and get back a 3D model in GLB format",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
 
120
  if __name__ == "__main__":
 
7
  import gradio as gr
8
  import trimesh
9
  from transparent_background import Remover
10
+ from diffusers import DiffusionPipeline
11
 
12
  # Import and setup SPAR3D
13
  os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
 
24
  # Initialize models
25
  device = spar3d_utils.get_device()
26
  bg_remover = Remover()
27
+ spar3d_model = SPAR3D.from_pretrained(
28
  "stabilityai/stable-point-aware-3d",
29
  config_name="config.yaml",
30
  weight_name="model.safetensors"
31
  ).eval().to(device)
32
 
33
+ # Initialize FLUX model
34
+ dtype = torch.bfloat16
35
+ flux_pipe = DiffusionPipeline.from_pretrained(
36
+ "black-forest-labs/FLUX.1-schnell",
37
+ torch_dtype=dtype
38
+ ).to(device)
39
+
40
  # Initialize camera parameters
41
  c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
42
  intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
 
67
  }
68
  return batch
69
 
70
+ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> str:
71
+ """Generate image from prompt and convert to 3D model."""
72
  try:
73
+ # Generate image using FLUX
74
+ generator = torch.Generator().manual_seed(seed)
75
+ generated_image = flux_pipe(
76
+ prompt=prompt,
77
+ width=width,
78
+ height=height,
79
+ num_inference_steps=4,
80
+ generator=generator,
81
+ guidance_scale=0.0
82
+ ).images[0]
83
+
84
+ # Convert PIL image to RGBA
85
+ input_image = generated_image.convert("RGBA")
86
 
87
  # Remove background if needed
88
+ input_image = bg_remover.process(input_image.convert("RGB"))
 
89
 
90
  # Auto crop
91
  input_image = spar3d_utils.foreground_crop(
92
  input_image,
93
+ crop_ratio=1.3,
94
  newsize=(COND_WIDTH, COND_HEIGHT),
95
  no_crop=False
96
  )
 
101
 
102
  # Generate mesh
103
  with torch.no_grad():
104
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
105
+ trimesh_mesh, _ = spar3d_model.generate_mesh(
106
  batch,
107
+ 1024, # texture_resolution
108
  remesh="none",
109
  vertex_count=-1,
110
  estimate_illumination=True
 
115
  temp_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
116
  trimesh_mesh.export(temp_file.name, file_type="glb", include_normals=True)
117
 
118
+ return temp_file.name, generated_image
119
 
120
  except Exception as e:
121
+ return str(e), None
122
 
123
  # Create Gradio interface
124
+ examples = [
125
+ "a tiny astronaut hatching from an egg on the moon",
126
+ "a cat holding a sign that says hello world",
127
+ "an anime illustration of a wiener schnitzel",
128
+ ]
129
+
130
  demo = gr.Interface(
131
+ fn=generate_and_process_3d,
132
+ inputs=[
133
+ gr.Text(
134
+ label="Enter your prompt",
135
+ placeholder="Describe what you want to generate..."
136
+ ),
137
+ gr.Slider(
138
+ label="Seed",
139
+ minimum=0,
140
+ maximum=np.iinfo(np.int32).max,
141
+ step=1,
142
+ value=42
143
+ ),
144
+ gr.Slider(
145
+ label="Width",
146
+ minimum=256,
147
+ maximum=2048,
148
+ step=32,
149
+ value=1024
150
+ ),
151
+ gr.Slider(
152
+ label="Height",
153
+ minimum=256,
154
+ maximum=2048,
155
+ step=32,
156
+ value=1024
157
+ )
158
+ ],
159
+ outputs=[
160
+ gr.File(
161
+ label="Download GLB",
162
+ file_types=[".glb"],
163
+ ),
164
+ gr.Image(
165
+ label="Generated Image",
166
+ type="pil"
167
+ )
168
+ ],
169
+ title="Text to 3D Model Generator",
170
+ description="Enter a text prompt to generate an image that will be converted into a 3D model",
171
+ examples=examples
172
  )
173
 
174
  if __name__ == "__main__":