blanchon's picture
Update
de6b860
raw
history blame
8.06 kB
import os
import numpy as np
from typing import cast
import torch
from PIL import Image, ImageDraw
from diffusers import DiffusionPipeline
import gradio as gr
from gradio.components.image_editor import EditorValue
DEVICE = "cuda"
MAIN_MODEL_REPO_ID = os.getenv("MAIN_MODEL_REPO_ID", None)
SUB_MODEL_REPO_ID = os.getenv("SUB_MODEL_REPO_ID", None)
SUB_MODEL_SUBFOLDER = os.getenv("SUB_MODEL_SUBFOLDER", None)
if MAIN_MODEL_REPO_ID is None:
raise ValueError("MAIN_MODEL_REPO_ID is not set")
if SUB_MODEL_REPO_ID is None:
raise ValueError("SUB_MODEL_REPO_ID is not set")
if SUB_MODEL_SUBFOLDER is None:
raise ValueError("SUB_MODEL_SUBFOLDER is not set")
pipeline = DiffusionPipeline.from_pretrained(
MAIN_MODEL_REPO_ID,
torch_dtype=torch.bfloat16,
custom_pipeline=SUB_MODEL_REPO_ID,
).to(DEVICE)
pipeline.post_init()
def crop_divisible_by_16(image: Image.Image) -> Image.Image:
w, h = image.size
w = w - w % 16
h = h - h % 16
return image.crop((0, 0, w, h))
# @spaces.GPU(duration=150)
def predict(
image_and_mask: EditorValue,
condition_image: Image.Image | None,
seed: int = 0,
num_inference_steps: int = 28,
condition_size: int = 512,
target_size: int = 512,
condition_scale: float = 1.0,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
) -> Image.Image | None:
# ) -> tuple[Image.Image, Image.Image] | None:
if not image_and_mask:
gr.Info("Please upload an image and draw a mask")
return None
if not condition_image:
gr.Info("Please upload a furniture reference image")
return None
pipeline.load(
SUB_MODEL_REPO_ID,
subfolder=SUB_MODEL_SUBFOLDER,
)
image_np = image_and_mask["background"]
image_np = cast(np.ndarray, image_np)
# If the image is empty, return None
if np.sum(image_np) == 0:
gr.Info("Please upload an image")
return None
alpha_channel = image_and_mask["layers"][0]
alpha_channel = cast(np.ndarray, alpha_channel)
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
# if mask_np is empty, return None
if np.sum(mask_np) == 0:
gr.Info("Please mark the areas you want to remove")
return None
pipeline.load(
SUB_MODEL_REPO_ID,
subfolder=SUB_MODEL_SUBFOLDER,
)
target_image = Image.fromarray(image_np).convert("RGB")
# Resize to max dimension
target_image.thumbnail((target_size, target_size))
new_target_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
new_target_image.paste(target_image, (0, 0))
# Save target image
new_target_image.save("target_image.png")
mask_image = Image.fromarray(mask_np).convert("L")
mask_image.thumbnail((target_size, target_size))
mask_image_bbox = mask_image.getbbox()
# Fill all the bbox area with 255
draw = ImageDraw.Draw(mask_image)
draw.rectangle(mask_image_bbox, fill=(255))
new_mask_image = Image.new("L", (target_size, target_size), 0)
new_mask_image.paste(mask_image, (0, 0))
# Save mask image
new_mask_image.save("mask_image.png")
# # Image masked is the image with the mask applied (black background)
# image_masked = Image.new("RGB", image.size, (0, 0, 0))
# image_masked.paste(image, (0, 0), mask)
condition_image = condition_image.convert("RGB")
condition_image.thumbnail((condition_size, condition_size))
# Save condition image
new_condition_image = Image.new("RGB", (condition_size, condition_size), (0, 0, 0))
new_condition_image.paste(condition_image, (0, 0))
# Save condition image
new_condition_image.save("condition_image.png")
generator = torch.Generator(device="cpu").manual_seed(seed)
final_image = pipeline(
condition_image=new_condition_image,
prompt="",
image=new_target_image,
mask_image=new_mask_image,
num_inference_steps=num_inference_steps,
height=target_size,
width=target_size,
union_cond_attn=True,
add_cond_attn=False,
latent_lora=False,
default_lora=False,
condition_scale=condition_scale,
generator=generator,
max_sequence_length=512,
).images[0]
final_image_crop = final_image.crop((0, 0, target_size, target_size))
return final_image_crop
intro_markdown = r"""
# Furniture Inpainting Demo
"""
css = r"""
#col-left {
margin: 0 auto;
max-width: 650px;
}
#col-right {
margin: 0 auto;
max-width: 650px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(intro_markdown)
with gr.Row() as content:
with gr.Column(elem_id="col-left"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 1. Upload a room image ⬇️
</div>
</div>
""",
max_height=50,
)
image_and_mask = gr.ImageMask(
label="Image and Mask",
layers=False,
height="full",
width="full",
show_fullscreen_button=False,
sources=["upload"],
show_download_button=False,
interactive=True,
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
transforms=[],
)
condition_image = gr.Image(
label="Furniture Reference",
type="pil",
sources=["upload"],
image_mode="RGB",
)
with gr.Column(elem_id="col-right"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 2. Press Run to launch
</div>
</div>
""",
max_height=50,
)
# image_slider = ImageSlider(
# label="Result",
# interactive=False,
# )
result = gr.Image(label="Result")
run_button = gr.Button("Run")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100_000,
step=1,
value=0,
)
condition_scale = gr.Slider(
label="Condition Scale",
minimum=-10.0,
maximum=10.0,
step=0.10,
value=1.0,
)
with gr.Column():
condition_size = gr.Slider(
label="Condition Size",
minimum=256,
maximum=1024,
step=128,
value=512,
)
target_size = gr.Slider(
label="Target Size",
minimum=256,
maximum=1024,
step=128,
value=512,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
run_button.click(
fn=predict,
inputs=[
image_and_mask,
condition_image,
seed,
num_inference_steps,
condition_size,
target_size,
condition_scale,
],
# outputs=[image_slider],
outputs=[result],
)
demo.launch()