import gradio as gr import imageio import numpy as np from demo.img_gen import img_gen from demo.mesh_recon import mesh_reconstruction from demo.relighting_gen import relighting_gen from demo.render_hints import render_hint_images_btn_func from demo.rm_bg import rm_bg with gr.Blocks(title="DiLightNet Demo") as demo: gr.Markdown("""# DiLightNet: Fine-grained Lighting Control for Diffusion-based Image Generation ## A demo for generating images under point lights using DiLightNet. For full usage, please refer to our [GitHub repository](TBD)""") with gr.Row(): # 1. Reference Image Input / Generation with gr.Column(variant="panel"): gr.Markdown("## Step 1. Input or Generate Reference Image") input_image = gr.Image(height=512, width=512, label="Input Image", interactive=True) with gr.Accordion("Generate Image", open=False): with gr.Group(): prompt = gr.Textbox(value="", label="Prompt", lines=3, placeholder="Input prompt here") with gr.Row(): seed = gr.Number(value=42, label="Seed", interactive=True) steps = gr.Number(value=20, label="Steps", interactive=True) cfg = gr.Number(value=7.5, label="CFG", interactive=True) down_from_768 = gr.Checkbox(label="Downsample from 768", value=True) with gr.Row(): generate_btn = gr.Button(value="Generate") generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image]) # 2. Background Removal with gr.Column(variant="panel"): gr.Markdown("## Step 2. Remove Background") with gr.Tab("Masked Image"): masked_image = gr.Image(height=512, width=512, label="Masked Image", interactive=True) with gr.Tab("Mask"): mask = gr.Image(height=512, width=512, label="Mask", interactive=False) use_sam = gr.Checkbox(label="Use SAM for Refinement", value=False) rm_bg_btn = gr.Button(value="Remove Background") rm_bg_btn.click(fn=rm_bg, inputs=[input_image, use_sam], outputs=[masked_image, mask]) # 3. Depth Estimation & Mesh Reconstruction with gr.Column(variant="panel"): gr.Markdown("## Step 3. Depth Estimation & Mesh Reconstruction") mesh = gr.Model3D(label="Mesh Reconstruction", clear_color=(1.0, 1.0, 1.0, 1.0), interactive=True) with gr.Column(): with gr.Accordion("Options", open=False): with gr.Group(): remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False) fov = gr.Number(value=55., label="FOV", interactive=True) mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.) depth_estimation_btn = gr.Button(value="Estimate Depth") depth_estimation_btn.click( fn=mesh_reconstruction, inputs=[masked_image, mask, remove_edges, fov, mask_threshold], outputs=[mesh] ) gr.Markdown("## Step 4. Render Hints") with gr.Row(): with gr.Column(): hint_image = gr.Image(label="Hint Image") with gr.Column(): pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01) pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01) pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01) power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.) render_btn = gr.Button(value="Render Hints") res_folder_path = gr.Textbox("", visible=False) def render_wrapper(mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power, progress=gr.Progress(track_tqdm=True)): res_path = render_hint_images_btn_func(mesh, fov, [(pl_pos_x, pl_pos_y, pl_pos_z)], power) hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.34.png"]] hints = [] for hint_file in hint_files: hint = imageio.v3.imread(hint_file) hints.append(hint) hints = np.concatenate(hints, axis=1) return hints, res_path render_btn.click( fn=render_wrapper, inputs=[mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power], outputs=[hint_image, res_folder_path] ) gr.Markdown("## Step 5. Lighting Control Generation") with gr.Row(): res_image = gr.Image(label="Result Image") with gr.Column(): with gr.Group(): relighting_prompt = gr.Textbox(value="", label="Lighting-Control Text Prompt", lines=3, placeholder="Input prompt here", interactive=True) reuse_btn = gr.Button(value="Reuse Image Generation Prompt") reuse_btn.click(fn=lambda x: x, inputs=[prompt], outputs=[relighting_prompt]) with gr.Accordion("Options", open=False): with gr.Row(): relighting_seed = gr.Number(value=3407, label="Seed", interactive=True) relighting_steps = gr.Number(value=20, label="Steps", interactive=True) relighting_cfg = gr.Number(value=3.0, label="CFG", interactive=True) with gr.Row(): relighting_generate_btn = gr.Button(value="Generate") def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed, relighting_steps, relighting_cfg, progress=gr.Progress(track_tqdm=True)): relighting_gen( masked_ref_img=masked_image, mask=mask, cond_path=res_folder_path, frames=1, prompt=relighting_prompt, steps=int(relighting_steps), seed=int(relighting_seed), cfg=relighting_cfg ) mask_for_bg = imageio.v3.imread(res_folder_path + '/hint00_diffuse.png')[..., -1:] / 255. res = imageio.v3.imread(res_folder_path + '/relighting00.png') / 255. res = res * mask_for_bg # + bg * (1. - mask_for_bg) res = (res * 255).clip(0, 255).astype(np.uint8) return res relighting_generate_btn.click(fn=gen_relighting_image, inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed, relighting_steps, relighting_cfg], outputs=[res_image]) if __name__ == '__main__': demo.queue().launch(server_name="0.0.0.0", share=True)