import random import cv2 import gradio as gr import numpy as np import torch from controlnet_aux import HEDdetector, OpenposeDetector from PIL import Image, ImageFilter from transformers import DPTFeatureExtractor, DPTForDepthEstimation from diffusers.pipelines.controlnet.pipeline_controlnet import ControlNetModel from pipeline.pipeline_PowerPaint import StableDiffusionInpaintPipeline as Pipeline from pipeline.pipeline_PowerPaint_ControlNet import StableDiffusionControlNetInpaintPipeline as controlnetPipeline from utils.utils import TokenizerWrapper, add_tokens torch.set_grad_enabled(False) weight_dtype = torch.float16 global pipe pipe = Pipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=weight_dtype) pipe.tokenizer = TokenizerWrapper( from_pretrained="runwayml/stable-diffusion-v1-5", subfolder="tokenizer", revision=None ) add_tokens( tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, placeholder_tokens=["P_ctxt", "P_shape", "P_obj"], initialize_tokens=["a", "a", "a"], num_vectors_per_token=10, ) from safetensors.torch import load_model load_model(pipe.unet, "./models/unet/unet.safetensors") load_model(pipe.text_encoder, "./models/unet/text_encoder.safetensors") pipe = pipe.to("cuda") depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") hed = HEDdetector.from_pretrained("lllyasviel/ControlNet") global current_control current_control = "canny" # controlnet_conditioning_scale = 0.8 def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def get_depth_map(image): image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") with torch.no_grad(), torch.autocast("cuda"): depth_map = depth_estimator(image).predicted_depth depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(1024, 1024), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) return image def add_task(prompt, negative_prompt, control_type): # print(control_type) if control_type == "object-removal": promptA = "empty scene blur " + prompt + " P_ctxt" promptB = "empty scene blur " + prompt + " P_ctxt" negative_promptA = negative_prompt + " P_obj" negative_promptB = negative_prompt + " P_obj" elif control_type == "shape-guided": promptA = prompt + " P_shape" promptB = prompt + " P_ctxt" negative_promptA = ( negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry P_shape" ) negative_promptB = negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry P_ctxt" elif control_type == "image-outpainting": promptA = "empty scene " + prompt + " P_ctxt" promptB = "empty scene " + prompt + " P_ctxt" negative_promptA = negative_prompt + " P_obj" negative_promptB = negative_prompt + " P_obj" else: promptA = prompt + " P_obj" promptB = prompt + " P_obj" negative_promptA = negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry, P_obj" negative_promptB = negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry, P_obj" return promptA, promptB, negative_promptA, negative_promptB def predict( input_image, prompt, fitting_degree, ddim_steps, scale, seed, negative_prompt, task, vertical_expansion_ratio, horizontal_expansion_ratio, ): size1, size2 = input_image["image"].convert("RGB").size if task != "image-outpainting": if size1 < size2: input_image["image"] = input_image["image"].convert("RGB").resize((640, int(size2 / size1 * 640))) else: input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 640), 640)) else: if size1 < size2: input_image["image"] = input_image["image"].convert("RGB").resize((512, int(size2 / size1 * 512))) else: input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 512), 512)) if vertical_expansion_ratio != None and horizontal_expansion_ratio != None: o_W, o_H = input_image["image"].convert("RGB").size c_W = int(horizontal_expansion_ratio * o_W) c_H = int(vertical_expansion_ratio * o_H) expand_img = np.ones((c_H, c_W, 3), dtype=np.uint8) * 127 original_img = np.array(input_image["image"]) expand_img[ int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H, int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W, :, ] = original_img blurry_gap = 10 expand_mask = np.ones((c_H, c_W, 3), dtype=np.uint8) * 255 if vertical_expansion_ratio == 1 and horizontal_expansion_ratio != 1: expand_mask[ int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H, int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap, :, ] = 0 elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio != 1: expand_mask[ int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap, int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap, :, ] = 0 elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio == 1: expand_mask[ int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap, int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W, :, ] = 0 input_image["image"] = Image.fromarray(expand_img) input_image["mask"] = Image.fromarray(expand_mask) promptA, promptB, negative_promptA, negative_promptB = add_task(prompt, negative_prompt, task) print(promptA, promptB, negative_promptA, negative_promptB) img = np.array(input_image["image"].convert("RGB")) W = int(np.shape(img)[0] - np.shape(img)[0] % 8) H = int(np.shape(img)[1] - np.shape(img)[1] % 8) input_image["image"] = input_image["image"].resize((H, W)) input_image["mask"] = input_image["mask"].resize((H, W)) set_seed(seed) global pipe result = pipe( promptA=promptA, promptB=promptB, tradoff=fitting_degree, tradoff_nag=fitting_degree, negative_promptA=negative_promptA, negative_promptB=negative_promptB, image=input_image["image"].convert("RGB"), mask_image=input_image["mask"].convert("RGB"), width=H, height=W, guidance_scale=scale, num_inference_steps=ddim_steps, ).images[0] mask_np = np.array(input_image["mask"].convert("RGB")) red = np.array(result).astype("float") * 1 red[:, :, 0] = 180.0 red[:, :, 2] = 0 red[:, :, 1] = 0 result_m = np.array(result) result_m = Image.fromarray( ( result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red ).astype("uint8") ) m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=3)) m_img = np.asarray(m_img) / 255.0 img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0 ours_np = np.asarray(result) / 255.0 ours_np = ours_np * m_img + (1 - m_img) * img_np result_paste = Image.fromarray(np.uint8(ours_np * 255)) dict_res = [input_image["mask"].convert("RGB"), result_m] dict_out = [input_image["image"].convert("RGB"), result_paste] return dict_out, dict_res def predict_controlnet( input_image, input_control_image, control_type, prompt, ddim_steps, scale, seed, negative_prompt, controlnet_conditioning_scale, ): promptA = prompt + " P_obj" promptB = prompt + " P_obj" negative_promptA = negative_prompt negative_promptB = negative_prompt size1, size2 = input_image["image"].convert("RGB").size if size1 < size2: input_image["image"] = input_image["image"].convert("RGB").resize((640, int(size2 / size1 * 640))) else: input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 640), 640)) img = np.array(input_image["image"].convert("RGB")) W = int(np.shape(img)[0] - np.shape(img)[0] % 8) H = int(np.shape(img)[1] - np.shape(img)[1] % 8) input_image["image"] = input_image["image"].resize((H, W)) input_image["mask"] = input_image["mask"].resize((H, W)) global current_control global pipe base_control = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype) control_pipe = controlnetPipeline( pipe.vae, pipe.text_encoder, pipe.tokenizer, pipe.unet, base_control, pipe.scheduler, None, None, False ) control_pipe = control_pipe.to("cuda") current_control = "canny" if current_control != control_type: if control_type == "canny" or control_type is None: control_pipe.controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype ) elif control_type == "pose": control_pipe.controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-openpose", torch_dtype=weight_dtype ) elif control_type == "depth": control_pipe.controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-depth", torch_dtype=weight_dtype ) else: control_pipe.controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-hed", torch_dtype=weight_dtype ) control_pipe = control_pipe.to("cuda") current_control = control_type controlnet_image = input_control_image if current_control == "canny": controlnet_image = controlnet_image.resize((H, W)) controlnet_image = np.array(controlnet_image) controlnet_image = cv2.Canny(controlnet_image, 100, 200) controlnet_image = controlnet_image[:, :, None] controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2) controlnet_image = Image.fromarray(controlnet_image) elif current_control == "pose": controlnet_image = openpose(controlnet_image) elif current_control == "depth": controlnet_image = controlnet_image.resize((H, W)) controlnet_image = get_depth_map(controlnet_image) else: controlnet_image = hed(controlnet_image) mask_np = np.array(input_image["mask"].convert("RGB")) controlnet_image = controlnet_image.resize((H, W)) set_seed(seed) result = control_pipe( promptA=promptB, promptB=promptA, tradoff=1.0, tradoff_nag=1.0, negative_promptA=negative_promptA, negative_promptB=negative_promptB, image=input_image["image"].convert("RGB"), mask_image=input_image["mask"].convert("RGB"), control_image=controlnet_image, width=H, height=W, guidance_scale=scale, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=ddim_steps, ).images[0] red = np.array(result).astype("float") * 1 red[:, :, 0] = 180.0 red[:, :, 2] = 0 red[:, :, 1] = 0 result_m = np.array(result) result_m = Image.fromarray( ( result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red ).astype("uint8") ) mask_np = np.array(input_image["mask"].convert("RGB")) m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=4)) m_img = np.asarray(m_img) / 255.0 img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0 ours_np = np.asarray(result) / 255.0 ours_np = ours_np * m_img + (1 - m_img) * img_np result_paste = Image.fromarray(np.uint8(ours_np * 255)) return [input_image["image"].convert("RGB"), result_paste], [controlnet_image, result_m] def infer( input_image, text_guided_prompt, text_guided_negative_prompt, shape_guided_prompt, shape_guided_negative_prompt, fitting_degree, ddim_steps, scale, seed, task, enable_control, input_control_image, control_type, vertical_expansion_ratio, horizontal_expansion_ratio, outpaint_prompt, outpaint_negative_prompt, controlnet_conditioning_scale, removal_prompt, removal_negative_prompt, ): if task == "text-guided": prompt = text_guided_prompt negative_prompt = text_guided_negative_prompt elif task == "shape-guided": prompt = shape_guided_prompt negative_prompt = shape_guided_negative_prompt elif task == "object-removal": prompt = removal_prompt negative_prompt = removal_negative_prompt elif task == "image-outpainting": prompt = outpaint_prompt negative_prompt = outpaint_negative_prompt return predict( input_image, prompt, fitting_degree, ddim_steps, scale, seed, negative_prompt, task, vertical_expansion_ratio, horizontal_expansion_ratio, ) else: task = "text-guided" prompt = text_guided_prompt negative_prompt = text_guided_negative_prompt if enable_control and task == "text-guided": return predict_controlnet( input_image, input_control_image, control_type, prompt, ddim_steps, scale, seed, negative_prompt, controlnet_conditioning_scale, ) else: return predict(input_image, prompt, fitting_degree, ddim_steps, scale, seed, negative_prompt, task, None, None) def select_tab_text_guided(): return "text-guided" def select_tab_object_removal(): return "object-removal" def select_tab_image_outpainting(): return "image-outpainting" def select_tab_shape_guided(): return "shape-guided" with gr.Blocks(css="style.css") as demo: with gr.Row(): gr.Markdown( "
PowerPaint: High-Quality Versatile Image Inpainting
" # noqa ) with gr.Row(): gr.Markdown( "
Project Page  " # noqa "Paper  " "Code
" # noqa ) with gr.Row(): gr.Markdown( "**Note:** Due to network-related factors, the page may experience occasional bugs! If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content." # noqa ) # Attention: Due to network-related factors, the page may experience occasional bugs. If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content. with gr.Row(): with gr.Column(): gr.Markdown("### Input image and draw mask") input_image = gr.Image(source="upload", tool="sketch", type="pil") task = gr.Radio( ["text-guided", "object-removal", "shape-guided", "image-outpainting"], show_label=False, visible=False ) # Text-guided object inpainting with gr.Tab("Text-guided object inpainting") as tab_text_guided: enable_text_guided = gr.Checkbox( label="Enable text-guided object inpainting", value=True, interactive=False ) text_guided_prompt = gr.Textbox(label="Prompt") text_guided_negative_prompt = gr.Textbox(label="negative_prompt") gr.Markdown("### Controlnet setting") enable_control = gr.Checkbox( label="Enable controlnet", info="Enable this if you want to use controlnet" ) controlnet_conditioning_scale = gr.Slider( label="controlnet conditioning scale", minimum=0, maximum=1, step=0.05, value=0.5, ) control_type = gr.Radio(["canny", "pose", "depth", "hed"], label="Control type") input_control_image = gr.Image(source="upload", type="pil") tab_text_guided.select(fn=select_tab_text_guided, inputs=None, outputs=task) # Object removal inpainting with gr.Tab("Object removal inpainting") as tab_object_removal: enable_object_removal = gr.Checkbox( label="Enable object removal inpainting", value=True, info="The recommended configuration for the Guidance Scale is 10 or higher. \ If undesired objects appear in the masked area, \ you can address this by specifically increasing the Guidance Scale.", interactive=False, ) removal_prompt = gr.Textbox(label="Prompt") removal_negative_prompt = gr.Textbox(label="negative_prompt") tab_object_removal.select(fn=select_tab_object_removal, inputs=None, outputs=task) # Object image outpainting with gr.Tab("Image outpainting") as tab_image_outpainting: enable_object_removal = gr.Checkbox( label="Enable image outpainting", value=True, info="The recommended configuration for the Guidance Scale is 10 or higher. \ If unwanted random objects appear in the extended image region, \ you can enhance the cleanliness of the extension area by increasing the Guidance Scale.", interactive=False, ) outpaint_prompt = gr.Textbox(label="Outpainting_prompt") outpaint_negative_prompt = gr.Textbox(label="Outpainting_negative_prompt") horizontal_expansion_ratio = gr.Slider( label="horizontal expansion ratio", minimum=1, maximum=4, step=0.05, value=1, ) vertical_expansion_ratio = gr.Slider( label="vertical expansion ratio", minimum=1, maximum=4, step=0.05, value=1, ) tab_image_outpainting.select(fn=select_tab_image_outpainting, inputs=None, outputs=task) # Shape-guided object inpainting with gr.Tab("Shape-guided object inpainting") as tab_shape_guided: enable_shape_guided = gr.Checkbox( label="Enable shape-guided object inpainting", value=True, interactive=False ) shape_guided_prompt = gr.Textbox(label="shape_guided_prompt") shape_guided_negative_prompt = gr.Textbox(label="shape_guided_negative_prompt") fitting_degree = gr.Slider( label="fitting degree", minimum=0, maximum=1, step=0.05, value=1, ) tab_shape_guided.select(fn=select_tab_shape_guided, inputs=None, outputs=task) run_button = gr.Button(label="Run") with gr.Accordion("Advanced options", open=False): ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1) scale = gr.Slider( label="Guidance Scale", info="For object removal and image outpainting, it is recommended to set the value at 10 or above.", minimum=0.1, maximum=30.0, value=7.5, step=0.1, ) seed = gr.Slider( label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True, ) with gr.Column(): gr.Markdown("### Inpainting result") inpaint_result = gr.Gallery(label="Generated images", show_label=False, columns=2) gr.Markdown("### Mask") gallery = gr.Gallery(label="Generated masks", show_label=False, columns=2) run_button.click( fn=infer, inputs=[ input_image, text_guided_prompt, text_guided_negative_prompt, shape_guided_prompt, shape_guided_negative_prompt, fitting_degree, ddim_steps, scale, seed, task, enable_control, input_control_image, control_type, vertical_expansion_ratio, horizontal_expansion_ratio, outpaint_prompt, outpaint_negative_prompt, controlnet_conditioning_scale, removal_prompt, removal_negative_prompt, ], outputs=[inpaint_result, gallery], ) demo.queue() demo.launch(share=False, server_name="0.0.0.0", server_port=7860)