blanchon's picture
Update
674f931
import os
import numpy as np
from typing import cast
import torch
from PIL import Image, ImageDraw
from diffusers import DiffusionPipeline
import gradio as gr
from gradio.components.image_editor import EditorValue
import spaces
DEVICE = "cuda"
MAIN_MODEL_REPO_ID = os.getenv("MAIN_MODEL_REPO_ID", None)
SUB_MODEL_REPO_ID = os.getenv("SUB_MODEL_REPO_ID", None)
SUB_MODEL_SUBFOLDER = os.getenv("SUB_MODEL_SUBFOLDER", None)
if MAIN_MODEL_REPO_ID is None:
raise ValueError("MAIN_MODEL_REPO_ID is not set")
if SUB_MODEL_REPO_ID is None:
raise ValueError("SUB_MODEL_REPO_ID is not set")
if SUB_MODEL_SUBFOLDER is None:
raise ValueError("SUB_MODEL_SUBFOLDER is not set")
pipeline = DiffusionPipeline.from_pretrained(
MAIN_MODEL_REPO_ID,
torch_dtype=torch.bfloat16,
custom_pipeline=SUB_MODEL_REPO_ID,
).to(DEVICE)
pipeline.post_init()
def crop_divisible_by_16(image: Image.Image) -> Image.Image:
w, h = image.size
w = w - w % 16
h = h - h % 16
return image.crop((0, 0, w, h))
@spaces.GPU(duration=150)
def predict(
image_and_mask: EditorValue,
condition_image: Image.Image | None,
seed: int = 0,
num_inference_steps: int = 28,
condition_size: int = 512,
target_size: int = 512,
condition_scale: float = 1.0,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
) -> Image.Image | None:
# ) -> tuple[Image.Image, Image.Image] | None:
if not image_and_mask:
gr.Info("Please upload an image and draw a mask")
return None
if not condition_image:
gr.Info("Please upload a furniture reference image")
return None
pipeline.load(
SUB_MODEL_REPO_ID,
subfolder=SUB_MODEL_SUBFOLDER,
)
image_np = image_and_mask["background"]
image_np = cast(np.ndarray, image_np)
# If the image is empty, return None
if np.sum(image_np) == 0:
gr.Info("Please upload an image")
return None
alpha_channel = image_and_mask["layers"][0]
alpha_channel = cast(np.ndarray, alpha_channel)
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
# if mask_np is empty, return None
if np.sum(mask_np) == 0:
gr.Info("Please mark the areas you want to remove")
return None
pipeline.load(
SUB_MODEL_REPO_ID,
subfolder=SUB_MODEL_SUBFOLDER,
)
target_image = Image.fromarray(image_np).convert("RGB")
# Resize to max dimension
target_image.thumbnail((target_size, target_size))
new_target_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
new_target_image.paste(target_image, (0, 0))
# Save target image
new_target_image.save("target_image.png")
mask_image = Image.fromarray(mask_np).convert("L")
mask_image.thumbnail((target_size, target_size))
mask_image_bbox = mask_image.getbbox()
# Fill all the bbox area with 255
draw = ImageDraw.Draw(mask_image)
draw.rectangle(mask_image_bbox, fill=(255))
new_mask_image = Image.new("L", (target_size, target_size), 0)
new_mask_image.paste(mask_image, (0, 0))
# Save mask image
new_mask_image.save("mask_image.png")
# # Image masked is the image with the mask applied (black background)
# image_masked = Image.new("RGB", image.size, (0, 0, 0))
# image_masked.paste(image, (0, 0), mask)
condition_image = condition_image.convert("RGB")
condition_image.thumbnail((condition_size, condition_size))
# Save condition image
new_condition_image = Image.new("RGB", (condition_size, condition_size), (0, 0, 0))
new_condition_image.paste(condition_image, (0, 0))
# Save condition image
new_condition_image.save("condition_image.png")
generator = torch.Generator(device="cpu").manual_seed(seed)
final_image = pipeline(
condition_image=new_condition_image,
prompt="",
image=new_target_image,
mask_image=new_mask_image,
num_inference_steps=num_inference_steps,
height=target_size,
width=target_size,
union_cond_attn=True,
add_cond_attn=False,
latent_lora=False,
default_lora=False,
condition_scale=condition_scale,
generator=generator,
max_sequence_length=512,
).images[0]
final_image_crop = final_image.crop((0, 0, target_size, target_size))
return final_image_crop
intro_markdown = r"""
# Furniture Inpainting Demo
"""
css = r"""
#col-left {
margin: 0 auto;
max-width: 650px;
}
#col-right {
margin: 0 auto;
max-width: 650px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(intro_markdown)
with gr.Row() as content:
with gr.Column(elem_id="col-left"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 1. Upload a room image ⬇️
</div>
</div>
""",
max_height=50,
)
image_and_mask = gr.ImageMask(
label="Image and Mask",
layers=False,
height="full",
width="full",
show_fullscreen_button=False,
sources=["upload"],
show_download_button=False,
interactive=True,
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
transforms=[],
)
condition_image = gr.Image(
label="Furniture Reference",
type="pil",
sources=["upload"],
image_mode="RGB",
)
with gr.Column(elem_id="col-right"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 2. Press Run to launch
</div>
</div>
""",
max_height=50,
)
# image_slider = ImageSlider(
# label="Result",
# interactive=False,
# )
result = gr.Image(label="Result")
run_button = gr.Button("Run")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100_000,
step=1,
value=0,
)
condition_scale = gr.Slider(
label="Condition Scale",
minimum=-10.0,
maximum=10.0,
step=0.10,
value=1.0,
)
with gr.Column():
condition_size = gr.Slider(
label="Condition Size",
minimum=256,
maximum=1024,
step=128,
value=512,
)
target_size = gr.Slider(
label="Target Size",
minimum=256,
maximum=1024,
step=128,
value=512,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
run_button.click(
fn=predict,
inputs=[
image_and_mask,
condition_image,
seed,
num_inference_steps,
condition_size,
target_size,
condition_scale,
],
# outputs=[image_slider],
outputs=[result],
)
demo.launch()