NCJ commited on
Commit
823d579
·
verified ·
1 Parent(s): 68e2a74

finish demo

Browse files
Files changed (6) hide show
  1. app.py +94 -40
  2. demo/mesh_recon.py +45 -12
  3. demo/relighting_gen.py +29 -17
  4. demo/render_hints.py +11 -8
  5. demo/rm_bg.py +0 -1
  6. requirements.txt +6 -5
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
 
2
  import imageio
3
  import numpy as np
 
4
 
5
  from demo.img_gen import img_gen
6
  from demo.mesh_recon import mesh_reconstruction
@@ -11,7 +13,7 @@ from demo.rm_bg import rm_bg
11
 
12
  with gr.Blocks(title="DiLightNet Demo") as demo:
13
  gr.Markdown("""# DiLightNet: Fine-grained Lighting Control for Diffusion-based Image Generation
14
- ## A demo for generating images under point lights using DiLightNet. For full usage, please refer to our [GitHub repository](TBD)""")
15
 
16
  with gr.Row():
17
  # 1. Reference Image Input / Generation
@@ -29,6 +31,11 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
29
  with gr.Row():
30
  generate_btn = gr.Button(value="Generate")
31
  generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image])
 
 
 
 
 
32
 
33
  # 2. Background Removal
34
  with gr.Column(variant="panel"):
@@ -49,54 +56,98 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
49
  with gr.Accordion("Options", open=False):
50
  with gr.Group():
51
  remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False)
52
- fov = gr.Number(value=55., label="FOV", interactive=True)
53
  mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.)
54
  depth_estimation_btn = gr.Button(value="Estimate Depth")
 
 
 
55
  depth_estimation_btn.click(
56
- fn=mesh_reconstruction,
57
- inputs=[masked_image, mask, remove_edges, fov, mask_threshold],
58
- outputs=[mesh]
59
  )
60
 
61
  gr.Markdown("## Step 4. Render Hints")
62
  with gr.Row():
63
  with gr.Column():
64
- hint_image = gr.Image(label="Hint Image")
65
  with gr.Column():
66
- pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01)
67
- pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01)
68
- pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01)
69
- power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.)
70
- render_btn = gr.Button(value="Render Hints")
71
  res_folder_path = gr.Textbox("", visible=False)
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def render_wrapper(mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power,
74
- progress=gr.Progress(track_tqdm=True)):
75
- res_path = render_hint_images_btn_func(mesh, fov, [(pl_pos_x, pl_pos_y, pl_pos_z)], power)
76
- hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.34.png"]]
77
- hints = []
78
- for hint_file in hint_files:
79
- hint = imageio.v3.imread(hint_file)
80
- hints.append(hint)
81
- hints = np.concatenate(hints, axis=1)
82
- return hints, res_path
83
-
84
- render_btn.click(
85
- fn=render_wrapper,
86
- inputs=[mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power],
87
- outputs=[hint_image, res_folder_path]
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- gr.Markdown("## Step 5. Lighting Control Generation")
 
 
 
 
 
 
91
  with gr.Row():
92
  res_image = gr.Image(label="Result Image")
93
  with gr.Column():
94
  with gr.Group():
95
- relighting_prompt = gr.Textbox(value="", label="Lighting-Control Text Prompt", lines=3,
96
- placeholder="Input prompt here",
97
- interactive=True)
98
- reuse_btn = gr.Button(value="Reuse Image Generation Prompt")
99
- reuse_btn.click(fn=lambda x: x, inputs=[prompt], outputs=[relighting_prompt])
 
 
 
 
 
 
 
 
 
 
 
 
100
  with gr.Accordion("Options", open=False):
101
  with gr.Row():
102
  relighting_seed = gr.Number(value=3407, label="Seed", interactive=True)
@@ -106,7 +157,7 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
106
  relighting_generate_btn = gr.Button(value="Generate")
107
 
108
  def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
109
- relighting_steps, relighting_cfg,
110
  progress=gr.Progress(track_tqdm=True)):
111
  relighting_gen(
112
  masked_ref_img=masked_image,
@@ -118,16 +169,19 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
118
  seed=int(relighting_seed),
119
  cfg=relighting_cfg
120
  )
121
- mask_for_bg = imageio.v3.imread(res_folder_path + '/hint00_diffuse.png')[..., -1:] / 255.
122
- res = imageio.v3.imread(res_folder_path + '/relighting00.png') / 255.
123
- res = res * mask_for_bg # + bg * (1. - mask_for_bg)
124
- res = (res * 255).clip(0, 255).astype(np.uint8)
125
- return res
 
 
 
126
 
127
 
128
  relighting_generate_btn.click(fn=gen_relighting_image,
129
  inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
130
- relighting_steps, relighting_cfg],
131
  outputs=[res_image])
132
 
133
 
 
1
  import gradio as gr
2
+ import os
3
  import imageio
4
  import numpy as np
5
+ from einops import rearrange
6
 
7
  from demo.img_gen import img_gen
8
  from demo.mesh_recon import mesh_reconstruction
 
13
 
14
  with gr.Blocks(title="DiLightNet Demo") as demo:
15
  gr.Markdown("""# DiLightNet: Fine-grained Lighting Control for Diffusion-based Image Generation
16
+ ## A demo for generating images under point/environmantal lighting using DiLightNet. For full usage (video generation & arbitary lighting condition), please refer to our [GitHub repository](https://github.com/iamNCJ/DiLightNet)""")
17
 
18
  with gr.Row():
19
  # 1. Reference Image Input / Generation
 
31
  with gr.Row():
32
  generate_btn = gr.Button(value="Generate")
33
  generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image])
34
+ gr.Examples(
35
+ examples=[os.path.join("examples/provisional_img", i) for i in os.listdir("examples/provisional_img")],
36
+ inputs=[input_image],
37
+ examples_per_page = 20,
38
+ )
39
 
40
  # 2. Background Removal
41
  with gr.Column(variant="panel"):
 
56
  with gr.Accordion("Options", open=False):
57
  with gr.Group():
58
  remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False)
59
+ fov = gr.Number(value=55., label="FOV", interactive=False)
60
  mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.)
61
  depth_estimation_btn = gr.Button(value="Estimate Depth")
62
+ def mesh_reconstruction_wrapper(image, mask, remove_edges, mask_threshold,
63
+ progress=gr.Progress(track_tqdm=True)):
64
+ return mesh_reconstruction(image, mask, remove_edges, None, mask_threshold)
65
  depth_estimation_btn.click(
66
+ fn=mesh_reconstruction_wrapper,
67
+ inputs=[input_image, mask, remove_edges, mask_threshold],
68
+ outputs=[mesh, fov],
69
  )
70
 
71
  gr.Markdown("## Step 4. Render Hints")
72
  with gr.Row():
73
  with gr.Column():
74
+ hint_image = gr.Image(label="Hint Image", height=512, width=512)
75
  with gr.Column():
 
 
 
 
 
76
  res_folder_path = gr.Textbox("", visible=False)
77
+ is_env_lighting = gr.Checkbox(label="Use Environmental Lighting", value=True, interactive=False, visible=False)
78
+ with gr.Tab("Environmental Lighting"):
79
+ env_map_preview = gr.Image(label="Environment Map Preview", height=256, width=512, interactive=False, show_download_button=False)
80
+ env_map_path = gr.Text(interactive=False, visible=False, value="examples/env_map/grace.exr")
81
+ env_rotation = gr.Slider(value=0., label="Environment Rotation", minimum=0., maximum=360., step=0.5)
82
+ env_examples = gr.Examples(
83
+ examples=[[os.path.join("examples/env_map_preview", i), os.path.join("examples/env_map", i).replace("png", "exr")] for i in os.listdir("examples/env_map_preview")],
84
+ inputs=[env_map_preview, env_map_path],
85
+ examples_per_page = 20,
86
+ )
87
+ render_btn_env = gr.Button(value="Render Hints")
88
 
89
+ def render_wrapper_env(mesh, fov, env_map_path, env_rotation, progress=gr.Progress(track_tqdm=True)):
90
+ env_map_path = os.path.abspath(env_map_path)
91
+ res_path = render_hint_images_btn_func(mesh, float(fov), [(0, 0, 0)], env_map=env_map_path, env_start_azi=env_rotation / 360.)
92
+ hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.05.png", "_ggx0.13.png", "_ggx0.34.png"]]
93
+ hints = []
94
+ for hint_file in hint_files:
95
+ hint = imageio.v3.imread(hint_file)
96
+ hints.append(hint)
97
+ hints = rearrange(np.stack(hints), '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=2, n2=2)
98
+ return hints, res_path, True
99
+ render_btn_env.click(
100
+ fn=render_wrapper_env,
101
+ inputs=[mesh, fov, env_map_path, env_rotation],
102
+ outputs=[hint_image, res_folder_path, is_env_lighting]
103
+ )
104
+
105
+ with gr.Tab("Point Lighting"):
106
+ pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01)
107
+ pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01)
108
+ pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01)
109
+ power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.)
110
+ render_btn_pl = gr.Button(value="Render Hints")
111
+
112
+ def render_wrapper_pl(mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power,
113
+ progress=gr.Progress(track_tqdm=True)):
114
+ res_path = render_hint_images_btn_func(mesh, float(fov), [(pl_pos_x, pl_pos_y, pl_pos_z)], power)
115
+ hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.05.png", "_ggx0.13.png", "_ggx0.34.png"]]
116
+ hints = []
117
+ for hint_file in hint_files:
118
+ hint = imageio.v3.imread(hint_file)
119
+ hints.append(hint)
120
+ hints = rearrange(np.stack(hints), '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=2, n2=2)
121
+ return hints, res_path, False
122
 
123
+ render_btn_pl.click(
124
+ fn=render_wrapper_pl,
125
+ inputs=[mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power],
126
+ outputs=[hint_image, res_folder_path, is_env_lighting]
127
+ )
128
+
129
+ gr.Markdown("## Step 5. Control Lighting!")
130
  with gr.Row():
131
  res_image = gr.Image(label="Result Image")
132
  with gr.Column():
133
  with gr.Group():
134
+ with gr.Row():
135
+ relighting_prompt = gr.Textbox(value="", label="Appearance Text Prompt", lines=3,
136
+ placeholder="Input prompt here",
137
+ interactive=True)
138
+ with gr.Row():
139
+ # several example prompts
140
+ metallic_prompt_btn = gr.Button(value="Metallic", size="sm")
141
+ specular_prompt_btn = gr.Button(value="Specular", size="sm")
142
+ very_specular_prompt_btn = gr.Button(value="Very Specular", size="sm")
143
+ clear_prompt_btn = gr.Button(value="Clear", size="sm")
144
+ metallic_prompt_btn.click(fn=lambda x: x + " metallic", inputs=[relighting_prompt], outputs=[relighting_prompt])
145
+ specular_prompt_btn.click(fn=lambda x: x + " specular", inputs=[relighting_prompt], outputs=[relighting_prompt])
146
+ very_specular_prompt_btn.click(fn=lambda x: x + " very specular", inputs=[relighting_prompt], outputs=[relighting_prompt])
147
+ clear_prompt_btn.click(fn=lambda x: "", inputs=[relighting_prompt], outputs=[relighting_prompt])
148
+ with gr.Row():
149
+ reuse_btn = gr.Button(value="Reuse Provisional Image Generation Prompt")
150
+ reuse_btn.click(fn=lambda x: x, inputs=[prompt], outputs=[relighting_prompt])
151
  with gr.Accordion("Options", open=False):
152
  with gr.Row():
153
  relighting_seed = gr.Number(value=3407, label="Seed", interactive=True)
 
157
  relighting_generate_btn = gr.Button(value="Generate")
158
 
159
  def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
160
+ relighting_steps, relighting_cfg, do_env_inpainting,
161
  progress=gr.Progress(track_tqdm=True)):
162
  relighting_gen(
163
  masked_ref_img=masked_image,
 
169
  seed=int(relighting_seed),
170
  cfg=relighting_cfg
171
  )
172
+ relit_img = imageio.v3.imread(res_folder_path + '/relighting00.png')
173
+ if do_env_inpainting:
174
+ bg = imageio.v3.imread(res_folder_path + f'/bg00.png') / 255.
175
+ relit_img = relit_img / 255.
176
+ mask_for_bg = imageio.v3.imread(res_folder_path + '/hint00_diffuse.png')[..., -1:] / 255.
177
+ relit_img = relit_img * mask_for_bg + bg * (1. - mask_for_bg)
178
+ relit_img = (relit_img * 255).clip(0, 255).astype(np.uint8)
179
+ return relit_img
180
 
181
 
182
  relighting_generate_btn.click(fn=gen_relighting_image,
183
  inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
184
+ relighting_steps, relighting_cfg, is_env_lighting],
185
  outputs=[res_image])
186
 
187
 
demo/mesh_recon.py CHANGED
@@ -1,18 +1,36 @@
1
  import tempfile
 
2
 
3
  import numpy as np
 
4
  import torch
5
  import trimesh
6
-
7
  import spaces
 
 
 
 
 
 
 
 
 
8
 
9
 
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
11
 
12
- # use torch hub
13
- # zeroGPU hack from https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/9
14
- torch.jit.script = lambda f: f
15
- model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_NK", pretrained=True).to(device).eval()
 
 
 
 
16
 
17
 
18
  def get_intrinsics(H, W, fov=55.):
@@ -106,14 +124,29 @@ def mesh_reconstruction(
106
  masked_image: np.ndarray,
107
  mask: np.ndarray,
108
  remove_edges: bool = True,
109
- fov: float = 55.,
110
  mask_threshold: float = 25.,
111
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  rgb = masked_image[..., :3].transpose(2, 0, 1) / 255.
113
- sample = torch.from_numpy(rgb).to(device).unsqueeze(0).float()
114
- with torch.no_grad():
115
- depth = model.infer(sample)
116
- depth = depth.squeeze().cpu().numpy()
117
 
118
  pts3d = depth_to_points(depth[None], fov=fov)
119
  pts3d = pts3d.reshape(-1, 3)
@@ -132,4 +165,4 @@ def mesh_reconstruction(
132
  mesh_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
133
  mesh_file_path = mesh_file.name
134
  mesh.export(mesh_file_path)
135
- return mesh_file_path
 
1
  import tempfile
2
+ from typing import Optional
3
 
4
  import numpy as np
5
+ import cv2
6
  import torch
7
  import trimesh
 
8
  import spaces
9
+ from dust3r.model import AsymmetricCroCo3DStereo
10
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
11
+ from dust3r.inference import inference
12
+ from dust3r.image_pairs import make_pairs
13
+
14
+
15
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
16
+
17
+ model = AsymmetricCroCo3DStereo.from_pretrained("naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt").to(device).eval()
18
 
19
 
20
+ import torchvision.transforms as tvf
21
+ import PIL.Image
22
+ import numpy as np
23
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
24
+
25
 
26
+ def load_single_image(img_array):
27
+ imgs = []
28
+ for i in range(2):
29
+ img = PIL.Image.fromarray(img_array)
30
+ imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
31
+ [img.size[::-1]]), idx=i, instance=str(len(imgs))))
32
+
33
+ return imgs
34
 
35
 
36
  def get_intrinsics(H, W, fov=55.):
 
124
  masked_image: np.ndarray,
125
  mask: np.ndarray,
126
  remove_edges: bool = True,
127
+ fov: Optional[float] = None,
128
  mask_threshold: float = 25.,
129
  ):
130
+ masked_image = cv2.resize(masked_image, (512, 512))
131
+ mask = cv2.resize(mask, (512, 512))
132
+ images = load_single_image(masked_image)
133
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
134
+ output = inference(pairs, model, device, batch_size=1)
135
+ scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
136
+ if fov is not None:
137
+ # do not optimize focal length if fov is provided
138
+ focal = scene.imshapes[0][1] / (2 * np.tan(0.5 * fov * np.pi / 180.))
139
+ scene.preset_focal([focal, focal])
140
+ _loss = scene.compute_global_alignment(init='mst', niter=300, schedule='cosine', lr=0.01)
141
+ if fov is None:
142
+ # get the focal length from the optimized parameters
143
+ focals = scene.get_focals()
144
+ fov = 2 * (np.arctan((scene.imshapes[0][1] / (focals[0] + focals[1])).detach().cpu().numpy()) * 180 / np.pi)[0]
145
+ depth = scene.get_depthmaps()[0].detach().cpu().numpy()
146
+ if device.type == 'cuda':
147
+ torch.cuda.empty_cache()
148
+
149
  rgb = masked_image[..., :3].transpose(2, 0, 1) / 255.
 
 
 
 
150
 
151
  pts3d = depth_to_points(depth[None], fov=fov)
152
  pts3d = pts3d.reshape(-1, 3)
 
165
  mesh_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
166
  mesh_file_path = mesh_file.name
167
  mesh.export(mesh_file_path)
168
+ return mesh_file_path, fov
demo/relighting_gen.py CHANGED
@@ -2,7 +2,7 @@ import imageio
2
  import numpy as np
3
  import spaces
4
  import torch
5
- from diffusers import UniPCMultistepScheduler, StableDiffusionControlNetPipeline
6
  from diffusers.utils import get_class_from_dynamic_module
7
 
8
  from tqdm import tqdm
@@ -19,37 +19,49 @@ NeuralTextureControlNetModel = get_class_from_dynamic_module(
19
  "NeuralTextureControlNetModel"
20
  )
21
  controlnet = NeuralTextureControlNetModel.from_pretrained(
22
- "dilightnet/DiLightNet",
23
  torch_dtype=dtype,
24
  )
 
25
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
26
- "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=dtype
 
 
 
27
  ).to(device)
28
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
29
  pipe.set_progress_bar_config(disable=True)
30
 
 
 
 
 
 
 
31
 
32
  @spaces.GPU
33
- def relighting_gen(masked_ref_img, mask, cond_path, frames, prompt, steps, seed, cfg):
34
  mask = mask[..., :1] / 255.
35
  for i in tqdm(range(frames)):
36
  source_image = masked_ref_img[..., :3] / 255.
37
- cond_diffuse = imageio.v3.imread(f'{cond_path}/hint{i:02d}_diffuse.png') / 255.
38
- if cond_diffuse.shape[-1] == 4:
39
- cond_diffuse = cond_diffuse[..., :3] * cond_diffuse[..., 3:]
40
- cond_ggx034 = imageio.v3.imread(f'{cond_path}/hint{i:02d}_ggx0.34.png') / 255.
41
- if cond_ggx034.shape[-1] == 4:
42
- cond_ggx034 = cond_ggx034[..., :3] * cond_ggx034[..., 3:]
43
- cond_ggx013 = imageio.v3.imread(f'{cond_path}/hint{i:02d}_ggx0.13.png') / 255.
44
- if cond_ggx013.shape[-1] == 4:
45
- cond_ggx013 = cond_ggx013[..., :3] * cond_ggx013[..., 3:]
46
- cond_ggx005 = imageio.v3.imread(f'{cond_path}/hint{i:02d}_ggx0.05.png') / 255.
47
- if cond_ggx005.shape[-1] == 4:
48
- cond_ggx005 = cond_ggx005[..., :3] * cond_ggx005[..., 3:]
49
- hint = np.concatenate([mask, source_image, cond_diffuse, cond_ggx005, cond_ggx013, cond_ggx034], axis=2).astype(np.float32)[None]
50
  hint = torch.from_numpy(hint).to(dtype).permute(0, 3, 1, 2).to(device)
51
  generator = torch.manual_seed(seed)
52
  image = pipe(
53
  prompt, num_inference_steps=steps, generator=generator, image=hint, num_images_per_prompt=1, guidance_scale=cfg, output_type='np',
54
  ).images[0] # [H, W, C]
 
 
 
55
  imageio.imwrite(f'{cond_path}/relighting{i:02d}.png', (image * 255).clip(0, 255).astype(np.uint8))
 
2
  import numpy as np
3
  import spaces
4
  import torch
5
+ from diffusers import UniPCMultistepScheduler, StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline, ConsistencyDecoderVAE
6
  from diffusers.utils import get_class_from_dynamic_module
7
 
8
  from tqdm import tqdm
 
19
  "NeuralTextureControlNetModel"
20
  )
21
  controlnet = NeuralTextureControlNetModel.from_pretrained(
22
+ "DiLightNet/DiLightNet",
23
  torch_dtype=dtype,
24
  )
25
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
26
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
27
+ "stabilityai/stable-diffusion-2-1",
28
+ vae=vae,
29
+ controlnet=controlnet,
30
+ torch_dtype=dtype
31
  ).to(device)
32
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
33
  pipe.set_progress_bar_config(disable=True)
34
 
35
+ inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained(
36
+ "stabilityai/stable-diffusion-2-inpainting",
37
+ torch_dtype=dtype
38
+ ).to(device)
39
+ inpainting_pipe.set_progress_bar_config(disable=True)
40
+
41
 
42
  @spaces.GPU
43
+ def relighting_gen(masked_ref_img, mask, cond_path, frames, prompt, steps, seed, cfg, inpaint=False):
44
  mask = mask[..., :1] / 255.
45
  for i in tqdm(range(frames)):
46
  source_image = masked_ref_img[..., :3] / 255.
47
+
48
+ hint_types = ['diffuse', 'ggx0.05', 'ggx0.13', 'ggx0.34']
49
+ images = [mask, source_image]
50
+ for hint_type in hint_types:
51
+ image_path = f'{cond_path}/hint{i:02d}_{hint_type}.png'
52
+ image = imageio.v3.imread(image_path) / 255.
53
+ if image.shape[-1] == 4: # Check if the image has an alpha channel
54
+ image = image[..., :3] * image[..., 3:] # Premultiply RGB by Alpha
55
+ images.append(image)
56
+
57
+ hint = np.concatenate(images, axis=2).astype(np.float32)[None]
58
+
 
59
  hint = torch.from_numpy(hint).to(dtype).permute(0, 3, 1, 2).to(device)
60
  generator = torch.manual_seed(seed)
61
  image = pipe(
62
  prompt, num_inference_steps=steps, generator=generator, image=hint, num_images_per_prompt=1, guidance_scale=cfg, output_type='np',
63
  ).images[0] # [H, W, C]
64
+ if inpaint:
65
+ mask_image = (1. - mask)[None]
66
+ image = inpainting_pipe(prompt=prompt, image=image[None], mask_image=mask_image, generator=generator, output_type='np', cfg=3.0, strength=1.0).images[0]
67
  imageio.imwrite(f'{cond_path}/relighting{i:02d}.png', (image * 255).clip(0, 255).astype(np.uint8))
demo/render_hints.py CHANGED
@@ -7,7 +7,7 @@ from tqdm import tqdm
7
 
8
 
9
  def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
10
- env_map: Optional[str] = None, resolution=512, use_gpu=False):
11
  import bpy
12
  import numpy as np
13
 
@@ -73,7 +73,7 @@ def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output
73
  output_folder = tempfile.mkdtemp()
74
  for i in tqdm(range(len(pls)), desc='Rendering Hints'):
75
  if env_map:
76
- z_angle = i / len(pls) * np.pi * 2.
77
  set_env_light(env_map, rotation_euler=[0, 0, z_angle])
78
  else:
79
  pl_pos = pls[i]
@@ -85,7 +85,7 @@ def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output
85
  return output_folder
86
 
87
 
88
- def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Optional[str] = None, resolution=512):
89
  import bpy
90
  import numpy as np
91
 
@@ -126,7 +126,7 @@ def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Opt
126
  if output_folder is None:
127
  output_folder = tempfile.mkdtemp()
128
  for i in tqdm(range(len(pls)), desc='Rendering Env Backgrounds'):
129
- z_angle = i / len(pls) * np.pi * 2.
130
  set_env_light(env_map, rotation_euler=[0, 0, z_angle])
131
 
132
  with stdout_redirected():
@@ -135,16 +135,19 @@ def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Opt
135
  return output_folder
136
 
137
 
138
- def render_hint_images_wrapper(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution, return_dict):
139
- output_folder = render_hint_images(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution)
 
 
 
140
  return_dict['output_folder'] = output_folder
141
 
142
 
143
  def render_hint_images_btn_func(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
144
- env_map: Optional[str] = None, resolution=512):
145
  manager = multiprocessing.Manager()
146
  return_dict = manager.dict()
147
- p = Process(target=render_hint_images_wrapper, args=(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution, return_dict))
148
  p.start()
149
  p.join()
150
  return return_dict['output_folder']
 
7
 
8
 
9
  def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
10
+ env_map: Optional[str] = None, env_start_azi=0., resolution=512, use_gpu=False):
11
  import bpy
12
  import numpy as np
13
 
 
73
  output_folder = tempfile.mkdtemp()
74
  for i in tqdm(range(len(pls)), desc='Rendering Hints'):
75
  if env_map:
76
+ z_angle = (i / len(pls) + env_start_azi) * np.pi * 2.
77
  set_env_light(env_map, rotation_euler=[0, 0, z_angle])
78
  else:
79
  pl_pos = pls[i]
 
85
  return output_folder
86
 
87
 
88
+ def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Optional[str] = None, env_start_azi=0., resolution=512):
89
  import bpy
90
  import numpy as np
91
 
 
126
  if output_folder is None:
127
  output_folder = tempfile.mkdtemp()
128
  for i in tqdm(range(len(pls)), desc='Rendering Env Backgrounds'):
129
+ z_angle = (i / len(pls) + env_start_azi) * np.pi * 2.
130
  set_env_light(env_map, rotation_euler=[0, 0, z_angle])
131
 
132
  with stdout_redirected():
 
135
  return output_folder
136
 
137
 
138
+ def render_hint_images_wrapper(model_path, fov, pls, power, geo_smooth, output_folder, env_map, env_start_azi, resolution, return_dict):
139
+ output_folder = render_hint_images(model_path, fov, pls, power, geo_smooth, output_folder, env_map, env_start_azi, resolution)
140
+ if env_map is not None:
141
+ bg_output_folder = render_bg_images(fov, pls, output_folder, env_map, env_start_azi, resolution)
142
+ return_dict['bg_output_folder'] = bg_output_folder
143
  return_dict['output_folder'] = output_folder
144
 
145
 
146
  def render_hint_images_btn_func(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
147
+ env_map: Optional[str] = None, env_start_azi=0., resolution=512):
148
  manager = multiprocessing.Manager()
149
  return_dict = manager.dict()
150
+ p = Process(target=render_hint_images_wrapper, args=(model_path, fov, pls, power, geo_smooth, output_folder, env_map, env_start_azi, resolution, return_dict))
151
  p.start()
152
  p.join()
153
  return return_dict['output_folder']
demo/rm_bg.py CHANGED
@@ -8,7 +8,6 @@ def rm_bg(img, use_sam=False):
8
  img = img.resize((512, 512))
9
  output = rembg.remove(img)
10
  mask = np.array(output)[:, :, 3]
11
- print(mask.shape)
12
 
13
  # use sam for mask refinement
14
  if use_sam:
 
8
  img = img.resize((512, 512))
9
  output = rembg.remove(img)
10
  mask = np.array(output)[:, :, 3]
 
11
 
12
  # use sam for mask refinement
13
  if use_sam:
requirements.txt CHANGED
@@ -1,17 +1,18 @@
1
  numpy==1.26.4
2
  scipy==1.13.0
3
- diffusers==0.27.2
4
- transformers==4.39.3
5
  accelerate==0.29.3
6
- timm==0.6.12 # must use this version, required by MiDaS
7
  rembg==2.0.56
8
  trimesh==4.3.1
9
  opencv-contrib-python==4.9.0.80
10
  tqdm==4.66.2
11
  bpy==3.6.0
12
- bpy-helper==0.0.0
13
- gradio==4.27.0
14
  einops==0.7.0
15
  imageio[ffmpeg]==2.34.0
16
  torch==2.0.1
17
  torchvision==0.15.2
 
 
 
1
  numpy==1.26.4
2
  scipy==1.13.0
3
+ diffusers==0.29.1
4
+ transformers==4.41.2
5
  accelerate==0.29.3
 
6
  rembg==2.0.56
7
  trimesh==4.3.1
8
  opencv-contrib-python==4.9.0.80
9
  tqdm==4.66.2
10
  bpy==3.6.0
11
+ bpy-helper==0.0.1
12
+ gradio==4.36.1
13
  einops==0.7.0
14
  imageio[ffmpeg]==2.34.0
15
  torch==2.0.1
16
  torchvision==0.15.2
17
+ git+https://github.com/naver/croco/#subdirectory=models/curope
18
+ git+https://github.com/iamNCJ/dust3r.git