blanchon's picture
Update image component
efe8765
raw
history blame
6.81 kB
import os
import numpy as np
from typing import cast
from pydantic import NonNegativeInt
import torch
from PIL import Image, ImageOps
from diffusers import DiffusionPipeline
import gradio as gr
from gradio.components.image_editor import EditorValue
import spaces
DEVICE = "cuda"
MAIN_MODEL_REPO_ID = os.getenv("MAIN_MODEL_REPO_ID", None)
SUB_MODEL_REPO_ID = os.getenv("SUB_MODEL_REPO_ID", None)
SUB_MODEL_SUBFOLDER = os.getenv("SUB_MODEL_SUBFOLDER", None)
if MAIN_MODEL_REPO_ID is None:
raise ValueError("MAIN_MODEL_REPO_ID is not set")
if SUB_MODEL_REPO_ID is None:
raise ValueError("SUB_MODEL_REPO_ID is not set")
if SUB_MODEL_SUBFOLDER is None:
raise ValueError("SUB_MODEL_SUBFOLDER is not set")
pipeline = DiffusionPipeline.from_pretrained(
MAIN_MODEL_REPO_ID,
torch_dtype=torch.bfloat16,
custom_pipeline=SUB_MODEL_REPO_ID,
).to(DEVICE)
def crop_divisible_by_16(image: Image.Image) -> Image.Image:
w, h = image.size
w = w - w % 16
h = h - h % 16
return image.crop((0, 0, w, h))
@spaces.GPU(duration=150)
def predict(
image_and_mask: EditorValue | NonNegativeInt,
furniture_reference: Image.Image | None,
seed: int = 0,
num_inference_steps: int = 28,
max_dimension: int = 704,
condition_scale: float = 1.0,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
) -> Image.Image | None:
# ) -> tuple[Image.Image, Image.Image] | None:
if not image_and_mask:
gr.Info("Please upload an image and draw a mask")
return None
if not furniture_reference:
gr.Info("Please upload a furniture reference image")
return None
image_np = image_and_mask["background"]
image_np = cast(np.ndarray, image_np)
# If the image is empty, return None
if np.sum(image_np) == 0:
gr.Info("Please upload an image")
return None
alpha_channel = image_and_mask["layers"][0]
alpha_channel = cast(np.ndarray, alpha_channel)
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
# if mask_np is empty, return None
if np.sum(mask_np) == 0:
gr.Info("Please mark the areas you want to remove")
return None
pipeline.load(
SUB_MODEL_REPO_ID,
subfolder=SUB_MODEL_SUBFOLDER,
)
image = Image.fromarray(image_np)
# Resize to max dimension
image.thumbnail((max_dimension, max_dimension))
# Ensure dimensions are multiple of 16 (for VAE)
image = crop_divisible_by_16(image)
mask = Image.fromarray(mask_np)
mask.thumbnail((max_dimension, max_dimension))
mask = crop_divisible_by_16(mask)
# Invert the mask
mask = ImageOps.invert(mask)
# Image masked is the image with the mask applied (black background)
image_masked = Image.new("RGB", image.size, (0, 0, 0))
image_masked.paste(image, (0, 0), mask)
furniture_reference.thumbnail((max_dimension, max_dimension))
furniture_reference = crop_divisible_by_16(furniture_reference)
generator = torch.Generator(device="cpu").manual_seed(seed)
final_image = pipeline(
condition_image=image_masked,
reference_image=furniture_reference,
condition_scale=condition_scale,
prompt="",
num_inference_steps=num_inference_steps,
generator=generator,
max_sequence_length=512,
latent_lora=True,
).images[0]
return final_image
intro_markdown = r"""
# Furniture Inpainting Demo
"""
css = r"""
#col-left {
margin: 0 auto;
max-width: 650px;
}
#col-right {
margin: 0 auto;
max-width: 650px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(intro_markdown)
with gr.Row() as content:
with gr.Column(elem_id="col-left"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 1. Upload a room image ⬇️
</div>
</div>
""",
max_height=50,
)
image_and_mask = gr.ImageMask(
label="Image and Mask",
layers=False,
height="full",
width="full",
show_fullscreen_button=False,
sources=["upload"],
show_download_button=False,
interactive=True,
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
transforms=[],
)
furniture_reference = gr.Image(
label="Furniture Reference",
type="pil",
sources=["upload"],
image_mode="RGB",
)
with gr.Column(elem_id="col-right"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 2. Press Run to launch
</div>
</div>
""",
max_height=50,
)
# image_slider = ImageSlider(
# label="Result",
# interactive=False,
# )
result = gr.Image(label="Result")
run_button = gr.Button("Run")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100_000,
step=1,
value=0,
)
condition_scale = gr.Slider(
label="Condition Scale",
minimum=-10.0,
maximum=10.0,
step=0.10,
value=1.0,
)
with gr.Column():
max_dimension = gr.Slider(
label="Max Dimension",
minimum=512,
maximum=2048,
step=128,
value=704,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
run_button.click(
fn=predict,
inputs=[
image_and_mask,
furniture_reference,
seed,
num_inference_steps,
max_dimension,
condition_scale,
],
# outputs=[image_slider],
outputs=[result],
)
demo.launch()