NCJ commited on
Commit
084ab29
1 Parent(s): 34376e6
Files changed (11) hide show
  1. .gitignore +154 -0
  2. LICENSE +21 -0
  3. README.md +5 -3
  4. app.py +131 -0
  5. demo/__init__.py +0 -0
  6. demo/img_gen.py +26 -0
  7. demo/mesh_recon.py +129 -0
  8. demo/relighting_gen.py +53 -0
  9. demo/render_hints.py +150 -0
  10. demo/rm_bg.py +23 -0
  11. requirements.txt +15 -0
.gitignore ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105
+ __pypackages__/
106
+
107
+ # Celery stuff
108
+ celerybeat-schedule
109
+ celerybeat.pid
110
+
111
+ # SageMath parsed files
112
+ *.sage.py
113
+
114
+ # Environments
115
+ .env
116
+ .venv
117
+ env/
118
+ venv/
119
+ ENV/
120
+ env.bak/
121
+ venv.bak/
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ # Pyre type checker
139
+ .pyre/
140
+
141
+ # pytype static type analyzer
142
+ .pytype/
143
+
144
+ # Cython debug symbols
145
+ cython_debug/
146
+
147
+ # PyCharm
148
+ # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
149
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
150
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
151
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
152
+ .idea/
153
+
154
+ output/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NCJ
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
  title: DiLightNet
3
- emoji: 🏆
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.27.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
  title: DiLightNet
3
+ emoji: 💡
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.27.0
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
13
+ # DiLightNet: Fine-grained Lighting Control for Diffusion-based Image Generation
14
+
15
+ https://arxiv.org/abs/2402.11929
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import imageio
3
+ import numpy as np
4
+
5
+ from demo.img_gen import img_gen
6
+ from demo.mesh_recon import mesh_reconstruction
7
+ from demo.relighting_gen import relighting_gen
8
+ from demo.render_hints import render_hint_images_btn_func
9
+ from demo.rm_bg import rm_bg
10
+
11
+
12
+ with gr.Blocks(title="DiLightNet Demo") as demo:
13
+ gr.Markdown("# DiLightNet: Fine-grained Lighting Control for Image Diffusion")
14
+
15
+ with gr.Row():
16
+ # 1. Reference Image Input / Generation
17
+ with gr.Column(variant="panel"):
18
+ gr.Markdown("## Step 1. Input or Generate Reference Image")
19
+ input_image = gr.Image(height=512, width=512, label="Input Image", interactive=True)
20
+ with gr.Accordion("Generate Image", open=False):
21
+ with gr.Group():
22
+ prompt = gr.Textbox(value="", label="Prompt", lines=3, placeholder="Input prompt here")
23
+ with gr.Row():
24
+ seed = gr.Number(value=42, label="Seed", interactive=True)
25
+ steps = gr.Number(value=20, label="Steps", interactive=True)
26
+ cfg = gr.Number(value=7.5, label="CFG", interactive=True)
27
+ down_from_768 = gr.Checkbox(label="Downsample from 768", value=True)
28
+ with gr.Row():
29
+ generate_btn = gr.Button(value="Generate")
30
+ generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image])
31
+
32
+ # 2. Background Removal
33
+ with gr.Column(variant="panel"):
34
+ gr.Markdown("## Step 2. Remove Background")
35
+ with gr.Tab("Masked Image"):
36
+ masked_image = gr.Image(height=512, width=512, label="Masked Image", interactive=True)
37
+ with gr.Tab("Mask"):
38
+ mask = gr.Image(height=512, width=512, label="Mask", interactive=False)
39
+ use_sam = gr.Checkbox(label="Use SAM for Refinement", value=False)
40
+ rm_bg_btn = gr.Button(value="Remove Background")
41
+ rm_bg_btn.click(fn=rm_bg, inputs=[input_image, use_sam], outputs=[masked_image, mask])
42
+
43
+ # 3. Depth Estimation & Mesh Reconstruction
44
+ with gr.Column(variant="panel"):
45
+ gr.Markdown("## Step 3. Depth Estimation & Mesh Reconstruction")
46
+ mesh = gr.Model3D(label="Mesh Reconstruction", clear_color=(1.0, 1.0, 1.0, 1.0), interactive=True)
47
+ with gr.Column():
48
+ with gr.Accordion("Options", open=False):
49
+ with gr.Group():
50
+ remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False)
51
+ fov = gr.Number(value=55., label="FOV", interactive=True)
52
+ mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.)
53
+ depth_estimation_btn = gr.Button(value="Estimate Depth")
54
+ depth_estimation_btn.click(
55
+ fn=mesh_reconstruction,
56
+ inputs=[masked_image, mask, remove_edges, fov, mask_threshold],
57
+ outputs=[mesh]
58
+ )
59
+
60
+ gr.Markdown("## Step 4. Render Hints")
61
+ with gr.Row():
62
+ with gr.Column():
63
+ hint_image = gr.Image(label="Hint Image")
64
+ with gr.Column():
65
+ pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01)
66
+ pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01)
67
+ pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01)
68
+ power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.)
69
+ render_btn = gr.Button(value="Render Hints")
70
+ res_folder_path = gr.Textbox("", visible=False)
71
+
72
+ def render_wrapper(mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power,
73
+ progress=gr.Progress(track_tqdm=True)):
74
+ res_path = render_hint_images_btn_func(mesh, fov, [(pl_pos_x, pl_pos_y, pl_pos_z)], power)
75
+ hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.34.png"]]
76
+ hints = []
77
+ for hint_file in hint_files:
78
+ hint = imageio.v3.imread(hint_file)
79
+ hints.append(hint)
80
+ hints = np.concatenate(hints, axis=1)
81
+ return hints, res_path
82
+
83
+ render_btn.click(
84
+ fn=render_wrapper,
85
+ inputs=[mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power],
86
+ outputs=[hint_image, res_folder_path]
87
+ )
88
+
89
+ gr.Markdown("## Step 5. Relighting!")
90
+ with gr.Row():
91
+ res_image = gr.Image(label="Result Image")
92
+ with gr.Column():
93
+ with gr.Group():
94
+ relighting_prompt = gr.Textbox(value="", label="Relighting Text Prompt", lines=3,
95
+ placeholder="Input prompt here",
96
+ interactive=True)
97
+ reuse_btn = gr.Button(value="Reuse Image Generation Prompt")
98
+ reuse_btn.click(fn=lambda x: x, inputs=[prompt], outputs=[relighting_prompt])
99
+ with gr.Accordion("Options", open=False):
100
+ with gr.Row():
101
+ relighting_seed = gr.Number(value=3407, label="Seed", interactive=True)
102
+ relighting_steps = gr.Number(value=20, label="Steps", interactive=True)
103
+ relighting_cfg = gr.Number(value=3.0, label="CFG", interactive=True)
104
+ with gr.Row():
105
+ relighting_generate_btn = gr.Button(value="Generate")
106
+
107
+ def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
108
+ relighting_steps, relighting_cfg,
109
+ progress=gr.Progress(track_tqdm=True)):
110
+ relighting_gen(
111
+ masked_ref_img=masked_image,
112
+ mask=mask,
113
+ cond_path=res_folder_path,
114
+ frames=1,
115
+ prompt=relighting_prompt,
116
+ steps=int(relighting_steps),
117
+ seed=int(relighting_seed),
118
+ cfg=relighting_cfg
119
+ )
120
+ res = imageio.v3.imread(res_folder_path + '/relighting00.png')
121
+ return res
122
+
123
+
124
+ relighting_generate_btn.click(fn=gen_relighting_image,
125
+ inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
126
+ relighting_steps, relighting_cfg],
127
+ outputs=[res_image])
128
+
129
+
130
+ if __name__ == '__main__':
131
+ demo.queue().launch(server_name="0.0.0.0", share=True)
demo/__init__.py ADDED
File without changes
demo/img_gen.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
+
6
+
7
+ model_id = "stabilityai/stable-diffusion-2-1"
8
+
9
+ device = torch.device('cpu')
10
+ dtype = torch.float32
11
+ if torch.cuda.is_available():
12
+ device = torch.device('cuda')
13
+ dtype = torch.float16
14
+
15
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
16
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
17
+ pipe = pipe.to(device)
18
+
19
+
20
+ def img_gen(prompt, seed, steps, cfg, down_from_768=False, progress=gr.Progress(track_tqdm=True)):
21
+ generator = torch.Generator(device=device).manual_seed(int(seed))
22
+ hw = 512 if not down_from_768 else 768
23
+ image = pipe(prompt, generator=generator, num_inference_steps=int(steps), guidance_scale=cfg, output_type='np', height=hw, width=hw).images[0]
24
+ if down_from_768:
25
+ image = F.interpolate(torch.from_numpy(image)[None].permute(0, 3, 1, 2), size=(512, 512), mode='bilinear', align_corners=False, antialias=True).permute(0, 2, 3, 1)[0].cpu().numpy()
26
+ return image
demo/mesh_recon.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+
3
+ import numpy as np
4
+ import torch
5
+ import trimesh
6
+
7
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
8
+
9
+ # use torch hub
10
+ model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_NK", pretrained=True).to(device).eval()
11
+
12
+
13
+ def get_intrinsics(H, W, fov=55.):
14
+ """
15
+ Intrinsics for a pinhole camera model.
16
+ Assume central principal point.
17
+ """
18
+ f = 0.5 * W / np.tan(0.5 * fov * np.pi / 180.0)
19
+ cx = 0.5 * W
20
+ cy = 0.5 * H
21
+ return np.array([[f, 0, cx],
22
+ [0, f, cy],
23
+ [0, 0, 1]])
24
+
25
+
26
+ def depth_to_points(depth, R=None, t=None, fov=55.):
27
+ K = get_intrinsics(depth.shape[1], depth.shape[2], fov=fov)
28
+ Kinv = np.linalg.inv(K)
29
+ if R is None:
30
+ R = np.eye(3)
31
+ if t is None:
32
+ t = np.zeros(3)
33
+
34
+ # M converts from your coordinate to PyTorch3D's coordinate system
35
+ M = np.eye(3)
36
+ M[0, 0] = -1.0
37
+ M[1, 1] = -1.0
38
+
39
+ height, width = depth.shape[1:3]
40
+
41
+ x = np.arange(width)
42
+ y = np.arange(height)
43
+ coord = np.stack(np.meshgrid(x, y), -1)
44
+ coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1
45
+ coord = coord.astype(np.float32)
46
+ coord = coord[None] # bs, h, w, 3
47
+
48
+ D = depth[:, :, :, None, None]
49
+ pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None]
50
+ # pts3D_1 live in your coordinate system. Convert them to Py3D's
51
+ pts3D_1 = M[None, None, None, ...] @ pts3D_1
52
+ # from reference to targe tviewpoint
53
+ pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None]
54
+ return pts3D_2[:, :, :, :3, 0][0]
55
+
56
+
57
+ def create_triangles(h, w, mask=None):
58
+ """
59
+ Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68
60
+ Creates mesh triangle indices from a given pixel grid size.
61
+ This function is not and need not be differentiable as triangle indices are
62
+ fixed.
63
+ Args:
64
+ h: (int) denoting the height of the image.
65
+ w: (int) denoting the width of the image.
66
+ Returns:
67
+ triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
68
+ """
69
+ x, y = np.meshgrid(range(w - 1), range(h - 1))
70
+ tl = y * w + x
71
+ tr = y * w + x + 1
72
+ bl = (y + 1) * w + x
73
+ br = (y + 1) * w + x + 1
74
+ triangles = np.array([tl, bl, tr, br, tr, bl])
75
+ triangles = np.transpose(triangles, (1, 2, 0)).reshape(
76
+ ((w - 1) * (h - 1) * 2, 3))
77
+ if mask is not None:
78
+ mask = mask.reshape(-1)
79
+ triangles = triangles[mask[triangles].all(1)]
80
+ return triangles
81
+
82
+
83
+ def depth_edges_mask(depth):
84
+ """Returns a mask of edges in the depth map.
85
+ Args:
86
+ depth: 2D numpy array of shape (H, W) with dtype float32.
87
+ Returns:
88
+ mask: 2D numpy array of shape (H, W) with dtype bool.
89
+ """
90
+ # Compute the x and y gradients of the depth map.
91
+ depth_dx, depth_dy = np.gradient(depth)
92
+ # Compute the gradient magnitude.
93
+ depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
94
+ # Compute the edge mask.
95
+ mask = depth_grad > 0.05
96
+ return mask
97
+
98
+
99
+ def mesh_reconstruction(
100
+ masked_image: np.ndarray,
101
+ mask: np.ndarray,
102
+ remove_edges: bool = True,
103
+ fov: float = 55.,
104
+ mask_threshold: float = 25.,
105
+ ):
106
+ rgb = masked_image[..., :3].transpose(2, 0, 1) / 255.
107
+ sample = torch.from_numpy(rgb).to(device).unsqueeze(0).float()
108
+ with torch.no_grad():
109
+ depth = model.infer(sample)
110
+ depth = depth.squeeze().cpu().numpy()
111
+
112
+ pts3d = depth_to_points(depth[None], fov=fov)
113
+ pts3d = pts3d.reshape(-1, 3)
114
+ pts3d = pts3d.reshape(-1, 3)
115
+ verts = pts3d.reshape(-1, 3)
116
+ rgb = rgb.transpose(1, 2, 0)
117
+ mask = mask[..., 0] > mask_threshold
118
+ edge_mask = depth_edges_mask(depth)
119
+ if remove_edges:
120
+ mask = np.logical_and(mask, ~edge_mask)
121
+ triangles = create_triangles(rgb.shape[0], rgb.shape[1], mask=mask)
122
+ colors = rgb.reshape(-1, 3)
123
+ mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
124
+
125
+ # Save as glb tmp file (obj will look inverted in ui)
126
+ mesh_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
127
+ mesh_file_path = mesh_file.name
128
+ mesh.export(mesh_file_path)
129
+ return mesh_file_path
demo/relighting_gen.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import UniPCMultistepScheduler, StableDiffusionControlNetPipeline
5
+ from diffusers.utils import get_class_from_dynamic_module
6
+
7
+ from tqdm import tqdm
8
+
9
+ device = torch.device('cpu')
10
+ dtype = torch.float32
11
+ if torch.cuda.is_available():
12
+ device = torch.device('cuda')
13
+ dtype = torch.float16
14
+
15
+ NeuralTextureControlNetModel = get_class_from_dynamic_module(
16
+ "dilightnet/model_helpers",
17
+ "neuraltexture_controlnet.py",
18
+ "NeuralTextureControlNetModel"
19
+ )
20
+ controlnet = NeuralTextureControlNetModel.from_pretrained(
21
+ "dilightnet/DiLightNet",
22
+ torch_dtype=dtype,
23
+ )
24
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
25
+ "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=dtype
26
+ ).to(device)
27
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
28
+ pipe.set_progress_bar_config(disable=True)
29
+
30
+
31
+ def relighting_gen(masked_ref_img, mask, cond_path, frames, prompt, steps, seed, cfg):
32
+ mask = mask[..., :1] / 255.
33
+ for i in tqdm(range(frames)):
34
+ source_image = masked_ref_img[..., :3] / 255.
35
+ cond_diffuse = imageio.v3.imread(f'{cond_path}/hint{i:02d}_diffuse.png') / 255.
36
+ if cond_diffuse.shape[-1] == 4:
37
+ cond_diffuse = cond_diffuse[..., :3] * cond_diffuse[..., 3:]
38
+ cond_ggx034 = imageio.v3.imread(f'{cond_path}/hint{i:02d}_ggx0.34.png') / 255.
39
+ if cond_ggx034.shape[-1] == 4:
40
+ cond_ggx034 = cond_ggx034[..., :3] * cond_ggx034[..., 3:]
41
+ cond_ggx013 = imageio.v3.imread(f'{cond_path}/hint{i:02d}_ggx0.13.png') / 255.
42
+ if cond_ggx013.shape[-1] == 4:
43
+ cond_ggx013 = cond_ggx013[..., :3] * cond_ggx013[..., 3:]
44
+ cond_ggx005 = imageio.v3.imread(f'{cond_path}/hint{i:02d}_ggx0.05.png') / 255.
45
+ if cond_ggx005.shape[-1] == 4:
46
+ cond_ggx005 = cond_ggx005[..., :3] * cond_ggx005[..., 3:]
47
+ hint = np.concatenate([mask, source_image, cond_diffuse, cond_ggx005, cond_ggx013, cond_ggx034], axis=2).astype(np.float32)[None]
48
+ hint = torch.from_numpy(hint).to(dtype).permute(0, 3, 1, 2).to(device)
49
+ generator = torch.manual_seed(seed)
50
+ image = pipe(
51
+ prompt, num_inference_steps=steps, generator=generator, image=hint, num_images_per_prompt=1, guidance_scale=cfg, output_type='np',
52
+ ).images[0] # [H, W, C]
53
+ imageio.imwrite(f'{cond_path}/relighting{i:02d}.png', (image * 255).clip(0, 255).astype(np.uint8))
demo/render_hints.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import tempfile
3
+ from multiprocessing import Process
4
+ from typing import Optional
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ def render_hint_images(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
10
+ env_map: Optional[str] = None, resolution=512, use_gpu=False):
11
+ import bpy
12
+ import numpy as np
13
+
14
+ from bpy_helper.camera import create_camera
15
+ from bpy_helper.light import set_env_light, create_point_light
16
+ from bpy_helper.material import create_white_diffuse_material, create_specular_ggx_material
17
+ from bpy_helper.scene import reset_scene, import_3d_model
18
+ from bpy_helper.utils import stdout_redirected
19
+
20
+ def configure_blender():
21
+ # Set the render resolution
22
+ bpy.context.scene.render.resolution_x = resolution
23
+ bpy.context.scene.render.resolution_y = resolution
24
+ bpy.context.scene.render.engine = 'CYCLES'
25
+ bpy.context.scene.cycles.samples = 512
26
+ if use_gpu:
27
+ bpy.context.preferences.addons["cycles"].preferences.get_devices()
28
+ bpy.context.scene.cycles.device = 'GPU'
29
+ bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
30
+
31
+ # Enable the alpha channel for GT mask
32
+ bpy.context.scene.render.film_transparent = True
33
+ bpy.context.scene.render.image_settings.color_mode = 'RGBA'
34
+
35
+ def render_rgb_and_hint(output_path):
36
+ MAT_DICT = {
37
+ '_diffuse': create_white_diffuse_material(),
38
+ '_ggx0.05': create_specular_ggx_material(0.05),
39
+ '_ggx0.13': create_specular_ggx_material(0.13),
40
+ '_ggx0.34': create_specular_ggx_material(0.34),
41
+ }
42
+
43
+ # render
44
+ for mat_name, mat in MAT_DICT.items():
45
+ bpy.context.scene.view_layers["ViewLayer"].material_override = mat
46
+
47
+ # and png
48
+ bpy.context.scene.render.image_settings.file_format = 'PNG'
49
+ bpy.context.scene.render.filepath = f'{output_path}{mat_name}.png'
50
+ bpy.ops.render.render(animation=False, write_still=True)
51
+
52
+ # Render hints
53
+ reset_scene()
54
+ import_3d_model(model_path)
55
+ if geo_smooth:
56
+ for obj in bpy.data.objects:
57
+ if obj.type == 'MESH':
58
+ obj.modifiers.new("Smooth", type="SMOOTH")
59
+ smooth_modifier = obj.modifiers["Smooth"]
60
+ smooth_modifier.factor = 0.5
61
+ smooth_modifier.iterations = 150
62
+ configure_blender()
63
+
64
+ c2w = np.array([
65
+ [-1, 0, 0, 0],
66
+ [0, 0, 1, 0],
67
+ [0, 1, 0, 0],
68
+ [0, 0, 0, 0]
69
+ ])
70
+ camera = create_camera(c2w, fov)
71
+ bpy.context.scene.camera = camera
72
+ if output_folder is None:
73
+ output_folder = tempfile.mkdtemp()
74
+ for i in tqdm(range(len(pls)), desc='Rendering Hints'):
75
+ if env_map:
76
+ z_angle = i / len(pls) * np.pi * 2.
77
+ set_env_light(env_map, rotation_euler=[0, 0, z_angle])
78
+ else:
79
+ pl_pos = pls[i]
80
+ _point_light = create_point_light(pl_pos, power)
81
+
82
+ with stdout_redirected():
83
+ render_rgb_and_hint(output_folder + f'/hint{i:02d}')
84
+
85
+ return output_folder
86
+
87
+
88
+ def render_bg_images(fov, pls, output_folder: Optional[str] = None, env_map: Optional[str] = None, resolution=512):
89
+ import bpy
90
+ import numpy as np
91
+
92
+ from bpy_helper.camera import create_camera
93
+ from bpy_helper.light import set_env_light
94
+ from bpy_helper.scene import reset_scene
95
+ from bpy_helper.utils import stdout_redirected
96
+
97
+ def configure_blender():
98
+ # Set the render resolution
99
+ bpy.context.scene.render.resolution_x = resolution
100
+ bpy.context.scene.render.resolution_y = resolution
101
+ bpy.context.scene.render.engine = 'CYCLES'
102
+ bpy.context.scene.cycles.samples = 512
103
+
104
+ # Enable the alpha channel for GT mask
105
+ bpy.context.scene.render.film_transparent = False
106
+ bpy.context.scene.render.image_settings.color_mode = 'RGB'
107
+
108
+ def render_env_bg(output_path):
109
+ bpy.context.scene.view_layers["ViewLayer"].material_override = None
110
+ bpy.context.scene.render.image_settings.file_format = 'PNG'
111
+ bpy.context.scene.render.filepath = f'{output_path}.png'
112
+ bpy.ops.render.render(animation=False, write_still=True)
113
+
114
+ # Render backgrounds
115
+ reset_scene()
116
+ configure_blender()
117
+
118
+ c2w = np.array([
119
+ [-1, 0, 0, 0],
120
+ [0, 0, 1, 0],
121
+ [0, 1, 0, 0],
122
+ [0, 0, 0, 0]
123
+ ])
124
+ camera = create_camera(c2w, fov)
125
+ bpy.context.scene.camera = camera
126
+ if output_folder is None:
127
+ output_folder = tempfile.mkdtemp()
128
+ for i in tqdm(range(len(pls)), desc='Rendering Env Backgrounds'):
129
+ z_angle = i / len(pls) * np.pi * 2.
130
+ set_env_light(env_map, rotation_euler=[0, 0, z_angle])
131
+
132
+ with stdout_redirected():
133
+ render_env_bg(output_folder + f'/bg{i:02d}')
134
+
135
+ return output_folder
136
+
137
+
138
+ def render_hint_images_wrapper(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution, return_dict):
139
+ output_folder = render_hint_images(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution)
140
+ return_dict['output_folder'] = output_folder
141
+
142
+
143
+ def render_hint_images_btn_func(model_path, fov, pls, power=500., geo_smooth=True, output_folder: Optional[str] = None,
144
+ env_map: Optional[str] = None, resolution=512):
145
+ manager = multiprocessing.Manager()
146
+ return_dict = manager.dict()
147
+ p = Process(target=render_hint_images_wrapper, args=(model_path, fov, pls, power, geo_smooth, output_folder, env_map, resolution, return_dict))
148
+ p.start()
149
+ p.join()
150
+ return return_dict['output_folder']
demo/rm_bg.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import rembg
3
+
4
+
5
+ def rm_bg(img, use_sam=False):
6
+ output = rembg.remove(img)
7
+ mask = np.array(output)[:, :, 3]
8
+
9
+ # use sam for mask refinement
10
+ if use_sam:
11
+ session = rembg.new_session('sam', sam_model='sam_vit_h_4b8939')
12
+ bool_mask = mask > 0
13
+ y1, y2, x1, x2 = (
14
+ np.nonzero(bool_mask)[0].min(),
15
+ np.nonzero(bool_mask)[0].max(),
16
+ np.nonzero(bool_mask)[1].min(),
17
+ np.nonzero(bool_mask)[1].max()
18
+ )
19
+ output = rembg.remove(img, session=session, sam_prompt=[
20
+ {'type': 'rectangle', 'label': 1, 'data': [x1, y1, x2, y2]}
21
+ ])
22
+
23
+ return output, mask
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ scipy==1.13.0
3
+ diffusers==0.27.2
4
+ transformers==4.39.3
5
+ accelerate==0.29.3
6
+ timm==0.6.12 # must use this version, required by MiDaS
7
+ rembg==2.0.56
8
+ trimesh==4.3.1
9
+ opencv-contrib-python==4.9.0.80
10
+ tqdm==4.66.2
11
+ bpy==3.6.0
12
+ bpy-helper==0.0.0
13
+ gradio==4.27.0
14
+ einops==0.7.0
15
+ imageio[ffmpeg]==2.34.0