tokenid commited on
Commit
6681628
·
verified ·
1 Parent(s): 397413a

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -341
app.py DELETED
@@ -1,341 +0,0 @@
1
- import os
2
- import imageio
3
- import numpy as np
4
- import torch
5
- import rembg
6
- from PIL import Image
7
- from torchvision.transforms import v2
8
- from pytorch_lightning import seed_everything
9
- from omegaconf import OmegaConf
10
- from einops import rearrange, repeat
11
- from tqdm import tqdm
12
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
13
-
14
- from src.utils.train_util import instantiate_from_config
15
- from src.utils.camera_util import (
16
- FOV_to_intrinsics,
17
- get_zero123plus_input_cameras,
18
- get_circular_camera_poses,
19
- )
20
- from src.utils.mesh_util import save_obj
21
- from src.utils.infer_util import remove_background, resize_foreground, images_to_video
22
-
23
- import tempfile
24
- from functools import partial
25
-
26
- from huggingface_hub import hf_hub_download
27
-
28
- import gradio as gr
29
- import spaces
30
-
31
-
32
- def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
33
- """
34
- Get the rendering camera parameters.
35
- """
36
- c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
37
- if is_flexicubes:
38
- cameras = torch.linalg.inv(c2ws)
39
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
40
- else:
41
- extrinsics = c2ws.flatten(-2)
42
- intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
43
- cameras = torch.cat([extrinsics, intrinsics], dim=-1)
44
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
45
- return cameras
46
-
47
-
48
- def images_to_video(images, output_path, fps=30):
49
- # images: (N, C, H, W)
50
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
51
- frames = []
52
- for i in range(images.shape[0]):
53
- frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
54
- assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
55
- f"Frame shape mismatch: {frame.shape} vs {images.shape}"
56
- assert frame.min() >= 0 and frame.max() <= 255, \
57
- f"Frame value out of range: {frame.min()} ~ {frame.max()}"
58
- frames.append(frame)
59
- imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
60
-
61
-
62
- ###############################################################################
63
- # Configuration.
64
- ###############################################################################
65
-
66
- config_path = 'configs/instant-mesh-large.yaml'
67
- config = OmegaConf.load(config_path)
68
- config_name = os.path.basename(config_path).replace('.yaml', '')
69
- model_config = config.model_config
70
- infer_config = config.infer_config
71
-
72
- IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
73
-
74
- device = torch.device('cuda')
75
-
76
- # load diffusion model
77
- print('Loading diffusion model ...')
78
- pipeline = DiffusionPipeline.from_pretrained(
79
- "sudo-ai/zero123plus-v1.2",
80
- custom_pipeline="zero123plus",
81
- torch_dtype=torch.float16,
82
- )
83
- pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
84
- pipeline.scheduler.config, timestep_spacing='trailing'
85
- )
86
-
87
- # load custom white-background UNet
88
- unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
89
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
90
- pipeline.unet.load_state_dict(state_dict, strict=True)
91
-
92
- pipeline = pipeline.to(device)
93
-
94
- # load reconstruction model
95
- print('Loading reconstruction model ...')
96
- model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
97
- model = instantiate_from_config(model_config)
98
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
99
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
100
- model.load_state_dict(state_dict, strict=True)
101
-
102
- model = model.to(device)
103
- if IS_FLEXICUBES:
104
- model.init_flexicubes_geometry(device)
105
- model = model.eval()
106
-
107
- print('Loading Finished!')
108
-
109
-
110
- def check_input_image(input_image):
111
- if input_image is None:
112
- raise gr.Error("No image uploaded!")
113
-
114
-
115
- def preprocess(input_image, do_remove_background):
116
-
117
- rembg_session = rembg.new_session() if do_remove_background else None
118
-
119
- if do_remove_background:
120
- input_image = remove_background(input_image, rembg_session)
121
- input_image = resize_foreground(input_image, 0.85)
122
-
123
- return input_image
124
-
125
-
126
- def generate_mvs(input_image, sample_steps, sample_seed):
127
-
128
- seed_everything(sample_seed)
129
-
130
- # sampling
131
- z123_image = pipeline(
132
- input_image,
133
- num_inference_steps=sample_steps
134
- ).images[0]
135
-
136
- show_image = np.asarray(z123_image, dtype=np.uint8)
137
- show_image = torch.from_numpy(show_image) # (960, 640, 3)
138
- show_image = rearrange(show_image, '(n h) (m w) c -> (m h) (n w) c', n=3, m=2)
139
- show_image = Image.fromarray(show_image.numpy())
140
-
141
- return z123_image, show_image
142
-
143
- def make_mesh(mesh_fpath, planes):
144
-
145
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
146
- mesh_dirname = os.path.dirname(mesh_fpath)
147
- mesh_vis_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
148
-
149
- with torch.no_grad():
150
-
151
- # get mesh
152
- mesh_out = model.extract_mesh(
153
- planes,
154
- use_texture_map=False,
155
- **infer_config,
156
- )
157
-
158
- vertices, faces, vertex_colors = mesh_out
159
- vertices = vertices[:, [0, 2, 1]]
160
- vertices[:, -1] *= -1
161
-
162
- save_obj(vertices, faces, vertex_colors, mesh_fpath)
163
-
164
- print(f"Mesh saved to {mesh_fpath}")
165
-
166
- return mesh_fpath
167
-
168
- @spaces.GPU
169
- def make3d(input_image, sample_steps, sample_seed):
170
-
171
- images, show_images = generate_mvs(input_image, sample_steps, sample_seed)
172
-
173
- images = np.asarray(images, dtype=np.float32) / 255.0
174
- images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
175
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
176
-
177
- input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=2.5).to(device)
178
- render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
179
-
180
- images = images.unsqueeze(0).to(device)
181
- images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
182
-
183
- mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
184
- print(mesh_fpath)
185
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
186
- mesh_dirname = os.path.dirname(mesh_fpath)
187
- video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
188
-
189
- with torch.no_grad():
190
- # get triplane
191
- planes = model.forward_planes(images, input_cameras)
192
-
193
- # get video
194
- chunk_size = 20 if IS_FLEXICUBES else 1
195
- render_size = 384
196
-
197
- frames = []
198
- for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
199
- if IS_FLEXICUBES:
200
- frame = model.forward_geometry(
201
- planes,
202
- render_cameras[:, i:i+chunk_size],
203
- render_size=render_size,
204
- )['img']
205
- else:
206
- frame = model.synthesizer(
207
- planes,
208
- cameras=render_cameras[:, i:i+chunk_size],
209
- render_size=render_size,
210
- )['images_rgb']
211
- frames.append(frame)
212
- frames = torch.cat(frames, dim=1)
213
-
214
- images_to_video(
215
- frames[0],
216
- video_fpath,
217
- fps=30,
218
- )
219
-
220
- print(f"Video saved to {video_fpath}")
221
-
222
- mesh_fpath = make_mesh(mesh_fpath, planes)
223
-
224
- return video_fpath, mesh_fpath, show_images
225
-
226
-
227
- _HEADER_ = '''
228
- <h2><b>Official 🤗 Gradio demo for</b>
229
- <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>
230
- <b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b>
231
- </a>.
232
- </h2>
233
- '''
234
-
235
- _LINKS_ = '''
236
- <h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3>
237
- <h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3>
238
- '''
239
-
240
- _CITE_ = r"""
241
- ```bibtex
242
- @article{xu2024instantmesh,
243
- title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
244
- author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
245
- journal={arXiv preprint arXiv:2404.07191},
246
- year={2024}
247
- }
248
- ```
249
- """
250
-
251
-
252
- with gr.Blocks() as demo:
253
- gr.Markdown(_HEADER_)
254
- with gr.Row(variant="panel"):
255
- with gr.Column():
256
- with gr.Row():
257
- input_image = gr.Image(
258
- label="Input Image",
259
- image_mode="RGBA",
260
- sources="upload",
261
- width=256,
262
- height=256,
263
- type="pil",
264
- elem_id="content_image",
265
- )
266
- processed_image = gr.Image(
267
- label="Processed Image",
268
- image_mode="RGBA",
269
- width=256,
270
- height=256,
271
- type="pil",
272
- interactive=False
273
- )
274
- with gr.Row():
275
- with gr.Group():
276
- do_remove_background = gr.Checkbox(
277
- label="Remove Background", value=True
278
- )
279
- sample_seed = gr.Number(value=42, label="Seed (Try a different value if the result is unsatisfying)", precision=0)
280
-
281
- sample_steps = gr.Slider(
282
- label="Sample Steps",
283
- minimum=30,
284
- maximum=75,
285
- value=75,
286
- step=5
287
- )
288
-
289
- with gr.Row():
290
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
291
-
292
- with gr.Row(variant="panel"):
293
- gr.Examples(
294
- examples=[
295
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
296
- ],
297
- inputs=[input_image],
298
- label="Examples",
299
- examples_per_page=15
300
- )
301
-
302
- with gr.Column():
303
-
304
- with gr.Row():
305
-
306
- with gr.Column():
307
- mv_show_images = gr.Image(
308
- label="Generated Multi-views",
309
- type="pil",
310
- width=379,
311
- interactive=False
312
- )
313
-
314
- with gr.Column():
315
- output_video = gr.Video(
316
- label="video", format="mp4",
317
- width=379,
318
- autoplay=True,
319
- interactive=False
320
- )
321
-
322
- with gr.Row():
323
- output_model_obj = gr.Model3D(
324
- label="Output Model (OBJ Format)",
325
- width=768,
326
- interactive=False,
327
- )
328
- gr.Markdown(_LINKS_)
329
- gr.Markdown(_CITE_)
330
-
331
- submit.click(fn=check_input_image, inputs=[input_image]).success(
332
- fn=preprocess,
333
- inputs=[input_image, do_remove_background],
334
- outputs=[processed_image],
335
- ).success(
336
- fn=make3d,
337
- inputs=[processed_image, sample_steps, sample_seed],
338
- outputs=[output_video, output_model_obj, mv_show_images]
339
- )
340
-
341
- demo.launch()