Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
# from diffusers import StableDiffusionImageVariationPipeline | |
from inference import InferenceModel | |
from pytorch_lightning import seed_everything | |
import numpy as np | |
import os | |
import rembg | |
import sys | |
from loguru import logger | |
_SAMPLE_TAB_ID_ = 0 | |
_HIGHRES_TAB_ID_ = 1 | |
_FOREGROUND_TAB_ID_ = 2 | |
def set_loggers(level): | |
logger.remove() | |
logger.add(sys.stderr, level=level) | |
def on_guide_select(evt: gr.SelectData): | |
logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}") | |
return [evt.value["image"]['path'], f"Sample {evt.index}"] | |
def on_input_select(evt: gr.SelectData): | |
logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}") | |
return evt.value["image"]['path'] | |
def sample_fine( | |
input_im, | |
domain="Albedo", | |
require_mask=False, | |
steps=25, | |
n_samples=4, | |
seed=0, | |
guid_img=None, | |
vert_split=2, | |
hor_split=2, | |
overlaps=2, | |
guidance_scale=2, | |
): | |
if require_mask: | |
input_im = remove_bg(input_im) | |
seed_everything(int(seed)) | |
model = model_dict[domain] | |
inp = tform(input_im).to(device).permute(1,2,0) | |
guid_img = tform(guid_img).to(device).permute(1,2,0) | |
images = model.generation((vert_split, hor_split), overlaps, guid_img[..., :3], inp[..., :3], inp[..., 3:], dps_scale=guidance_scale, uc_score=1.0, ddim_steps=steps, batch_size=1, n_samples=1) | |
images["guid_iamges"] = [(guid_img.detach().cpu().numpy() * 255).astype(np.uint8)] | |
output = images["out_images"][0] | |
return [[(output, "High-res")], gr.Tabs(selected=_HIGHRES_TAB_ID_)] | |
def remove_bg(input_im): | |
output = rembg.remove(input_im, session=model_dict["remove_bg"]) | |
return output | |
def sampling(input_im, domain="Albedo", require_mask=False, | |
steps=25, n_samples=4, seed=0): | |
seed_everything(int(seed)) | |
model = model_dict[domain] | |
if require_mask: | |
input_im = remove_bg(input_im) | |
inp = tform(input_im).to(device).permute(1,2,0) | |
images = model.generation((1, 1), 1, None, inp[..., :3], inp[..., 3:], dps_scale=0, uc_score=1, ddim_steps=steps, batch_size=1, n_samples=n_samples) | |
output = [[(images["input_image"][0], "Foreground Object"), (images["input_maskes"][0], "Foreground Maks")], | |
[(img,f"Sample {idx}") for idx, img in enumerate(images["out_images"])], | |
gr.Tabs(selected=_SAMPLE_TAB_ID_), | |
] | |
return output | |
title = "IntrinsicAnything: Learning Diffusion Priors for Inverse Rendering Under Unknown Illumination" | |
description = \ | |
""" | |
#### Generate intrinsic images (Albedo, Specular Shading) from a single image. | |
##### Tips | |
- You can check the "Auto Mask" box if the input image requires a foreground mask. Or supply your mask with RGBA input. | |
- You can optionally generate a high-resolution sample if the input image is of high resolution. We split the original image into `Vertical Splits` by `Horizontal Splits` patches with some `Overlaps` in between. Due to computation constraints for the online demo, we recommend `Vertical Splits` x `Horizontal Splits` to be no more than 6 and to set 2 for `Overlaps`. The denoising steps should at least be set to 80 for high resolution samples. | |
""" | |
set_loggers("INFO") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Loading Models...") | |
model_dict = { | |
"Albedo": InferenceModel(ckpt_path="weights/albedo", | |
use_ddim=True, | |
gpu_id=0), | |
"Specular": InferenceModel(ckpt_path="weights/specular", | |
use_ddim=True, | |
gpu_id=0), | |
"remove_bg": rembg.new_session(), | |
} | |
logger.info(f"All models Loaded!") | |
tform = transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
examples_dir = "examples" | |
examples = [[os.path.join(examples_dir, img_name)] for img_name in os.listdir(examples_dir)] | |
# theme definition | |
theme = gr.Theme.from_hub("NoCrypt/miku") | |
theme.body_background_fill = "#FFFFFF " | |
theme.body_background_fill_dark = "#000000" | |
demo = gr.Blocks(title=title, theme=theme) | |
with demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown('# ' + title) | |
gr.Markdown(description) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=0.8): | |
image_input = [gr.Image(image_mode='RGBA', height=256)] | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("Options"): | |
with gr.Column(): | |
with gr.Row(): | |
domain_box = gr.Radio([("Albedo", "Albedo"),("Specular", "Specular")], | |
value="Albedo", | |
label="Type") | |
with gr.Column(): | |
gr.Markdown("### Automatic foreground segmentation") | |
mask_box = gr.Checkbox(False, label="Auto Mask") | |
options_tab = [ | |
domain_box, | |
mask_box, | |
gr.Slider(5, 200, value=50, step=5, label="Denoising Steps (The larger the better results)"), | |
gr.Slider(1, 10, value=2, step=1, label="Number of Samples"), | |
gr.Number(75424, label="Seed", precision=0), | |
] | |
with gr.TabItem("Advanced (High-res)"): | |
with gr.Column(): | |
guiding_img = gr.Image(image_mode='RGBA', label="Guiding Image", interactive=False, height=256, visible=False) | |
sample_idx = gr.Textbox(placeholder="Select one from the generate low-res samples", lines=1, interactive=False, label="Guiding Image") | |
options_advanced_tab = [ | |
# high resolution options | |
guiding_img, | |
gr.Slider(1, 4, value=2, step=1, label="Vertical Splits"), | |
gr.Slider(1, 4, value=2, step=1, label="Horizontal Splits"), | |
gr.Slider(1, 5, value=2, step=1, label="Overlaps"), | |
gr.Slider(0, 5, value=3, step=1, label="Guidance Scale"),] | |
with gr.Column(scale=1.0): | |
with gr.Tabs() as res_tabs: | |
with gr.TabItem("Generated Samples", id=_SAMPLE_TAB_ID_): | |
image_output = gr.Gallery(label="Generated Samples", object_fit="contain", columns=[2], rows=[2],height=512, selected_index=0) | |
with gr.TabItem("High Resolution Sample", id=_HIGHRES_TAB_ID_): | |
image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="contain", columns=[1], rows=[1],height=512, selected_index=0) | |
with gr.TabItem("Foreground Object", id=_FOREGROUND_TAB_ID_): | |
forground_output = gr.Gallery(label="Foreground Object", object_fit="contain", columns=[2], rows=[1],height=512, selected_index=0) | |
with gr.Row(): | |
generate_button = gr.Button("Generate") | |
generate_button_fine = gr.Button("Generate High-Res") | |
examples_gr = gr.Examples(examples=examples, inputs=image_input, | |
cache_examples=False, examples_per_page=30, | |
label='Examples (Click one to start!)') | |
with gr.Row(): | |
pass | |
# forground_output = gr.Gallery(label="Inputs", preview=False, columns=[2], rows=[1],height=512, selected_index=0) | |
# image_output = gr.Gallery(label="Generated Samples", object_fit="cover", columns=[1], rows=[6],height=512, selected_index=0) | |
# image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="cover", columns=[1], rows=[1],height=512, selected_index=0) | |
generate_button.click(sampling, inputs=image_input+options_tab, | |
outputs=[forground_output, image_output, res_tabs]) | |
generate_button_fine.click(sample_fine, | |
inputs=image_input+options_tab+options_advanced_tab, | |
outputs=[image_output_high, res_tabs]) | |
image_output.select(on_guide_select, None, [guiding_img, sample_idx]) | |
logger.info(f"Demo Initilized, Starting...") | |
demo.queue().launch() | |