pengHTYX commited on
Commit
ad7ddbe
·
1 Parent(s): f11b5f9
Files changed (1) hide show
  1. app.py +432 -0
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import fire
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from functools import partial
7
+
8
+ import cv2
9
+ import time
10
+ import numpy as np
11
+ from rembg import remove
12
+ from segment_anything import sam_model_registry, SamPredictor
13
+
14
+ import os
15
+ import sys
16
+ import numpy
17
+ import torch
18
+ import rembg
19
+ import threading
20
+ import urllib.request
21
+ from PIL import Image
22
+ from typing import Dict, Optional, Tuple, List
23
+ from dataclasses import dataclass
24
+ import streamlit as st
25
+ import huggingface_hub
26
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
27
+ from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel
28
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
29
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
30
+ from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
31
+ from einops import rearrange
32
+ import numpy as np
33
+ import subprocess
34
+ from datetime import datetime
35
+
36
+ def save_image(tensor):
37
+ ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
38
+ # pdb.set_trace()
39
+ im = Image.fromarray(ndarr)
40
+ return ndarr
41
+
42
+
43
+ def save_image_to_disk(tensor, fp):
44
+ ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
45
+ # pdb.set_trace()
46
+ im = Image.fromarray(ndarr)
47
+ im.save(fp)
48
+ return ndarr
49
+
50
+
51
+ def save_image_numpy(ndarr, fp):
52
+ im = Image.fromarray(ndarr)
53
+ im.save(fp)
54
+
55
+
56
+ weight_dtype = torch.float16
57
+
58
+ _TITLE = '''Era3D: High-Resolution Multiview Diffusion using Efficient Row-wise Attention'''
59
+ _DESCRIPTION = '''
60
+ <div>
61
+ Generate consistent high-resolution multi-view normals maps and color images.
62
+ <a style="display:inline-block; margin-left: .5em" href='https://github.com/pengHTYX/Era3D'></a>
63
+ </div>
64
+ <div>
65
+ The demo does not include the mesh reconstruction part, please visit <a href="https://github.com/pengHTYX/Era3D">our github repo</a> to get a textured mesh.
66
+ </div>
67
+ '''
68
+ _GPU_ID = 0
69
+
70
+
71
+ if not hasattr(Image, 'Resampling'):
72
+ Image.Resampling = Image
73
+
74
+
75
+ def sam_init():
76
+ sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
77
+ model_type = "vit_h"
78
+
79
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
80
+ predictor = SamPredictor(sam)
81
+ return predictor
82
+
83
+
84
+ def sam_segment(predictor, input_image, *bbox_coords):
85
+ bbox = np.array(bbox_coords)
86
+ image = np.asarray(input_image)
87
+
88
+ start_time = time.time()
89
+ predictor.set_image(image)
90
+
91
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(box=bbox, multimask_output=True)
92
+
93
+ print(f"SAM Time: {time.time() - start_time:.3f}s")
94
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
95
+ out_image[:, :, :3] = image
96
+ out_image_bbox = out_image.copy()
97
+ out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
98
+ torch.cuda.empty_cache()
99
+ return Image.fromarray(out_image_bbox, mode='RGBA')
100
+
101
+
102
+ def expand2square(pil_img, background_color):
103
+ width, height = pil_img.size
104
+ if width == height:
105
+ return pil_img
106
+ elif width > height:
107
+ result = Image.new(pil_img.mode, (width, width), background_color)
108
+ result.paste(pil_img, (0, (width - height) // 2))
109
+ return result
110
+ else:
111
+ result = Image.new(pil_img.mode, (height, height), background_color)
112
+ result.paste(pil_img, ((height - width) // 2, 0))
113
+ return result
114
+
115
+
116
+ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False):
117
+ RES = 1024
118
+ input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
119
+ if chk_group is not None:
120
+ segment = "Background Removal" in chk_group
121
+ rescale = "Rescale" in chk_group
122
+ if segment:
123
+ image_rem = input_image.convert('RGBA')
124
+ image_nobg = remove(image_rem, alpha_matting=True)
125
+ arr = np.asarray(image_nobg)[:, :, -1]
126
+ x_nonzero = np.nonzero(arr.sum(axis=0))
127
+ y_nonzero = np.nonzero(arr.sum(axis=1))
128
+ x_min = int(x_nonzero[0].min())
129
+ y_min = int(y_nonzero[0].min())
130
+ x_max = int(x_nonzero[0].max())
131
+ y_max = int(y_nonzero[0].max())
132
+ input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
133
+ # Rescale and recenter
134
+ if rescale:
135
+ image_arr = np.array(input_image)
136
+ in_w, in_h = image_arr.shape[:2]
137
+ out_res = min(RES, max(in_w, in_h))
138
+ ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
139
+ x, y, w, h = cv2.boundingRect(mask)
140
+ max_size = max(w, h)
141
+ ratio = 0.75
142
+ side_len = int(max_size / ratio)
143
+ padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
144
+ center = side_len // 2
145
+ padded_image[center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w] = image_arr[y : y + h, x : x + w]
146
+ rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
147
+
148
+ rgba_arr = np.array(rgba) / 255.0
149
+ rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
150
+ input_image = Image.fromarray((rgb * 255).astype(np.uint8))
151
+ else:
152
+ input_image = expand2square(input_image, (127, 127, 127, 0))
153
+ return input_image, input_image.resize((768, 768), Image.Resampling.LANCZOS)
154
+
155
+
156
+ def load_era3d_pipeline(cfg):
157
+ # Load scheduler, tokenizer and models.
158
+
159
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
160
+ '../MacLab-Era3D-512-6view',
161
+ torch_dtype=weight_dtype
162
+ )
163
+
164
+ # pipeline.to('cuda:0')
165
+ pipeline.unet.enable_xformers_memory_efficient_attention()
166
+
167
+
168
+ if torch.cuda.is_available():
169
+ pipeline.to('cuda:0')
170
+ # sys.main_lock = threading.Lock()
171
+ return pipeline
172
+
173
+
174
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset
175
+
176
+
177
+ def prepare_data(single_image, crop_size):
178
+ dataset = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white', crop_size=crop_size, single_image=single_image)
179
+ return dataset[0]
180
+
181
+ scene = 'scene'
182
+
183
+ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, chk_group=None):
184
+ import pdb
185
+ global scene
186
+ # pdb.set_trace()
187
+
188
+ if chk_group is not None:
189
+ write_image = "Write Results" in chk_group
190
+
191
+ batch = prepare_data(single_image, crop_size)
192
+
193
+ pipeline.set_progress_bar_config(disable=True)
194
+ seed = int(seed)
195
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(seed)
196
+
197
+
198
+ imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
199
+ num_views = imgs_in.shape[1]
200
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
201
+
202
+ normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
203
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
204
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
205
+
206
+
207
+ out = pipeline(
208
+ imgs_in,
209
+ None,
210
+ prompt_embeds=prompt_embeddings,
211
+ generator=generator,
212
+ guidance_scale=guidance_scale,
213
+ output_type='pt',
214
+ num_images_per_prompt=1,
215
+ return_elevation_focal=cfg.log_elevation_focal_length,
216
+ **cfg.pipe_validation_kwargs
217
+ ).images
218
+
219
+ bsz = out.shape[0] // 2
220
+ normals_pred = out[:bsz]
221
+ images_pred = out[bsz:]
222
+ num_views = 6
223
+ if write_image:
224
+ VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
225
+ cur_dir = os.path.join("./mv_res", f"cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}")
226
+
227
+ scene = 'scene'+datetime.now().strftime('@%Y%m%d-%H%M%S')
228
+ scene_dir = os.path.join(cur_dir, scene)
229
+ os.makedirs(scene_dir, exist_ok=True)
230
+
231
+ for j in range(num_views):
232
+ view = VIEWS[j]
233
+ normal = normals_pred[j]
234
+ color = images_pred[j]
235
+
236
+ normal_filename = f"normals_{view}_masked.png"
237
+ color_filename = f"color_{view}_masked.png"
238
+ normal = save_image_to_disk(normal, os.path.join(scene_dir, normal_filename))
239
+ color = save_image_to_disk(color, os.path.join(scene_dir, color_filename))
240
+
241
+
242
+ normals_pred = [save_image(normals_pred[i]) for i in range(bsz)]
243
+ images_pred = [save_image(images_pred[i]) for i in range(bsz)]
244
+
245
+ out = images_pred + normals_pred
246
+ return out
247
+
248
+
249
+ def process_3d(mode, data_dir, guidance_scale, crop_size):
250
+ dir = None
251
+ global scene
252
+
253
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
254
+
255
+ subprocess.run(
256
+ f'cd instant-nsr-pl && bash run.sh 0 {scene} exp_demo && cd ..',
257
+ shell=True,
258
+ )
259
+ import glob
260
+ # import pdb
261
+
262
+ # pdb.set_trace()
263
+
264
+ obj_files = glob.glob(f'{cur_dir}/instant-nsr-pl/exp_demo/{scene}/*/save/*.obj', recursive=True)
265
+ print(obj_files)
266
+ if obj_files:
267
+ dir = obj_files[0]
268
+ return dir
269
+
270
+
271
+ @dataclass
272
+ class TestConfig:
273
+ pretrained_model_name_or_path: str
274
+ pretrained_unet_path:str
275
+ revision: Optional[str]
276
+ validation_dataset: Dict
277
+ save_dir: str
278
+ seed: Optional[int]
279
+ validation_batch_size: int
280
+ dataloader_num_workers: int
281
+ # save_single_views: bool
282
+ save_mode: str
283
+ local_rank: int
284
+
285
+ pipe_kwargs: Dict
286
+ pipe_validation_kwargs: Dict
287
+ unet_from_pretrained_kwargs: Dict
288
+ validation_guidance_scales: List[float]
289
+ validation_grid_nrow: int
290
+ camera_embedding_lr_mult: float
291
+
292
+ num_views: int
293
+ camera_embedding_type: str
294
+
295
+ pred_type: str # joint, or ablation
296
+ regress_elevation: bool
297
+ enable_xformers_memory_efficient_attention: bool
298
+
299
+ cond_on_normals: bool
300
+ cond_on_colors: bool
301
+
302
+ regress_elevation: bool
303
+ regress_focal_length: bool
304
+
305
+
306
+
307
+ def run_demo():
308
+ from utils.misc import load_config
309
+ from omegaconf import OmegaConf
310
+
311
+ # parse YAML config to OmegaConf
312
+ cfg = load_config("./configs/test_unclip-512-6view.yaml")
313
+ # print(cfg)
314
+ schema = OmegaConf.structured(TestConfig)
315
+ cfg = OmegaConf.merge(schema, cfg)
316
+
317
+ pipeline = load_era3d_pipeline(cfg)
318
+ torch.set_grad_enabled(False)
319
+ pipeline.to(f'cuda:{_GPU_ID}')
320
+
321
+ predictor = sam_init()
322
+
323
+ custom_theme = gr.themes.Soft(primary_hue="blue").set(
324
+ button_secondary_background_fill="*neutral_100", button_secondary_background_fill_hover="*neutral_200"
325
+ )
326
+ custom_css = '''#disp_image {
327
+ text-align: center; /* Horizontally center the content */
328
+ }'''
329
+
330
+ with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
331
+ with gr.Row():
332
+ with gr.Column(scale=1):
333
+ gr.Markdown('# ' + _TITLE)
334
+ gr.Markdown(_DESCRIPTION)
335
+ with gr.Row(variant='panel'):
336
+ with gr.Column(scale=1):
337
+ input_image = gr.Image(type='pil', image_mode='RGBA', height=768, label='Input image')
338
+
339
+ with gr.Column(scale=1):
340
+ processed_image = gr.Image(
341
+ type='pil',
342
+ label="Processed Image",
343
+ interactive=False,
344
+ height=768,
345
+ image_mode='RGBA',
346
+ elem_id="disp_image",
347
+ visible=True,
348
+ )
349
+ # with gr.Column(scale=1):
350
+ # ## add 3D Model
351
+ # obj_3d = gr.Model3D(
352
+ # # clear_color=[0.0, 0.0, 0.0, 0.0],
353
+ # label="3D Model", height=320,
354
+ # # camera_position=[0,0,2.0]
355
+ # )
356
+ processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False)
357
+ with gr.Row(variant='panel'):
358
+ with gr.Column(scale=1):
359
+ example_folder = os.path.join(os.path.dirname(__file__), "./examples")
360
+ example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
361
+ gr.Examples(
362
+ examples=example_fns,
363
+ inputs=[input_image],
364
+ outputs=[input_image],
365
+ cache_examples=False,
366
+ label='Examples (click one of the images below to start)',
367
+ examples_per_page=30,
368
+ )
369
+ with gr.Column(scale=1):
370
+ with gr.Accordion('Advanced options', open=True):
371
+ with gr.Row():
372
+ with gr.Column():
373
+ input_processing = gr.CheckboxGroup(
374
+ ['Background Removal'],
375
+ label='Input Image Preprocessing',
376
+ value=['Background Removal'],
377
+ info='untick this, if masked image with alpha channel',
378
+ )
379
+ with gr.Column():
380
+ output_processing = gr.CheckboxGroup(
381
+ ['Write Results'], label='write the results in ./outputs folder', value=['Write Results']
382
+ )
383
+ with gr.Row():
384
+ with gr.Column():
385
+ scale_slider = gr.Slider(1, 5, value=3, step=1, label='Classifier Free Guidance Scale')
386
+ with gr.Column():
387
+ steps_slider = gr.Slider(15, 100, value=40, step=1, label='Number of Diffusion Inference Steps')
388
+ with gr.Row():
389
+ with gr.Column():
390
+ seed = gr.Number(600, label='Seed')
391
+ with gr.Column():
392
+ crop_size = gr.Number(420, label='Crop size')
393
+
394
+ mode = gr.Textbox('train', visible=False)
395
+ data_dir = gr.Textbox('outputs', visible=False)
396
+ # with gr.Row():
397
+ # method = gr.Radio(choices=['instant-nsr-pl', 'NeuS'], label='Method (Default: instant-nsr-pl)', value='instant-nsr-pl')
398
+ run_btn = gr.Button('Generate Normals and Colors', variant='primary', interactive=True)
399
+ # recon_btn = gr.Button('Reconstruct 3D model', variant='primary', interactive=True)
400
+ # gr.Markdown("<span style='color:red'>First click Generate button, then click Reconstruct button. Reconstruction may cost several minutes.</span>")
401
+
402
+ with gr.Row():
403
+ view_1 = gr.Image(interactive=False, height=512, show_label=False)
404
+ view_2 = gr.Image(interactive=False, height=512, show_label=False)
405
+ view_3 = gr.Image(interactive=False, height=512, show_label=False)
406
+ view_4 = gr.Image(interactive=False, height=512, show_label=False)
407
+ view_5 = gr.Image(interactive=False, height=512, show_label=False)
408
+ view_6 = gr.Image(interactive=False, height=512, show_label=False)
409
+ with gr.Row():
410
+ normal_1 = gr.Image(interactive=False, height=512, show_label=False)
411
+ normal_2 = gr.Image(interactive=False, height=512, show_label=False)
412
+ normal_3 = gr.Image(interactive=False, height=512, show_label=False)
413
+ normal_4 = gr.Image(interactive=False, height=512, show_label=False)
414
+ normal_5 = gr.Image(interactive=False, height=512, show_label=False)
415
+ normal_6 = gr.Image(interactive=False, height=512, show_label=False)
416
+
417
+ run_btn.click(
418
+ fn=partial(preprocess, predictor), inputs=[input_image, input_processing], outputs=[processed_image_highres, processed_image], queue=True
419
+ ).success(
420
+ fn=partial(run_pipeline, pipeline, cfg),
421
+ inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size, output_processing],
422
+ outputs=[view_1, view_2, view_3, view_4, view_5, view_6, normal_1, normal_2, normal_3, normal_4, normal_5, normal_6],
423
+ )
424
+ # recon_btn.click(
425
+ # process_3d, inputs=[mode, data_dir, scale_slider, crop_size], outputs=[obj_3d]
426
+ # )
427
+
428
+ demo.queue().launch(share=True, max_threads=80)
429
+
430
+
431
+ if __name__ == '__main__':
432
+ fire.Fire(run_demo)