DiLightNet / app.py
NCJ's picture
init
084ab29 verified
raw
history blame
6.88 kB
import gradio as gr
import imageio
import numpy as np
from demo.img_gen import img_gen
from demo.mesh_recon import mesh_reconstruction
from demo.relighting_gen import relighting_gen
from demo.render_hints import render_hint_images_btn_func
from demo.rm_bg import rm_bg
with gr.Blocks(title="DiLightNet Demo") as demo:
gr.Markdown("# DiLightNet: Fine-grained Lighting Control for Image Diffusion")
with gr.Row():
# 1. Reference Image Input / Generation
with gr.Column(variant="panel"):
gr.Markdown("## Step 1. Input or Generate Reference Image")
input_image = gr.Image(height=512, width=512, label="Input Image", interactive=True)
with gr.Accordion("Generate Image", open=False):
with gr.Group():
prompt = gr.Textbox(value="", label="Prompt", lines=3, placeholder="Input prompt here")
with gr.Row():
seed = gr.Number(value=42, label="Seed", interactive=True)
steps = gr.Number(value=20, label="Steps", interactive=True)
cfg = gr.Number(value=7.5, label="CFG", interactive=True)
down_from_768 = gr.Checkbox(label="Downsample from 768", value=True)
with gr.Row():
generate_btn = gr.Button(value="Generate")
generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image])
# 2. Background Removal
with gr.Column(variant="panel"):
gr.Markdown("## Step 2. Remove Background")
with gr.Tab("Masked Image"):
masked_image = gr.Image(height=512, width=512, label="Masked Image", interactive=True)
with gr.Tab("Mask"):
mask = gr.Image(height=512, width=512, label="Mask", interactive=False)
use_sam = gr.Checkbox(label="Use SAM for Refinement", value=False)
rm_bg_btn = gr.Button(value="Remove Background")
rm_bg_btn.click(fn=rm_bg, inputs=[input_image, use_sam], outputs=[masked_image, mask])
# 3. Depth Estimation & Mesh Reconstruction
with gr.Column(variant="panel"):
gr.Markdown("## Step 3. Depth Estimation & Mesh Reconstruction")
mesh = gr.Model3D(label="Mesh Reconstruction", clear_color=(1.0, 1.0, 1.0, 1.0), interactive=True)
with gr.Column():
with gr.Accordion("Options", open=False):
with gr.Group():
remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False)
fov = gr.Number(value=55., label="FOV", interactive=True)
mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.)
depth_estimation_btn = gr.Button(value="Estimate Depth")
depth_estimation_btn.click(
fn=mesh_reconstruction,
inputs=[masked_image, mask, remove_edges, fov, mask_threshold],
outputs=[mesh]
)
gr.Markdown("## Step 4. Render Hints")
with gr.Row():
with gr.Column():
hint_image = gr.Image(label="Hint Image")
with gr.Column():
pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01)
pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01)
pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01)
power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.)
render_btn = gr.Button(value="Render Hints")
res_folder_path = gr.Textbox("", visible=False)
def render_wrapper(mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power,
progress=gr.Progress(track_tqdm=True)):
res_path = render_hint_images_btn_func(mesh, fov, [(pl_pos_x, pl_pos_y, pl_pos_z)], power)
hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.34.png"]]
hints = []
for hint_file in hint_files:
hint = imageio.v3.imread(hint_file)
hints.append(hint)
hints = np.concatenate(hints, axis=1)
return hints, res_path
render_btn.click(
fn=render_wrapper,
inputs=[mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power],
outputs=[hint_image, res_folder_path]
)
gr.Markdown("## Step 5. Relighting!")
with gr.Row():
res_image = gr.Image(label="Result Image")
with gr.Column():
with gr.Group():
relighting_prompt = gr.Textbox(value="", label="Relighting Text Prompt", lines=3,
placeholder="Input prompt here",
interactive=True)
reuse_btn = gr.Button(value="Reuse Image Generation Prompt")
reuse_btn.click(fn=lambda x: x, inputs=[prompt], outputs=[relighting_prompt])
with gr.Accordion("Options", open=False):
with gr.Row():
relighting_seed = gr.Number(value=3407, label="Seed", interactive=True)
relighting_steps = gr.Number(value=20, label="Steps", interactive=True)
relighting_cfg = gr.Number(value=3.0, label="CFG", interactive=True)
with gr.Row():
relighting_generate_btn = gr.Button(value="Generate")
def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
relighting_steps, relighting_cfg,
progress=gr.Progress(track_tqdm=True)):
relighting_gen(
masked_ref_img=masked_image,
mask=mask,
cond_path=res_folder_path,
frames=1,
prompt=relighting_prompt,
steps=int(relighting_steps),
seed=int(relighting_seed),
cfg=relighting_cfg
)
res = imageio.v3.imread(res_folder_path + '/relighting00.png')
return res
relighting_generate_btn.click(fn=gen_relighting_image,
inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
relighting_steps, relighting_cfg],
outputs=[res_image])
if __name__ == '__main__':
demo.queue().launch(server_name="0.0.0.0", share=True)