tokenid commited on
Commit
74465a1
β€’
1 Parent(s): 6681628

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -0
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ import torch
5
+ import rembg
6
+ from PIL import Image
7
+ from torchvision.transforms import v2
8
+ from pytorch_lightning import seed_everything
9
+ from omegaconf import OmegaConf
10
+ from einops import rearrange, repeat
11
+ from tqdm import tqdm
12
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
13
+
14
+ from src.utils.train_util import instantiate_from_config
15
+ from src.utils.camera_util import (
16
+ FOV_to_intrinsics,
17
+ get_zero123plus_input_cameras,
18
+ get_circular_camera_poses,
19
+ )
20
+ from src.utils.mesh_util import save_obj
21
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
22
+
23
+ import tempfile
24
+ from functools import partial
25
+
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ import gradio as gr
29
+ import spaces
30
+
31
+
32
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
33
+ """
34
+ Get the rendering camera parameters.
35
+ """
36
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
37
+ if is_flexicubes:
38
+ cameras = torch.linalg.inv(c2ws)
39
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
40
+ else:
41
+ extrinsics = c2ws.flatten(-2)
42
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
43
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
44
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
45
+ return cameras
46
+
47
+
48
+ def images_to_video(images, output_path, fps=30):
49
+ # images: (N, C, H, W)
50
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
51
+ frames = []
52
+ for i in range(images.shape[0]):
53
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
54
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
55
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
56
+ assert frame.min() >= 0 and frame.max() <= 255, \
57
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
58
+ frames.append(frame)
59
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
60
+
61
+
62
+ ###############################################################################
63
+ # Configuration.
64
+ ###############################################################################
65
+
66
+ config_path = 'configs/instant-mesh-large.yaml'
67
+ config = OmegaConf.load(config_path)
68
+ config_name = os.path.basename(config_path).replace('.yaml', '')
69
+ model_config = config.model_config
70
+ infer_config = config.infer_config
71
+
72
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
73
+
74
+ device = torch.device('cuda')
75
+
76
+ # load diffusion model
77
+ print('Loading diffusion model ...')
78
+ pipeline = DiffusionPipeline.from_pretrained(
79
+ "sudo-ai/zero123plus-v1.2",
80
+ custom_pipeline="zero123plus",
81
+ torch_dtype=torch.float16,
82
+ )
83
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
84
+ pipeline.scheduler.config, timestep_spacing='trailing'
85
+ )
86
+
87
+ # load custom white-background UNet
88
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
89
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
90
+ pipeline.unet.load_state_dict(state_dict, strict=True)
91
+
92
+ pipeline = pipeline.to(device)
93
+
94
+ # load reconstruction model
95
+ print('Loading reconstruction model ...')
96
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
97
+ model = instantiate_from_config(model_config)
98
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
99
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
100
+ model.load_state_dict(state_dict, strict=True)
101
+
102
+ model = model.to(device)
103
+ if IS_FLEXICUBES:
104
+ model.init_flexicubes_geometry(device)
105
+ model = model.eval()
106
+
107
+ print('Loading Finished!')
108
+
109
+
110
+ def check_input_image(input_image):
111
+ if input_image is None:
112
+ raise gr.Error("No image uploaded!")
113
+
114
+
115
+ def preprocess(input_image, do_remove_background):
116
+
117
+ rembg_session = rembg.new_session() if do_remove_background else None
118
+
119
+ if do_remove_background:
120
+ input_image = remove_background(input_image, rembg_session)
121
+ input_image = resize_foreground(input_image, 0.85)
122
+
123
+ return input_image
124
+
125
+
126
+ def generate_mvs(input_image, sample_steps, sample_seed):
127
+
128
+ seed_everything(sample_seed)
129
+
130
+ # sampling
131
+ z123_image = pipeline(
132
+ input_image,
133
+ num_inference_steps=sample_steps
134
+ ).images[0]
135
+
136
+ show_image = np.asarray(z123_image, dtype=np.uint8)
137
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
138
+ show_image = rearrange(show_image, '(n h) (m w) c -> (m h) (n w) c', n=3, m=2)
139
+ show_image = Image.fromarray(show_image.numpy())
140
+
141
+ return z123_image, show_image
142
+
143
+ def make_mesh(mesh_fpath, planes):
144
+
145
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
146
+ mesh_dirname = os.path.dirname(mesh_fpath)
147
+ mesh_vis_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
148
+
149
+ with torch.no_grad():
150
+
151
+ # get mesh
152
+ mesh_out = model.extract_mesh(
153
+ planes,
154
+ use_texture_map=False,
155
+ **infer_config,
156
+ )
157
+
158
+ vertices, faces, vertex_colors = mesh_out
159
+ vertices = vertices[:, [0, 2, 1]]
160
+ vertices[:, -1] *= -1
161
+
162
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
163
+
164
+ print(f"Mesh saved to {mesh_fpath}")
165
+
166
+ return mesh_fpath
167
+
168
+ @spaces.GPU
169
+ def make3d(input_image, sample_steps, sample_seed):
170
+
171
+ images, show_images = generate_mvs(input_image, sample_steps, sample_seed)
172
+
173
+ images = np.asarray(images, dtype=np.float32) / 255.0
174
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
175
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
176
+
177
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=2.5).to(device)
178
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
179
+
180
+ images = images.unsqueeze(0).to(device)
181
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
182
+
183
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
184
+ print(mesh_fpath)
185
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
186
+ mesh_dirname = os.path.dirname(mesh_fpath)
187
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
188
+
189
+ with torch.no_grad():
190
+ # get triplane
191
+ planes = model.forward_planes(images, input_cameras)
192
+
193
+ # get video
194
+ chunk_size = 20 if IS_FLEXICUBES else 1
195
+ render_size = 384
196
+
197
+ frames = []
198
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
199
+ if IS_FLEXICUBES:
200
+ frame = model.forward_geometry(
201
+ planes,
202
+ render_cameras[:, i:i+chunk_size],
203
+ render_size=render_size,
204
+ )['img']
205
+ else:
206
+ frame = model.synthesizer(
207
+ planes,
208
+ cameras=render_cameras[:, i:i+chunk_size],
209
+ render_size=render_size,
210
+ )['images_rgb']
211
+ frames.append(frame)
212
+ frames = torch.cat(frames, dim=1)
213
+
214
+ images_to_video(
215
+ frames[0],
216
+ video_fpath,
217
+ fps=30,
218
+ )
219
+
220
+ print(f"Video saved to {video_fpath}")
221
+
222
+ mesh_fpath = make_mesh(mesh_fpath, planes)
223
+
224
+ return video_fpath, mesh_fpath, show_images
225
+
226
+
227
+ _HEADER_ = '''
228
+ <h2><b>Official πŸ€— Gradio demo for</b>
229
+ <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>
230
+ <b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b>
231
+ </a>.
232
+ </h2>
233
+ '''
234
+
235
+ _LINKS_ = '''
236
+ <h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3>
237
+ <h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3>
238
+ '''
239
+
240
+ _CITE_ = r"""
241
+ ```bibtex
242
+ @article{xu2024instantmesh,
243
+ title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
244
+ author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
245
+ journal={arXiv preprint arXiv:2404.07191},
246
+ year={2024}
247
+ }
248
+ ```
249
+ """
250
+
251
+
252
+ with gr.Blocks() as demo:
253
+ gr.Markdown(_HEADER_)
254
+ with gr.Row(variant="panel"):
255
+ with gr.Column():
256
+ with gr.Row():
257
+ input_image = gr.Image(
258
+ label="Input Image",
259
+ image_mode="RGBA",
260
+ sources="upload",
261
+ width=256,
262
+ height=256,
263
+ type="pil",
264
+ elem_id="content_image",
265
+ )
266
+ processed_image = gr.Image(
267
+ label="Processed Image",
268
+ image_mode="RGBA",
269
+ width=256,
270
+ height=256,
271
+ type="pil",
272
+ interactive=False
273
+ )
274
+ with gr.Row():
275
+ with gr.Group():
276
+ do_remove_background = gr.Checkbox(
277
+ label="Remove Background", value=True
278
+ )
279
+ sample_seed = gr.Number(value=42, label="Seed (Try a different value if the result is unsatisfying)", precision=0)
280
+
281
+ sample_steps = gr.Slider(
282
+ label="Sample Steps",
283
+ minimum=30,
284
+ maximum=75,
285
+ value=75,
286
+ step=5
287
+ )
288
+
289
+ with gr.Row():
290
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
291
+
292
+ with gr.Row(variant="panel"):
293
+ gr.Examples(
294
+ examples=[
295
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
296
+ ],
297
+ inputs=[input_image],
298
+ label="Examples",
299
+ examples_per_page=15
300
+ )
301
+
302
+ with gr.Column():
303
+
304
+ with gr.Row():
305
+
306
+ with gr.Column():
307
+ mv_show_images = gr.Image(
308
+ label="Generated Multi-views",
309
+ type="pil",
310
+ width=379,
311
+ interactive=False
312
+ )
313
+
314
+ with gr.Column():
315
+ output_video = gr.Video(
316
+ label="video", format="mp4",
317
+ width=379,
318
+ autoplay=True,
319
+ interactive=False
320
+ )
321
+
322
+ with gr.Row():
323
+ output_model_obj = gr.Model3D(
324
+ label="Output Model (OBJ Format)",
325
+ width=768,
326
+ interactive=False,
327
+ )
328
+ gr.Markdown(_LINKS_)
329
+ gr.Markdown(_CITE_)
330
+
331
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
332
+ fn=preprocess,
333
+ inputs=[input_image, do_remove_background],
334
+ outputs=[processed_image],
335
+ ).success(
336
+ fn=make3d,
337
+ inputs=[processed_image, sample_steps, sample_seed],
338
+ outputs=[output_video, output_model_obj, mv_show_images]
339
+ )
340
+
341
+ demo.launch()