Spaces:
Running
on
Zero
Running
on
Zero
finish demo
Browse files- app.py +94 -40
- demo/mesh_recon.py +45 -12
- demo/relighting_gen.py +29 -17
- demo/render_hints.py +11 -8
- demo/rm_bg.py +0 -1
- requirements.txt +6 -5
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import imageio
|
3 |
import numpy as np
|
|
|
4 |
|
5 |
from demo.img_gen import img_gen
|
6 |
from demo.mesh_recon import mesh_reconstruction
|
@@ -11,7 +13,7 @@ from demo.rm_bg import rm_bg
|
|
11 |
|
12 |
with gr.Blocks(title="DiLightNet Demo") as demo:
|
13 |
gr.Markdown("""# DiLightNet: Fine-grained Lighting Control for Diffusion-based Image Generation
|
14 |
-
## A demo for generating images under point
|
15 |
|
16 |
with gr.Row():
|
17 |
# 1. Reference Image Input / Generation
|
@@ -29,6 +31,11 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
|
|
29 |
with gr.Row():
|
30 |
generate_btn = gr.Button(value="Generate")
|
31 |
generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image])
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# 2. Background Removal
|
34 |
with gr.Column(variant="panel"):
|
@@ -49,54 +56,98 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
|
|
49 |
with gr.Accordion("Options", open=False):
|
50 |
with gr.Group():
|
51 |
remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False)
|
52 |
-
fov = gr.Number(value=55., label="FOV", interactive=
|
53 |
mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.)
|
54 |
depth_estimation_btn = gr.Button(value="Estimate Depth")
|
|
|
|
|
|
|
55 |
depth_estimation_btn.click(
|
56 |
-
fn=
|
57 |
-
inputs=[
|
58 |
-
outputs=[mesh]
|
59 |
)
|
60 |
|
61 |
gr.Markdown("## Step 4. Render Hints")
|
62 |
with gr.Row():
|
63 |
with gr.Column():
|
64 |
-
hint_image = gr.Image(label="Hint Image")
|
65 |
with gr.Column():
|
66 |
-
pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01)
|
67 |
-
pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01)
|
68 |
-
pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01)
|
69 |
-
power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.)
|
70 |
-
render_btn = gr.Button(value="Render Hints")
|
71 |
res_folder_path = gr.Textbox("", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
with gr.Row():
|
92 |
res_image = gr.Image(label="Result Image")
|
93 |
with gr.Column():
|
94 |
with gr.Group():
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
with gr.Accordion("Options", open=False):
|
101 |
with gr.Row():
|
102 |
relighting_seed = gr.Number(value=3407, label="Seed", interactive=True)
|
@@ -106,7 +157,7 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
|
|
106 |
relighting_generate_btn = gr.Button(value="Generate")
|
107 |
|
108 |
def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
|
109 |
-
relighting_steps, relighting_cfg,
|
110 |
progress=gr.Progress(track_tqdm=True)):
|
111 |
relighting_gen(
|
112 |
masked_ref_img=masked_image,
|
@@ -118,16 +169,19 @@ with gr.Blocks(title="DiLightNet Demo") as demo:
|
|
118 |
seed=int(relighting_seed),
|
119 |
cfg=relighting_cfg
|
120 |
)
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
126 |
|
127 |
|
128 |
relighting_generate_btn.click(fn=gen_relighting_image,
|
129 |
inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
|
130 |
-
relighting_steps, relighting_cfg],
|
131 |
outputs=[res_image])
|
132 |
|
133 |
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
import imageio
|
4 |
import numpy as np
|
5 |
+
from einops import rearrange
|
6 |
|
7 |
from demo.img_gen import img_gen
|
8 |
from demo.mesh_recon import mesh_reconstruction
|
|
|
13 |
|
14 |
with gr.Blocks(title="DiLightNet Demo") as demo:
|
15 |
gr.Markdown("""# DiLightNet: Fine-grained Lighting Control for Diffusion-based Image Generation
|
16 |
+
## A demo for generating images under point/environmantal lighting using DiLightNet. For full usage (video generation & arbitary lighting condition), please refer to our [GitHub repository](https://github.com/iamNCJ/DiLightNet)""")
|
17 |
|
18 |
with gr.Row():
|
19 |
# 1. Reference Image Input / Generation
|
|
|
31 |
with gr.Row():
|
32 |
generate_btn = gr.Button(value="Generate")
|
33 |
generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image])
|
34 |
+
gr.Examples(
|
35 |
+
examples=[os.path.join("examples/provisional_img", i) for i in os.listdir("examples/provisional_img")],
|
36 |
+
inputs=[input_image],
|
37 |
+
examples_per_page = 20,
|
38 |
+
)
|
39 |
|
40 |
# 2. Background Removal
|
41 |
with gr.Column(variant="panel"):
|
|
|
56 |
with gr.Accordion("Options", open=False):
|
57 |
with gr.Group():
|
58 |
remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False)
|
59 |
+
fov = gr.Number(value=55., label="FOV", interactive=False)
|
60 |
mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.)
|
61 |
depth_estimation_btn = gr.Button(value="Estimate Depth")
|
62 |
+
def mesh_reconstruction_wrapper(image, mask, remove_edges, mask_threshold,
|
63 |
+
progress=gr.Progress(track_tqdm=True)):
|
64 |
+
return mesh_reconstruction(image, mask, remove_edges, None, mask_threshold)
|
65 |
depth_estimation_btn.click(
|
66 |
+
fn=mesh_reconstruction_wrapper,
|
67 |
+
inputs=[input_image, mask, remove_edges, mask_threshold],
|
68 |
+
outputs=[mesh, fov],
|
69 |
)
|
70 |
|
71 |
gr.Markdown("## Step 4. Render Hints")
|
72 |
with gr.Row():
|
73 |
with gr.Column():
|
74 |
+
hint_image = gr.Image(label="Hint Image", height=512, width=512)
|
75 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
76 |
res_folder_path = gr.Textbox("", visible=False)
|
77 |
+
is_env_lighting = gr.Checkbox(label="Use Environmental Lighting", value=True, interactive=False, visible=False)
|
78 |
+
with gr.Tab("Environmental Lighting"):
|
79 |
+
env_map_preview = gr.Image(label="Environment Map Preview", height=256, width=512, interactive=False, show_download_button=False)
|
80 |
+
env_map_path = gr.Text(interactive=False, visible=False, value="examples/env_map/grace.exr")
|
81 |
+
env_rotation = gr.Slider(value=0., label="Environment Rotation", minimum=0., maximum=360., step=0.5)
|
82 |
+
env_examples = gr.Examples(
|
83 |
+
examples=[[os.path.join("examples/env_map_preview", i), os.path.join("examples/env_map", i).replace("png", "exr")] for i in os.listdir("examples/env_map_preview")],
|
84 |
+
inputs=[env_map_preview, env_map_path],
|
85 |
+
examples_per_page = 20,
|
86 |
+
)
|
87 |
+
render_btn_env = gr.Button(value="Render Hints")
|
88 |
|
89 |
+
def render_wrapper_env(mesh, fov, env_map_path, env_rotation, progress=gr.Progress(track_tqdm=True)):
|
90 |
+
env_map_path = os.path.abspath(env_map_path)
|
91 |
+
res_path = render_hint_images_btn_func(mesh, float(fov), [(0, 0, 0)], env_map=env_map_path, env_start_azi=env_rotation / 360.)
|
92 |
+
hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.05.png", "_ggx0.13.png", "_ggx0.34.png"]]
|
93 |
+
hints = []
|
94 |
+
for hint_file in hint_files:
|
95 |
+
hint = imageio.v3.imread(hint_file)
|
96 |
+
hints.append(hint)
|
97 |
+
hints = rearrange(np.stack(hints), '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=2, n2=2)
|
98 |
+
return hints, res_path, True
|
99 |
+
render_btn_env.click(
|
100 |
+
fn=render_wrapper_env,
|
101 |
+
inputs=[mesh, fov, env_map_path, env_rotation],
|
102 |
+
outputs=[hint_image, res_folder_path, is_env_lighting]
|
103 |
+
)
|
104 |
+
|
105 |
+
with gr.Tab("Point Lighting"):
|
106 |
+
pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01)
|
107 |
+
pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01)
|
108 |
+
pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01)
|
109 |
+
power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.)
|
110 |
+
render_btn_pl = gr.Button(value="Render Hints")
|
111 |
+
|
112 |
+
def render_wrapper_pl(mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power,
|
113 |
+
progress=gr.Progress(track_tqdm=True)):
|
114 |
+
res_path = render_hint_images_btn_func(mesh, float(fov), [(pl_pos_x, pl_pos_y, pl_pos_z)], power)
|
115 |
+
hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.05.png", "_ggx0.13.png", "_ggx0.34.png"]]
|
116 |
+
hints = []
|
117 |
+
for hint_file in hint_files:
|
118 |
+
hint = imageio.v3.imread(hint_file)
|
119 |
+
hints.append(hint)
|
120 |
+
hints = rearrange(np.stack(hints), '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=2, n2=2)
|
121 |
+
return hints, res_path, False
|
122 |
|
123 |
+
render_btn_pl.click(
|
124 |
+
fn=render_wrapper_pl,
|
125 |
+
inputs=[mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power],
|
126 |
+
outputs=[hint_image, res_folder_path, is_env_lighting]
|
127 |
+
)
|
128 |
+
|
129 |
+
gr.Markdown("## Step 5. Control Lighting!")
|
130 |
with gr.Row():
|
131 |
res_image = gr.Image(label="Result Image")
|
132 |
with gr.Column():
|
133 |
with gr.Group():
|
134 |
+
with gr.Row():
|
135 |
+
relighting_prompt = gr.Textbox(value="", label="Appearance Text Prompt", lines=3,
|
136 |
+
placeholder="Input prompt here",
|
137 |
+
interactive=True)
|
138 |
+
with gr.Row():
|
139 |
+
# several example prompts
|
140 |
+
metallic_prompt_btn = gr.Button(value="Metallic", size="sm")
|
141 |
+
specular_prompt_btn = gr.Button(value="Specular", size="sm")
|
142 |
+
very_specular_prompt_btn = gr.Button(value="Very Specular", size="sm")
|
143 |
+
clear_prompt_btn = gr.Button(value="Clear", size="sm")
|
144 |
+
metallic_prompt_btn.click(fn=lambda x: x + " metallic", inputs=[relighting_prompt], outputs=[relighting_prompt])
|
145 |
+
specular_prompt_btn.click(fn=lambda x: x + " specular", inputs=[relighting_prompt], outputs=[relighting_prompt])
|
146 |
+
very_specular_prompt_btn.click(fn=lambda x: x + " very specular", inputs=[relighting_prompt], outputs=[relighting_prompt])
|
147 |
+
clear_prompt_btn.click(fn=lambda x: "", inputs=[relighting_prompt], outputs=[relighting_prompt])
|
148 |
+
with gr.Row():
|
149 |
+
reuse_btn = gr.Button(value="Reuse Provisional Image Generation Prompt")
|
150 |
+
reuse_btn.click(fn=lambda x: x, inputs=[prompt], outputs=[relighting_prompt])
|
151 |
with gr.Accordion("Options", open=False):
|
152 |
with gr.Row():
|
153 |
relighting_seed = gr.Number(value=3407, label="Seed", interactive=True)
|
|
|
157 |
relighting_generate_btn = gr.Button(value="Generate")
|
158 |
|
159 |
def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
|
160 |
+
relighting_steps, relighting_cfg, do_env_inpainting,
|
161 |
progress=gr.Progress(track_tqdm=True)):
|
162 |
relighting_gen(
|
163 |
masked_ref_img=masked_image,
|
|
|
169 |
seed=int(relighting_seed),
|
170 |
cfg=relighting_cfg
|
171 |
)
|
172 |
+
relit_img = imageio.v3.imread(res_folder_path + '/relighting00.png')
|
173 |
+
if do_env_inpainting:
|
174 |
+
bg = imageio.v3.imread(res_folder_path + f'/bg00.png') / 255.
|
175 |
+
relit_img = relit_img / 255.
|
176 |
+
mask_for_bg = imageio.v3.imread(res_folder_path + '/hint00_diffuse.png')[..., -1:] / 255.
|
177 |
+
relit_img = relit_img * mask_for_bg + bg * (1. - mask_for_bg)
|
178 |
+
relit_img = (relit_img * 255).clip(0, 255).astype(np.uint8)
|
179 |
+
return relit_img
|
180 |
|
181 |
|
182 |
relighting_generate_btn.click(fn=gen_relighting_image,
|
183 |
inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
|
184 |
+
relighting_steps, relighting_cfg, is_env_lighting],
|
185 |
outputs=[res_image])
|
186 |
|
187 |
|
demo/mesh_recon.py
CHANGED
@@ -1,18 +1,36 @@
|
|
1 |
import tempfile
|
|
|
2 |
|
3 |
import numpy as np
|
|
|
4 |
import torch
|
5 |
import trimesh
|
6 |
-
|
7 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
def get_intrinsics(H, W, fov=55.):
|
@@ -106,14 +124,29 @@ def mesh_reconstruction(
|
|
106 |
masked_image: np.ndarray,
|
107 |
mask: np.ndarray,
|
108 |
remove_edges: bool = True,
|
109 |
-
fov: float =
|
110 |
mask_threshold: float = 25.,
|
111 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
rgb = masked_image[..., :3].transpose(2, 0, 1) / 255.
|
113 |
-
sample = torch.from_numpy(rgb).to(device).unsqueeze(0).float()
|
114 |
-
with torch.no_grad():
|
115 |
-
depth = model.infer(sample)
|
116 |
-
depth = depth.squeeze().cpu().numpy()
|
117 |
|
118 |
pts3d = depth_to_points(depth[None], fov=fov)
|
119 |
pts3d = pts3d.reshape(-1, 3)
|
@@ -132,4 +165,4 @@ def mesh_reconstruction(
|
|
132 |
mesh_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
|
133 |
mesh_file_path = mesh_file.name
|
134 |
mesh.export(mesh_file_path)
|
135 |
-
return mesh_file_path
|
|
|
1 |
import tempfile
|
2 |
+
from typing import Optional
|
3 |
|
4 |
import numpy as np
|
5 |
+
import cv2
|
6 |
import torch
|
7 |
import trimesh
|
|
|
8 |
import spaces
|
9 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
10 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
11 |
+
from dust3r.inference import inference
|
12 |
+
from dust3r.image_pairs import make_pairs
|
13 |
+
|
14 |
+
|
15 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
16 |
+
|
17 |
+
model = AsymmetricCroCo3DStereo.from_pretrained("naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt").to(device).eval()
|
18 |
|
19 |
|
20 |
+
import torchvision.transforms as tvf
|
21 |
+
import PIL.Image
|
22 |
+
import numpy as np
|
23 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
24 |
+
|
25 |
|
26 |
+
def load_single_image(img_array):
|
27 |
+
imgs = []
|
28 |
+
for i in range(2):
|
29 |
+
img = PIL.Image.fromarray(img_array)
|
30 |
+
imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
|
31 |
+
[img.size[::-1]]), idx=i, instance=str(len(imgs))))
|
32 |
+
|
33 |
+
return imgs
|
34 |
|
35 |
|
36 |
def get_intrinsics(H, W, fov=55.):
|
|
|
124 |
masked_image: np.ndarray,
|
125 |
mask: np.ndarray,
|
126 |
remove_edges: bool = True,
|
127 |
+
fov: Optional[float] = None,
|
128 |
mask_threshold: float = 25.,
|
129 |
):
|
130 |
+
masked_image = cv2.resize(masked_image, (512, 512))
|
131 |
+
mask = cv2.resize(mask, (512, 512))
|
132 |
+
images = load_single_image(masked_image)
|
133 |
+
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
134 |
+
output = inference(pairs, model, device, batch_size=1)
|
135 |
+
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
136 |
+
if fov is not None:
|
137 |
+
# do not optimize focal length if fov is provided
|
138 |
+
focal = scene.imshapes[0][1] / (2 * np.tan(0.5 * fov * np.pi / 180.))
|
139 |
+
scene.preset_focal([focal, focal])
|
140 |
+
_loss = scene.compute_global_alignment(init='mst', niter=300, schedule='cosine', lr=0.01)
|
141 |
+
if fov is None:
|
142 |
+
# get the focal length from the optimized parameters
|
143 |
+
focals = scene.get_focals()
|
144 |
+
fov = 2 * (np.arctan((scene.imshapes[0][1] / (focals[0] + focals[1])).detach().cpu().numpy()) * 180 / np.pi)[0]
|
145 |
+
depth = scene.get_depthmaps()[0].detach().cpu().numpy()
|
146 |
+
if device.type == 'cuda':
|
147 |
+
torch.cuda.empty_cache()
|
148 |
+
|
149 |
rgb = masked_image[..., :3].transpose(2, 0, 1) / 255.
|
|
|
|
|
|
|
|
|
150 |
|
151 |
pts3d = depth_to_points(depth[None], fov=fov)
|
152 |
pts3d = pts3d.reshape(-1, 3)
|
|
|
165 |
mesh_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
|
166 |
mesh_file_path = mesh_file.name
|
167 |
mesh.export(mesh_file_path)
|
168 |
+
return mesh_file_path, fov
|
demo/relighting_gen.py
CHANGED
@@ -2,7 +2,7 @@ import imageio
|
|
2 |
import numpy as np
|
3 |
import spaces
|
4 |
import torch
|
5 |
-
from diffusers import UniPCMultistepScheduler, StableDiffusionControlNetPipeline
|
6 |
from diffusers.utils import get_class_from_dynamic_module
|
7 |
|
8 |
from tqdm import tqdm
|
@@ -19,37 +19,49 @@ NeuralTextureControlNetModel = get_class_from_dynamic_module(
|
|
19 |
"NeuralTextureControlNetModel"
|
20 |
)
|
21 |
controlnet = NeuralTextureControlNetModel.from_pretrained(
|
22 |
-
"
|
23 |
torch_dtype=dtype,
|
24 |
)
|
|
|
25 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
26 |
-
"stabilityai/stable-diffusion-2-1",
|
|
|
|
|
|
|
27 |
).to(device)
|
28 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
29 |
pipe.set_progress_bar_config(disable=True)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
@spaces.GPU
|
33 |
-
def relighting_gen(masked_ref_img, mask, cond_path, frames, prompt, steps, seed, cfg):
|
34 |
mask = mask[..., :1] / 255.
|
35 |
for i in tqdm(range(frames)):
|
36 |
source_image = masked_ref_img[..., :3] / 255.
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
hint = np.concatenate([mask, source_image, cond_diffuse, cond_ggx005, cond_ggx013, cond_ggx034], axis=2).astype(np.float32)[None]
|
50 |
hint = torch.from_numpy(hint).to(dtype).permute(0, 3, 1, 2).to(device)
|
51 |
generator = torch.manual_seed(seed)
|
52 |
image = pipe(
|
53 |
prompt, num_inference_steps=steps, generator=generator, image=hint, num_images_per_prompt=1, guidance_scale=cfg, output_type='np',
|
54 |
).images[0] # [H, W, C]
|
|
|
|
|
|
|
55 |
imageio.imwrite(f'{cond_path}/relighting{i:02d}.png', (image * 255).clip(0, 255).astype(np.uint8))
|
|
|
2 |
import numpy as np
|
3 |
import spaces
|
4 |
import torch
|
5 |
+
from diffusers import UniPCMultistepScheduler, StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline, ConsistencyDecoderVAE
|
6 |
from diffusers.utils import get_class_from_dynamic_module
|
7 |
|
8 |
from tqdm import tqdm
|
|
|
19 |
"NeuralTextureControlNetModel"
|
20 |
)
|
21 |
controlnet = NeuralTextureControlNetModel.from_pretrained(
|
22 |
+
"DiLightNet/DiLightNet",
|
23 |
torch_dtype=dtype,
|
24 |
)
|
25 |
+
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
26 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
27 |
+
"stabilityai/stable-diffusion-2-1",
|
28 |
+
vae=vae,
|
29 |
+
controlnet=controlnet,
|
30 |
+
torch_dtype=dtype
|
31 |
).to(device)
|
32 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
33 |
pipe.set_progress_bar_config(disable=True)
|
34 |
|
35 |
+
inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
36 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
37 |
+
torch_dtype=dtype
|
38 |
+
).to(device)
|
39 |
+
inpainting_pipe.set_progress_bar_config(disable=True)
|
40 |
+
|
41 |
|
42 |
@spaces.GPU
|
43 |
+
def relighting_gen(masked_ref_img, mask, cond_path, frames, prompt, steps, seed, cfg, inpaint=False):
|
44 |
mask = mask[..., :1] / 255.
|
45 |
for i in tqdm(range(frames)):
|
46 |
source_image = masked_ref_img[..., :3] / 255.
|
47 |
+
|
48 |
+
hint_types = ['diffuse', 'ggx0.05', 'ggx0.13', 'ggx0.34']
|
49 |
+
images = [mask, source_image]
|
50 |
+
for hint_type in hint_types:
|
51 |
+
image_path = f'{cond_path}/hint{i:02d}_{hint_type}.png'
|
52 |
+
image = imageio.v3.imread(image_path) / 255.
|
53 |
+
if image.shape[-1] == 4: # Check if the image has an alpha channel
|
54 |
+
image = image[..., :3] * image[..., 3:] # Premultiply RGB by Alpha
|
55 |
+
images.append(image)
|
56 |
+
|
57 |
+
hint = np.concatenate(images, axis=2).astype(np.float32)[None]
|
58 |
+
|
|
|
59 |
hint = torch.from_numpy(hint).to(dtype).permute(0, 3, 1, 2).to(device)
|
60 |
generator = torch.manual_seed(seed)
|
61 |
image = pipe(
|
62 |
prompt, num_inference_steps=steps, generator=generator, image=hint, num_images_per_prompt=1, guidance_scale=cfg, output_type='np',
|
63 |
).images[0] # [H, W, C]
|
64 |
+
if inpaint:
|
65 |
+
mask_image = (1. - mask)[None]
|
66 |
+
image = inpainting_pipe(prompt=prompt, image=image[None], mask_image=mask_image, generator=generator, output_type='np', cfg=3.0, strength=1.0).images[0]
|
67 |
imageio.imwrite(f'{cond_path}/relighting{i:02d}.png', (image * 255).clip(0, 255).astype(np.uint8))
|
demo/render_hints.py
CHANGED
@@ -7,7 +7,7 @@ from tqdm import tqdm
|
|
7 |
|
8 |
|
9 |
def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
|
10 |
-
env_map: Optional[str] = None, resolution=512, use_gpu=False):
|
11 |
import bpy
|
12 |
import numpy as np
|
13 |
|
@@ -73,7 +73,7 @@ def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output
|
|
73 |
output_folder = tempfile.mkdtemp()
|
74 |
for i in tqdm(range(len(pls)), desc='Rendering Hints'):
|
75 |
if env_map:
|
76 |
-
z_angle = i / len(pls) * np.pi * 2.
|
77 |
set_env_light(env_map, rotation_euler=[0, 0, z_angle])
|
78 |
else:
|
79 |
pl_pos = pls[i]
|
@@ -85,7 +85,7 @@ def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output
|
|
85 |
return output_folder
|
86 |
|
87 |
|
88 |
-
def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Optional[str] = None, resolution=512):
|
89 |
import bpy
|
90 |
import numpy as np
|
91 |
|
@@ -126,7 +126,7 @@ def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Opt
|
|
126 |
if output_folder is None:
|
127 |
output_folder = tempfile.mkdtemp()
|
128 |
for i in tqdm(range(len(pls)), desc='Rendering Env Backgrounds'):
|
129 |
-
z_angle = i / len(pls) * np.pi * 2.
|
130 |
set_env_light(env_map, rotation_euler=[0, 0, z_angle])
|
131 |
|
132 |
with stdout_redirected():
|
@@ -135,16 +135,19 @@ def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Opt
|
|
135 |
return output_folder
|
136 |
|
137 |
|
138 |
-
def render_hint_images_wrapper(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution, return_dict):
|
139 |
-
output_folder = render_hint_images(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution)
|
|
|
|
|
|
|
140 |
return_dict['output_folder'] = output_folder
|
141 |
|
142 |
|
143 |
def render_hint_images_btn_func(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
|
144 |
-
env_map: Optional[str] = None, resolution=512):
|
145 |
manager = multiprocessing.Manager()
|
146 |
return_dict = manager.dict()
|
147 |
-
p = Process(target=render_hint_images_wrapper, args=(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution, return_dict))
|
148 |
p.start()
|
149 |
p.join()
|
150 |
return return_dict['output_folder']
|
|
|
7 |
|
8 |
|
9 |
def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
|
10 |
+
env_map: Optional[str] = None, env_start_azi=0., resolution=512, use_gpu=False):
|
11 |
import bpy
|
12 |
import numpy as np
|
13 |
|
|
|
73 |
output_folder = tempfile.mkdtemp()
|
74 |
for i in tqdm(range(len(pls)), desc='Rendering Hints'):
|
75 |
if env_map:
|
76 |
+
z_angle = (i / len(pls) + env_start_azi) * np.pi * 2.
|
77 |
set_env_light(env_map, rotation_euler=[0, 0, z_angle])
|
78 |
else:
|
79 |
pl_pos = pls[i]
|
|
|
85 |
return output_folder
|
86 |
|
87 |
|
88 |
+
def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Optional[str] = None, env_start_azi=0., resolution=512):
|
89 |
import bpy
|
90 |
import numpy as np
|
91 |
|
|
|
126 |
if output_folder is None:
|
127 |
output_folder = tempfile.mkdtemp()
|
128 |
for i in tqdm(range(len(pls)), desc='Rendering Env Backgrounds'):
|
129 |
+
z_angle = (i / len(pls) + env_start_azi) * np.pi * 2.
|
130 |
set_env_light(env_map, rotation_euler=[0, 0, z_angle])
|
131 |
|
132 |
with stdout_redirected():
|
|
|
135 |
return output_folder
|
136 |
|
137 |
|
138 |
+
def render_hint_images_wrapper(model_path, fov, pls, power, geo_smooth, output_folder, env_map, env_start_azi, resolution, return_dict):
|
139 |
+
output_folder = render_hint_images(model_path, fov, pls, power, geo_smooth, output_folder, env_map, env_start_azi, resolution)
|
140 |
+
if env_map is not None:
|
141 |
+
bg_output_folder = render_bg_images(fov, pls, output_folder, env_map, env_start_azi, resolution)
|
142 |
+
return_dict['bg_output_folder'] = bg_output_folder
|
143 |
return_dict['output_folder'] = output_folder
|
144 |
|
145 |
|
146 |
def render_hint_images_btn_func(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
|
147 |
+
env_map: Optional[str] = None, env_start_azi=0., resolution=512):
|
148 |
manager = multiprocessing.Manager()
|
149 |
return_dict = manager.dict()
|
150 |
+
p = Process(target=render_hint_images_wrapper, args=(model_path, fov, pls, power, geo_smooth, output_folder, env_map, env_start_azi, resolution, return_dict))
|
151 |
p.start()
|
152 |
p.join()
|
153 |
return return_dict['output_folder']
|
demo/rm_bg.py
CHANGED
@@ -8,7 +8,6 @@ def rm_bg(img, use_sam=False):
|
|
8 |
img = img.resize((512, 512))
|
9 |
output = rembg.remove(img)
|
10 |
mask = np.array(output)[:, :, 3]
|
11 |
-
print(mask.shape)
|
12 |
|
13 |
# use sam for mask refinement
|
14 |
if use_sam:
|
|
|
8 |
img = img.resize((512, 512))
|
9 |
output = rembg.remove(img)
|
10 |
mask = np.array(output)[:, :, 3]
|
|
|
11 |
|
12 |
# use sam for mask refinement
|
13 |
if use_sam:
|
requirements.txt
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
numpy==1.26.4
|
2 |
scipy==1.13.0
|
3 |
-
diffusers==0.
|
4 |
-
transformers==4.
|
5 |
accelerate==0.29.3
|
6 |
-
timm==0.6.12 # must use this version, required by MiDaS
|
7 |
rembg==2.0.56
|
8 |
trimesh==4.3.1
|
9 |
opencv-contrib-python==4.9.0.80
|
10 |
tqdm==4.66.2
|
11 |
bpy==3.6.0
|
12 |
-
bpy-helper==0.0.
|
13 |
-
gradio==4.
|
14 |
einops==0.7.0
|
15 |
imageio[ffmpeg]==2.34.0
|
16 |
torch==2.0.1
|
17 |
torchvision==0.15.2
|
|
|
|
|
|
1 |
numpy==1.26.4
|
2 |
scipy==1.13.0
|
3 |
+
diffusers==0.29.1
|
4 |
+
transformers==4.41.2
|
5 |
accelerate==0.29.3
|
|
|
6 |
rembg==2.0.56
|
7 |
trimesh==4.3.1
|
8 |
opencv-contrib-python==4.9.0.80
|
9 |
tqdm==4.66.2
|
10 |
bpy==3.6.0
|
11 |
+
bpy-helper==0.0.1
|
12 |
+
gradio==4.36.1
|
13 |
einops==0.7.0
|
14 |
imageio[ffmpeg]==2.34.0
|
15 |
torch==2.0.1
|
16 |
torchvision==0.15.2
|
17 |
+
git+https://github.com/naver/croco/#subdirectory=models/curope
|
18 |
+
git+https://github.com/iamNCJ/dust3r.git
|