abreza commited on
Commit
4d6f443
·
1 Parent(s): 39b4836

refactor the code

Browse files
Files changed (1) hide show
  1. app.py +154 -160
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import os
2
  import shutil
3
  import tempfile
4
- from functools import partial
5
 
6
  import gradio as gr
7
  import numpy as np
8
  import rembg
 
9
  import torch
10
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
11
  from einops import rearrange
@@ -16,8 +16,9 @@ from pytorch_lightning import seed_everything
16
  from torchvision.transforms import v2
17
  from tqdm import tqdm
18
 
19
- from src.utils.camera_util import FOV_to_intrinsics, get_circular_camera_poses, get_zero123plus_input_cameras
20
- from src.utils.infer_util import images_to_video, remove_background, resize_foreground
 
21
  from src.utils.mesh_util import save_glb, save_obj
22
  from src.utils.train_util import instantiate_from_config
23
 
@@ -42,50 +43,13 @@ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexi
42
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
43
  else:
44
  extrinsics = c2ws.flatten(-2)
45
- intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
 
46
  cameras = torch.cat([extrinsics, intrinsics], dim=-1)
47
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
48
  return cameras
49
 
50
 
51
- def load_models(config_path):
52
- config = OmegaConf.load(config_path)
53
- config_name = os.path.basename(config_path).replace('.yaml', '')
54
- model_config = config.model_config
55
- infer_config = config.infer_config
56
-
57
- is_flexicubes = config_name.startswith('instant-mesh')
58
-
59
- device = torch.device('cuda')
60
-
61
- pipeline = DiffusionPipeline.from_pretrained(
62
- "sudo-ai/zero123plus-v1.2",
63
- custom_pipeline="zero123plus",
64
- torch_dtype=torch.float16,
65
- )
66
- pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
67
- pipeline.scheduler.config, timestep_spacing='trailing'
68
- )
69
-
70
- unet_ckpt_path = hf_hub_download(
71
- repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
72
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
73
- pipeline.unet.load_state_dict(state_dict, strict=True)
74
-
75
- pipeline = pipeline.to(device)
76
-
77
- model_ckpt_path = hf_hub_download(
78
- repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
79
- model = instantiate_from_config(model_config)
80
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
81
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
82
- model.load_state_dict(state_dict, strict=True)
83
-
84
- model = model.to(device)
85
-
86
- return pipeline, model, is_flexicubes, infer_config
87
-
88
-
89
  def check_input_image(input_image):
90
  if input_image is None:
91
  raise gr.Error("No image uploaded!")
@@ -101,27 +65,28 @@ def preprocess(input_image, do_remove_background):
101
  return input_image
102
 
103
 
104
- def generate_mvs(input_image, sample_steps, sample_seed, pipeline):
 
105
  seed_everything(sample_seed)
106
 
107
  z123_image = pipeline(
108
- input_image,
109
- num_inference_steps=sample_steps
110
- ).images[0]
111
 
112
  show_image = np.asarray(z123_image, dtype=np.uint8)
113
  show_image = torch.from_numpy(show_image)
114
- show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
115
- show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
 
 
116
  show_image = Image.fromarray(show_image.numpy())
117
 
118
  return z123_image, show_image
119
 
120
 
121
- def make3d(images, model, is_flexicubes, infer_config):
122
- device = torch.device('cuda')
123
-
124
- if is_flexicubes:
125
  model.init_flexicubes_geometry(device, use_renderer=False)
126
  model = model.eval()
127
 
@@ -129,20 +94,25 @@ def make3d(images, model, is_flexicubes, infer_config):
129
  images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
130
  images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
131
 
132
- input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
133
- render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=is_flexicubes).to(device)
 
 
134
 
135
  images = images.unsqueeze(0).to(device)
136
- images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
 
137
 
138
  mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
 
139
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
140
  mesh_dirname = os.path.dirname(mesh_fpath)
141
  mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
142
 
143
  with torch.no_grad():
144
  planes = model.forward_planes(images, input_cameras)
145
- mesh_out = model.extract_mesh(planes, use_texture_map=False, **infer_config)
 
146
 
147
  vertices, faces, vertex_colors = mesh_out
148
  vertices = vertices[:, [1, 2, 0]]
@@ -150,115 +120,139 @@ def make3d(images, model, is_flexicubes, infer_config):
150
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
151
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
152
 
 
 
153
  return mesh_fpath, mesh_glb_fpath
154
 
155
 
156
- def launch_demo(config_path):
157
- cuda_path = find_cuda()
158
- if cuda_path:
159
- print(f"CUDA installation found at: {cuda_path}")
160
- else:
161
- print("CUDA installation not found")
162
-
163
- pipeline, model, is_flexicubes, infer_config = load_models(config_path)
164
-
165
- with gr.Blocks() as demo:
166
- with gr.Row(variant="panel"):
167
- with gr.Column():
168
- with gr.Row():
169
- input_image = gr.Image(
170
- label="Input Image",
171
- image_mode="RGBA",
172
- sources="upload",
173
- type="pil",
174
- elem_id="content_image",
175
- )
176
- processed_image = gr.Image(
177
- label="Processed Image",
178
- image_mode="RGBA",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  type="pil",
 
180
  interactive=False
181
  )
182
- with gr.Row():
183
- with gr.Group():
184
- do_remove_background = gr.Checkbox(
185
- label="Remove Background", value=True
186
- )
187
- sample_seed = gr.Number(
188
- value=42, label="Seed Value", precision=0)
189
-
190
- sample_steps = gr.Slider(
191
- label="Sample Steps",
192
- minimum=30,
193
- maximum=75,
194
- value=75,
195
- step=5
196
- )
197
-
198
- with gr.Row():
199
- submit = gr.Button(
200
- "Generate", elem_id="generate", variant="primary")
201
-
202
- with gr.Row(variant="panel"):
203
- gr.Examples(
204
- examples=[
205
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
206
- ],
207
- inputs=[input_image],
208
- label="Examples",
209
- cache_examples=False,
210
- examples_per_page=16
211
- )
212
 
213
- with gr.Column():
214
- with gr.Row():
215
- with gr.Column():
216
- mv_show_images = gr.Image(
217
- label="Generated Multi-views",
218
- type="pil",
219
- width=379,
220
- interactive=False
221
- )
222
-
223
- with gr.Row():
224
- with gr.Tab("OBJ"):
225
- output_model_obj = gr.Model3D(
226
- label="Output Model (OBJ Format)",
227
- interactive=False,
228
- )
229
- gr.Markdown(
230
- "Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
231
- with gr.Tab("GLB"):
232
- output_model_glb = gr.Model3D(
233
- label="Output Model (GLB Format)",
234
- interactive=False,
235
- )
236
- gr.Markdown(
237
- "Note: The model shown here has a darker appearance. Download to get correct results.")
238
-
239
- with gr.Row():
240
  gr.Markdown(
241
- '''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
242
-
243
- mv_images = gr.State()
244
-
245
- submit.click(fn=check_input_image, inputs=[input_image]).success(
246
- fn=preprocess,
247
- inputs=[input_image, do_remove_background],
248
- outputs=[processed_image],
249
- ).success(
250
- fn=generate_mvs,
251
- inputs=[processed_image, sample_steps, sample_seed, pipeline],
252
- outputs=[mv_images, mv_show_images]
253
- ).success(
254
- fn=make3d,
255
- inputs=[mv_images, model, is_flexicubes, infer_config],
256
- outputs=[output_model_obj, output_model_glb]
257
- )
258
-
259
- demo.launch()
260
-
261
-
262
- if __name__ == "__main__":
263
- config_path = 'configs/instant-mesh-large.yaml'
264
- launch_demo(config_path)
 
1
  import os
2
  import shutil
3
  import tempfile
 
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import rembg
8
+ import spaces
9
  import torch
10
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
11
  from einops import rearrange
 
16
  from torchvision.transforms import v2
17
  from tqdm import tqdm
18
 
19
+ from src.utils.camera_util import (FOV_to_intrinsics, get_circular_camera_poses,
20
+ get_zero123plus_input_cameras)
21
+ from src.utils.infer_util import (remove_background, resize_foreground)
22
  from src.utils.mesh_util import save_glb, save_obj
23
  from src.utils.train_util import instantiate_from_config
24
 
 
43
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
44
  else:
45
  extrinsics = c2ws.flatten(-2)
46
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(
47
+ 0).repeat(M, 1, 1).float().flatten(-2)
48
  cameras = torch.cat([extrinsics, intrinsics], dim=-1)
49
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
50
  return cameras
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def check_input_image(input_image):
54
  if input_image is None:
55
  raise gr.Error("No image uploaded!")
 
65
  return input_image
66
 
67
 
68
+ @spaces.GPU
69
+ def generate_mvs(input_image, sample_steps, sample_seed):
70
  seed_everything(sample_seed)
71
 
72
  z123_image = pipeline(
73
+ input_image, num_inference_steps=sample_steps).images[0]
 
 
74
 
75
  show_image = np.asarray(z123_image, dtype=np.uint8)
76
  show_image = torch.from_numpy(show_image)
77
+ show_image = rearrange(
78
+ show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
79
+ show_image = rearrange(
80
+ show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
81
  show_image = Image.fromarray(show_image.numpy())
82
 
83
  return z123_image, show_image
84
 
85
 
86
+ @spaces.GPU
87
+ def make3d(images):
88
+ global model
89
+ if IS_FLEXICUBES:
90
  model.init_flexicubes_geometry(device, use_renderer=False)
91
  model = model.eval()
92
 
 
94
  images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
95
  images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
96
 
97
+ input_cameras = get_zero123plus_input_cameras(
98
+ batch_size=1, radius=4.0).to(device)
99
+ render_cameras = get_render_cameras(
100
+ batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
101
 
102
  images = images.unsqueeze(0).to(device)
103
+ images = v2.functional.resize(
104
+ images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
105
 
106
  mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
107
+ print(mesh_fpath)
108
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
109
  mesh_dirname = os.path.dirname(mesh_fpath)
110
  mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
111
 
112
  with torch.no_grad():
113
  planes = model.forward_planes(images, input_cameras)
114
+ mesh_out = model.extract_mesh(
115
+ planes, use_texture_map=False, **infer_config)
116
 
117
  vertices, faces, vertex_colors = mesh_out
118
  vertices = vertices[:, [1, 2, 0]]
 
120
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
121
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
122
 
123
+ print(f"Mesh saved to {mesh_fpath}")
124
+
125
  return mesh_fpath, mesh_glb_fpath
126
 
127
 
128
+ # Configuration
129
+ cuda_path = find_cuda()
130
+ config_path = 'configs/instant-mesh-large.yaml'
131
+ config = OmegaConf.load(config_path)
132
+ config_name = os.path.basename(config_path).replace('.yaml', '')
133
+ model_config = config.model_config
134
+ infer_config = config.infer_config
135
+
136
+ IS_FLEXICUBES = config_name.startswith('instant-mesh')
137
+ device = torch.device('cuda')
138
+
139
+ # Load diffusion model
140
+ print('Loading diffusion model ...')
141
+ pipeline = DiffusionPipeline.from_pretrained(
142
+ "sudo-ai/zero123plus-v1.2",
143
+ custom_pipeline="zero123plus",
144
+ torch_dtype=torch.float16,
145
+ )
146
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
147
+ pipeline.scheduler.config, timestep_spacing='trailing'
148
+ )
149
+
150
+ unet_ckpt_path = hf_hub_download(
151
+ repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
152
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
153
+ pipeline.unet.load_state_dict(state_dict, strict=True)
154
+
155
+ pipeline = pipeline.to(device)
156
+
157
+ # Load reconstruction model
158
+ print('Loading reconstruction model ...')
159
+ model_ckpt_path = hf_hub_download(
160
+ repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
161
+ model = instantiate_from_config(model_config)
162
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
163
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith(
164
+ 'lrm_generator.') and 'source_camera' not in k}
165
+ model.load_state_dict(state_dict, strict=True)
166
+
167
+ model = model.to(device)
168
+
169
+ print('Loading Finished!')
170
+
171
+ # Gradio UI
172
+ with gr.Blocks() as demo:
173
+ with gr.Row(variant="panel"):
174
+ with gr.Column():
175
+ with gr.Row():
176
+ input_image = gr.Image(
177
+ label="Input Image",
178
+ image_mode="RGBA",
179
+ sources="upload",
180
+ type="pil",
181
+ elem_id="content_image",
182
+ )
183
+ processed_image = gr.Image(
184
+ label="Processed Image",
185
+ image_mode="RGBA",
186
+ type="pil",
187
+ interactive=False
188
+ )
189
+ with gr.Row():
190
+ with gr.Group():
191
+ do_remove_background = gr.Checkbox(
192
+ label="Remove Background", value=True)
193
+ sample_seed = gr.Number(
194
+ value=42, label="Seed Value", precision=0)
195
+ sample_steps = gr.Slider(
196
+ label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
197
+
198
+ with gr.Row():
199
+ submit = gr.Button(
200
+ "Generate", elem_id="generate", variant="primary")
201
+
202
+ with gr.Row(variant="panel"):
203
+ gr.Examples(
204
+ examples=[os.path.join("examples", img_name)
205
+ for img_name in sorted(os.listdir("examples"))],
206
+ inputs=[input_image],
207
+ label="Examples",
208
+ cache_examples=False,
209
+ examples_per_page=16
210
+ )
211
+
212
+ with gr.Column():
213
+ with gr.Row():
214
+ with gr.Column():
215
+ mv_show_images = gr.Image(
216
+ label="Generated Multi-views",
217
  type="pil",
218
+ width=379,
219
  interactive=False
220
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ with gr.Row():
223
+ with gr.Tab("OBJ"):
224
+ output_model_obj = gr.Model3D(
225
+ label="Output Model (OBJ Format)",
226
+ interactive=False,
227
+ )
228
+ gr.Markdown(
229
+ "Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
230
+ with gr.Tab("GLB"):
231
+ output_model_glb = gr.Model3D(
232
+ label="Output Model (GLB Format)",
233
+ interactive=False,
234
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  gr.Markdown(
236
+ "Note: The model shown here has a darker appearance. Download to get correct results.")
237
+
238
+ with gr.Row():
239
+ gr.Markdown(
240
+ '''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
241
+
242
+ mv_images = gr.State()
243
+
244
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
245
+ fn=preprocess,
246
+ inputs=[input_image, do_remove_background],
247
+ outputs=[processed_image],
248
+ ).success(
249
+ fn=generate_mvs,
250
+ inputs=[processed_image, sample_steps, sample_seed],
251
+ outputs=[mv_images, mv_show_images]
252
+ ).success(
253
+ fn=make3d,
254
+ inputs=[mv_images],
255
+ outputs=[output_model_obj, output_model_glb]
256
+ )
257
+
258
+ demo.launch()