LittleFrog's picture
Import space before everything.
3434b8c verified
import spaces
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
# from diffusers import StableDiffusionImageVariationPipeline
from inference import InferenceModel
from pytorch_lightning import seed_everything
import numpy as np
import os
import rembg
import sys
from loguru import logger
_SAMPLE_TAB_ID_ = 0
_HIGHRES_TAB_ID_ = 1
_FOREGROUND_TAB_ID_ = 2
def set_loggers(level):
logger.remove()
logger.add(sys.stderr, level=level)
def on_guide_select(evt: gr.SelectData):
logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}")
return [evt.value["image"]['path'], f"Sample {evt.index}"]
def on_input_select(evt: gr.SelectData):
logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}")
return evt.value["image"]['path']
@spaces.GPU(duration=120)
def sample_fine(
input_im,
domain="Albedo",
require_mask=False,
steps=25,
n_samples=4,
seed=0,
guid_img=None,
vert_split=2,
hor_split=2,
overlaps=2,
guidance_scale=2,
):
if require_mask:
input_im = remove_bg(input_im)
seed_everything(int(seed))
model = model_dict[domain]
inp = tform(input_im).to(device).permute(1,2,0)
guid_img = tform(guid_img).to(device).permute(1,2,0)
images = model.generation((vert_split, hor_split), overlaps, guid_img[..., :3], inp[..., :3], inp[..., 3:], dps_scale=guidance_scale, uc_score=1.0, ddim_steps=steps, batch_size=1, n_samples=1)
images["guid_iamges"] = [(guid_img.detach().cpu().numpy() * 255).astype(np.uint8)]
output = images["out_images"][0]
return [[(output, "High-res")], gr.Tabs(selected=_HIGHRES_TAB_ID_)]
def remove_bg(input_im):
output = rembg.remove(input_im, session=model_dict["remove_bg"])
return output
@spaces.GPU()
def sampling(input_im, domain="Albedo", require_mask=False,
steps=25, n_samples=4, seed=0):
seed_everything(int(seed))
model = model_dict[domain]
if require_mask:
input_im = remove_bg(input_im)
inp = tform(input_im).to(device).permute(1,2,0)
images = model.generation((1, 1), 1, None, inp[..., :3], inp[..., 3:], dps_scale=0, uc_score=1, ddim_steps=steps, batch_size=1, n_samples=n_samples)
output = [[(images["input_image"][0], "Foreground Object"), (images["input_maskes"][0], "Foreground Maks")],
[(img,f"Sample {idx}") for idx, img in enumerate(images["out_images"])],
gr.Tabs(selected=_SAMPLE_TAB_ID_),
]
return output
title = "IntrinsicAnything: Learning Diffusion Priors for Inverse Rendering Under Unknown Illumination"
description = \
"""
#### Generate intrinsic images (Albedo, Specular Shading) from a single image.
##### Tips
- You can check the "Auto Mask" box if the input image requires a foreground mask. Or supply your mask with RGBA input.
- You can optionally generate a high-resolution sample if the input image is of high resolution. We split the original image into `Vertical Splits` by `Horizontal Splits` patches with some `Overlaps` in between. Due to computation constraints for the online demo, we recommend `Vertical Splits` x `Horizontal Splits` to be no more than 6 and to set 2 for `Overlaps`. The denoising steps should at least be set to 80 for high resolution samples.
"""
set_loggers("INFO")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading Models...")
model_dict = {
"Albedo": InferenceModel(ckpt_path="weights/albedo",
use_ddim=True,
gpu_id=0),
"Specular": InferenceModel(ckpt_path="weights/specular",
use_ddim=True,
gpu_id=0),
"remove_bg": rembg.new_session(),
}
logger.info(f"All models Loaded!")
tform = transforms.Compose([
transforms.ToTensor()
])
examples_dir = "examples"
examples = [[os.path.join(examples_dir, img_name)] for img_name in os.listdir(examples_dir)]
# theme definition
theme = gr.Theme.from_hub("NoCrypt/miku")
theme.body_background_fill = "#FFFFFF "
theme.body_background_fill_dark = "#000000"
demo = gr.Blocks(title=title, theme=theme)
with demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + title)
gr.Markdown(description)
with gr.Column():
with gr.Row():
with gr.Column(scale=0.8):
image_input = [gr.Image(image_mode='RGBA', height=256)]
with gr.Column():
with gr.Tabs():
with gr.TabItem("Options"):
with gr.Column():
with gr.Row():
domain_box = gr.Radio([("Albedo", "Albedo"),("Specular", "Specular")],
value="Albedo",
label="Type")
with gr.Column():
gr.Markdown("### Automatic foreground segmentation")
mask_box = gr.Checkbox(False, label="Auto Mask")
options_tab = [
domain_box,
mask_box,
gr.Slider(5, 200, value=50, step=5, label="Denoising Steps (The larger the better results)"),
gr.Slider(1, 10, value=2, step=1, label="Number of Samples"),
gr.Number(75424, label="Seed", precision=0),
]
with gr.TabItem("Advanced (High-res)"):
with gr.Column():
guiding_img = gr.Image(image_mode='RGBA', label="Guiding Image", interactive=False, height=256, visible=False)
sample_idx = gr.Textbox(placeholder="Select one from the generate low-res samples", lines=1, interactive=False, label="Guiding Image")
options_advanced_tab = [
# high resolution options
guiding_img,
gr.Slider(1, 4, value=2, step=1, label="Vertical Splits"),
gr.Slider(1, 4, value=2, step=1, label="Horizontal Splits"),
gr.Slider(1, 5, value=2, step=1, label="Overlaps"),
gr.Slider(0, 5, value=3, step=1, label="Guidance Scale"),]
with gr.Column(scale=1.0):
with gr.Tabs() as res_tabs:
with gr.TabItem("Generated Samples", id=_SAMPLE_TAB_ID_):
image_output = gr.Gallery(label="Generated Samples", object_fit="contain", columns=[2], rows=[2],height=512, selected_index=0)
with gr.TabItem("High Resolution Sample", id=_HIGHRES_TAB_ID_):
image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="contain", columns=[1], rows=[1],height=512, selected_index=0)
with gr.TabItem("Foreground Object", id=_FOREGROUND_TAB_ID_):
forground_output = gr.Gallery(label="Foreground Object", object_fit="contain", columns=[2], rows=[1],height=512, selected_index=0)
with gr.Row():
generate_button = gr.Button("Generate")
generate_button_fine = gr.Button("Generate High-Res")
examples_gr = gr.Examples(examples=examples, inputs=image_input,
cache_examples=False, examples_per_page=30,
label='Examples (Click one to start!)')
with gr.Row():
pass
# forground_output = gr.Gallery(label="Inputs", preview=False, columns=[2], rows=[1],height=512, selected_index=0)
# image_output = gr.Gallery(label="Generated Samples", object_fit="cover", columns=[1], rows=[6],height=512, selected_index=0)
# image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="cover", columns=[1], rows=[1],height=512, selected_index=0)
generate_button.click(sampling, inputs=image_input+options_tab,
outputs=[forground_output, image_output, res_tabs])
generate_button_fine.click(sample_fine,
inputs=image_input+options_tab+options_advanced_tab,
outputs=[image_output_high, res_tabs])
image_output.select(on_guide_select, None, [guiding_img, sample_idx])
logger.info(f"Demo Initilized, Starting...")
demo.queue().launch()