import os import numpy as np from typing import cast import torch from PIL import Image, ImageDraw from diffusers import DiffusionPipeline import gradio as gr from gradio.components.image_editor import EditorValue import spaces DEVICE = "cuda" MAIN_MODEL_REPO_ID = os.getenv("MAIN_MODEL_REPO_ID", None) SUB_MODEL_REPO_ID = os.getenv("SUB_MODEL_REPO_ID", None) SUB_MODEL_SUBFOLDER = os.getenv("SUB_MODEL_SUBFOLDER", None) if MAIN_MODEL_REPO_ID is None: raise ValueError("MAIN_MODEL_REPO_ID is not set") if SUB_MODEL_REPO_ID is None: raise ValueError("SUB_MODEL_REPO_ID is not set") if SUB_MODEL_SUBFOLDER is None: raise ValueError("SUB_MODEL_SUBFOLDER is not set") pipeline = DiffusionPipeline.from_pretrained( MAIN_MODEL_REPO_ID, torch_dtype=torch.bfloat16, custom_pipeline=SUB_MODEL_REPO_ID, ).to(DEVICE) pipeline.post_init() def crop_divisible_by_16(image: Image.Image) -> Image.Image: w, h = image.size w = w - w % 16 h = h - h % 16 return image.crop((0, 0, w, h)) @spaces.GPU(duration=150) def predict( image_and_mask: EditorValue, condition_image: Image.Image | None, seed: int = 0, num_inference_steps: int = 28, condition_size: int = 512, target_size: int = 512, condition_scale: float = 1.0, progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 ) -> Image.Image | None: # ) -> tuple[Image.Image, Image.Image] | None: if not image_and_mask: gr.Info("Please upload an image and draw a mask") return None if not condition_image: gr.Info("Please upload a furniture reference image") return None pipeline.load( SUB_MODEL_REPO_ID, subfolder=SUB_MODEL_SUBFOLDER, ) image_np = image_and_mask["background"] image_np = cast(np.ndarray, image_np) # If the image is empty, return None if np.sum(image_np) == 0: gr.Info("Please upload an image") return None alpha_channel = image_and_mask["layers"][0] alpha_channel = cast(np.ndarray, alpha_channel) mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) # if mask_np is empty, return None if np.sum(mask_np) == 0: gr.Info("Please mark the areas you want to remove") return None pipeline.load( SUB_MODEL_REPO_ID, subfolder=SUB_MODEL_SUBFOLDER, ) target_image = Image.fromarray(image_np).convert("RGB") # Resize to max dimension target_image.thumbnail((target_size, target_size)) new_target_image = Image.new("RGB", (target_size, target_size), (0, 0, 0)) new_target_image.paste(target_image, (0, 0)) # Save target image new_target_image.save("target_image.png") mask_image = Image.fromarray(mask_np).convert("L") mask_image.thumbnail((target_size, target_size)) mask_image_bbox = mask_image.getbbox() # Fill all the bbox area with 255 draw = ImageDraw.Draw(mask_image) draw.rectangle(mask_image_bbox, fill=(255)) new_mask_image = Image.new("L", (target_size, target_size), 0) new_mask_image.paste(mask_image, (0, 0)) # Save mask image new_mask_image.save("mask_image.png") # # Image masked is the image with the mask applied (black background) # image_masked = Image.new("RGB", image.size, (0, 0, 0)) # image_masked.paste(image, (0, 0), mask) condition_image = condition_image.convert("RGB") condition_image.thumbnail((condition_size, condition_size)) # Save condition image new_condition_image = Image.new("RGB", (condition_size, condition_size), (0, 0, 0)) new_condition_image.paste(condition_image, (0, 0)) # Save condition image new_condition_image.save("condition_image.png") generator = torch.Generator(device="cpu").manual_seed(seed) final_image = pipeline( condition_image=new_condition_image, prompt="", image=new_target_image, mask_image=new_mask_image, num_inference_steps=num_inference_steps, height=target_size, width=target_size, union_cond_attn=True, add_cond_attn=False, latent_lora=False, default_lora=False, condition_scale=condition_scale, generator=generator, max_sequence_length=512, ).images[0] final_image_crop = final_image.crop((0, 0, target_size, target_size)) return final_image_crop intro_markdown = r""" # Furniture Inpainting Demo """ css = r""" #col-left { margin: 0 auto; max-width: 650px; } #col-right { margin: 0 auto; max-width: 650px; } #col-showcase { margin: 0 auto; max-width: 1100px; } """ with gr.Blocks(css=css) as demo: gr.Markdown(intro_markdown) with gr.Row() as content: with gr.Column(elem_id="col-left"): gr.HTML( """