Spaces:
Paused
Paused
import os | |
import random | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image, ImageFilter | |
from transformers import CLIPTextModel | |
from diffusers import UniPCMultistepScheduler | |
from model.BrushNet_CA import BrushNetModel | |
from model.diffusers_c.models import UNet2DConditionModel | |
from pipeline.pipeline_PowerPaint_Brushnet_CA import StableDiffusionPowerPaintBrushNetPipeline | |
from utils.utils import TokenizerWrapper, add_tokens | |
base_path = "./PowerPaint_v2" | |
os.system("apt install git") | |
os.system("apt install git-lfs") | |
os.system(f"git lfs clone https://code.openxlab.org.cn/zhuangjunhao/PowerPaint_v2.git {base_path}") | |
os.system(f"cd {base_path} && git lfs pull") | |
os.system("cd ..") | |
torch.set_grad_enabled(False) | |
context_prompt = "" | |
context_negative_prompt = "" | |
base_model_path = "./PowerPaint_v2/realisticVisionV60B1_v51VAE/" | |
dtype = torch.float16 | |
unet = UNet2DConditionModel.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", subfolder="unet", revision=None, torch_dtype=dtype | |
) | |
text_encoder_brushnet = CLIPTextModel.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", subfolder="text_encoder", revision=None, torch_dtype=dtype | |
) | |
brushnet = BrushNetModel.from_unet(unet) | |
global pipe | |
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained( | |
base_model_path, | |
brushnet=brushnet, | |
text_encoder_brushnet=text_encoder_brushnet, | |
torch_dtype=dtype, | |
low_cpu_mem_usage=False, | |
safety_checker=None, | |
) | |
pipe.unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet", revision=None, torch_dtype=dtype) | |
pipe.tokenizer = TokenizerWrapper(from_pretrained=base_model_path, subfolder="tokenizer", revision=None) | |
add_tokens( | |
tokenizer=pipe.tokenizer, | |
text_encoder=pipe.text_encoder_brushnet, | |
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"], | |
initialize_tokens=["a", "a", "a"], | |
num_vectors_per_token=10, | |
) | |
from safetensors.torch import load_model | |
load_model(pipe.brushnet, "./PowerPaint_v2/PowerPaint_Brushnet/diffusion_pytorch_model.safetensors") | |
pipe.text_encoder_brushnet.load_state_dict( | |
torch.load("./PowerPaint_v2/PowerPaint_Brushnet/pytorch_model.bin"), strict=False | |
) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload() | |
global current_control | |
current_control = "canny" | |
# controlnet_conditioning_scale = 0.8 | |
def set_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
def add_task(control_type): | |
# print(control_type) | |
if control_type == "object-removal": | |
promptA = "P_ctxt" | |
promptB = "P_ctxt" | |
negative_promptA = "P_obj" | |
negative_promptB = "P_obj" | |
elif control_type == "context-aware": | |
promptA = "P_ctxt" | |
promptB = "P_ctxt" | |
negative_promptA = "" | |
negative_promptB = "" | |
elif control_type == "shape-guided": | |
promptA = "P_shape" | |
promptB = "P_ctxt" | |
negative_promptA = "P_shape" | |
negative_promptB = "P_ctxt" | |
elif control_type == "image-outpainting": | |
promptA = "P_ctxt" | |
promptB = "P_ctxt" | |
negative_promptA = "P_obj" | |
negative_promptB = "P_obj" | |
else: | |
promptA = "P_obj" | |
promptB = "P_obj" | |
negative_promptA = "P_obj" | |
negative_promptB = "P_obj" | |
return promptA, promptB, negative_promptA, negative_promptB | |
def predict( | |
input_image, | |
prompt, | |
fitting_degree, | |
ddim_steps, | |
scale, | |
seed, | |
negative_prompt, | |
task, | |
vertical_expansion_ratio, | |
horizontal_expansion_ratio, | |
): | |
size1, size2 = input_image["image"].convert("RGB").size | |
if task != "image-outpainting": | |
if size1 < size2: | |
input_image["image"] = input_image["image"].convert("RGB").resize((640, int(size2 / size1 * 640))) | |
else: | |
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 640), 640)) | |
else: | |
if size1 < size2: | |
input_image["image"] = input_image["image"].convert("RGB").resize((512, int(size2 / size1 * 512))) | |
else: | |
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 512), 512)) | |
if task == "image-outpainting" or task == "context-aware": | |
prompt = prompt + " empty scene" | |
if task == "object-removal": | |
prompt = prompt + " empty scene blur" | |
if vertical_expansion_ratio != None and horizontal_expansion_ratio != None: | |
o_W, o_H = input_image["image"].convert("RGB").size | |
c_W = int(horizontal_expansion_ratio * o_W) | |
c_H = int(vertical_expansion_ratio * o_H) | |
expand_img = np.ones((c_H, c_W, 3), dtype=np.uint8) * 127 | |
original_img = np.array(input_image["image"]) | |
expand_img[ | |
int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H, | |
int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W, | |
:, | |
] = original_img | |
blurry_gap = 10 | |
expand_mask = np.ones((c_H, c_W, 3), dtype=np.uint8) * 255 | |
if vertical_expansion_ratio == 1 and horizontal_expansion_ratio != 1: | |
expand_mask[ | |
int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H, | |
int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap, | |
:, | |
] = 0 | |
elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio != 1: | |
expand_mask[ | |
int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap, | |
int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap, | |
:, | |
] = 0 | |
elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio == 1: | |
expand_mask[ | |
int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap, | |
int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W, | |
:, | |
] = 0 | |
input_image["image"] = Image.fromarray(expand_img) | |
input_image["mask"] = Image.fromarray(expand_mask) | |
promptA, promptB, negative_promptA, negative_promptB = add_task(task) | |
img = np.array(input_image["image"].convert("RGB")) | |
W = int(np.shape(img)[0] - np.shape(img)[0] % 8) | |
H = int(np.shape(img)[1] - np.shape(img)[1] % 8) | |
input_image["image"] = input_image["image"].resize((H, W)) | |
input_image["mask"] = input_image["mask"].resize((H, W)) | |
np_inpimg = np.array(input_image["image"]) | |
np_inmask = np.array(input_image["mask"]) / 255.0 | |
np_inpimg = np_inpimg * (1 - np_inmask) | |
input_image["image"] = Image.fromarray(np_inpimg.astype(np.uint8)).convert("RGB") | |
set_seed(seed) | |
global pipe | |
result = pipe( | |
promptA=promptA, | |
promptB=promptB, | |
promptU=prompt, | |
tradoff=fitting_degree, | |
tradoff_nag=fitting_degree, | |
image=input_image["image"].convert("RGB"), | |
mask=input_image["mask"].convert("RGB"), | |
num_inference_steps=ddim_steps, | |
generator=torch.Generator("cuda").manual_seed(seed), | |
brushnet_conditioning_scale=1.0, | |
negative_promptA=negative_promptA, | |
negative_promptB=negative_promptB, | |
negative_promptU=negative_prompt, | |
guidance_scale=scale, | |
width=H, | |
height=W, | |
).images[0] | |
mask_np = np.array(input_image["mask"].convert("RGB")) | |
red = np.array(result).astype("float") * 1 | |
red[:, :, 0] = 180.0 | |
red[:, :, 2] = 0 | |
red[:, :, 1] = 0 | |
result_m = np.array(result) | |
result_m = Image.fromarray( | |
( | |
result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red | |
).astype("uint8") | |
) | |
m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=3)) | |
m_img = np.asarray(m_img) / 255.0 | |
img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0 | |
ours_np = np.asarray(result) / 255.0 | |
ours_np = ours_np * m_img + (1 - m_img) * img_np | |
result_paste = Image.fromarray(np.uint8(ours_np * 255)) | |
dict_res = [input_image["mask"].convert("RGB"), result_m] | |
dict_out = [result] | |
return dict_out, dict_res | |
def infer( | |
input_image, | |
text_guided_prompt, | |
text_guided_negative_prompt, | |
shape_guided_prompt, | |
shape_guided_negative_prompt, | |
fitting_degree, | |
ddim_steps, | |
scale, | |
seed, | |
task, | |
vertical_expansion_ratio, | |
horizontal_expansion_ratio, | |
outpaint_prompt, | |
outpaint_negative_prompt, | |
removal_prompt, | |
removal_negative_prompt, | |
context_prompt, | |
context_negative_prompt, | |
): | |
if task == "text-guided": | |
prompt = text_guided_prompt | |
negative_prompt = text_guided_negative_prompt | |
elif task == "shape-guided": | |
prompt = shape_guided_prompt | |
negative_prompt = shape_guided_negative_prompt | |
elif task == "object-removal": | |
prompt = removal_prompt | |
negative_prompt = removal_negative_prompt | |
elif task == "context-aware": | |
prompt = context_prompt | |
negative_prompt = context_negative_prompt | |
elif task == "image-outpainting": | |
prompt = outpaint_prompt | |
negative_prompt = outpaint_negative_prompt | |
return predict( | |
input_image, | |
prompt, | |
fitting_degree, | |
ddim_steps, | |
scale, | |
seed, | |
negative_prompt, | |
task, | |
vertical_expansion_ratio, | |
horizontal_expansion_ratio, | |
) | |
else: | |
task = "text-guided" | |
prompt = text_guided_prompt | |
negative_prompt = text_guided_negative_prompt | |
return predict(input_image, prompt, fitting_degree, ddim_steps, scale, seed, negative_prompt, task, None, None) | |
def select_tab_text_guided(): | |
return "text-guided" | |
def select_tab_object_removal(): | |
return "object-removal" | |
def select_tab_context_aware(): | |
return "context-aware" | |
def select_tab_image_outpainting(): | |
return "image-outpainting" | |
def select_tab_shape_guided(): | |
return "shape-guided" | |
with gr.Blocks(css="style.css") as demo: | |
with gr.Row(): | |
gr.Markdown( | |
"<div align='center'><font size='18'>PowerPaint: High-Quality Versatile Image Inpainting</font></div>" # noqa | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"<div align='center'><font size='5'><a href='https://powerpaint.github.io/'>Project Page</a>  " # noqa | |
"<a href='https://arxiv.org/abs/2312.03594/'>Paper</a>  " | |
"<a href='https://github.com/zhuang2002/PowerPaint'>Code</a> </font></div>" # noqa | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"**Note:** Due to network-related factors, the page may experience occasional bugs! If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content." # noqa | |
) | |
# Attention: Due to network-related factors, the page may experience occasional bugs. If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content. | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Input image and draw mask") | |
input_image = gr.Image(source="upload", tool="sketch", type="pil") | |
task = gr.Radio( | |
["text-guided", "object-removal", "shape-guided", "image-outpainting"], show_label=False, visible=False | |
) | |
# Text-guided object inpainting | |
with gr.Tab("Text-guided object inpainting") as tab_text_guided: | |
enable_text_guided = gr.Checkbox( | |
label="Enable text-guided object inpainting", value=True, interactive=False | |
) | |
text_guided_prompt = gr.Textbox(label="Prompt") | |
text_guided_negative_prompt = gr.Textbox(label="negative_prompt") | |
tab_text_guided.select(fn=select_tab_text_guided, inputs=None, outputs=task) | |
# Object removal inpainting | |
with gr.Tab("Object removal inpainting") as tab_object_removal: | |
enable_object_removal = gr.Checkbox( | |
label="Enable object removal inpainting", | |
value=True, | |
info="The recommended configuration for the Guidance Scale is 10 or higher. \ | |
If undesired objects appear in the masked area, \ | |
you can address this by specifically increasing the Guidance Scale.", | |
interactive=False, | |
) | |
removal_prompt = gr.Textbox(label="Prompt") | |
removal_negative_prompt = gr.Textbox(label="negative_prompt") | |
context_prompt = removal_prompt | |
context_negative_prompt = removal_negative_prompt | |
tab_object_removal.select(fn=select_tab_object_removal, inputs=None, outputs=task) | |
# Object image outpainting | |
with gr.Tab("Image outpainting") as tab_image_outpainting: | |
enable_object_removal = gr.Checkbox( | |
label="Enable image outpainting", | |
value=True, | |
info="The recommended configuration for the Guidance Scale is 15 or higher. \ | |
If unwanted random objects appear in the extended image region, \ | |
you can enhance the cleanliness of the extension area by increasing the Guidance Scale.", | |
interactive=False, | |
) | |
outpaint_prompt = gr.Textbox(label="Outpainting_prompt") | |
outpaint_negative_prompt = gr.Textbox(label="Outpainting_negative_prompt") | |
horizontal_expansion_ratio = gr.Slider( | |
label="horizontal expansion ratio", | |
minimum=1, | |
maximum=4, | |
step=0.05, | |
value=1, | |
) | |
vertical_expansion_ratio = gr.Slider( | |
label="vertical expansion ratio", | |
minimum=1, | |
maximum=4, | |
step=0.05, | |
value=1, | |
) | |
tab_image_outpainting.select(fn=select_tab_image_outpainting, inputs=None, outputs=task) | |
# Shape-guided object inpainting | |
with gr.Tab("Shape-guided object inpainting") as tab_shape_guided: | |
enable_shape_guided = gr.Checkbox( | |
label="Enable shape-guided object inpainting", value=True, interactive=False | |
) | |
shape_guided_prompt = gr.Textbox(label="shape_guided_prompt") | |
shape_guided_negative_prompt = gr.Textbox(label="shape_guided_negative_prompt") | |
fitting_degree = gr.Slider( | |
label="fitting degree", | |
minimum=0.3, | |
maximum=1, | |
step=0.05, | |
value=1, | |
) | |
tab_shape_guided.select(fn=select_tab_shape_guided, inputs=None, outputs=task) | |
run_button = gr.Button(label="Run") | |
with gr.Accordion("Advanced options", open=False): | |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=50, step=1) | |
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=45.0, value=12, step=0.1) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2147483647, | |
step=1, | |
randomize=True, | |
) | |
with gr.Column(): | |
gr.Markdown("### Inpainting result") | |
inpaint_result = gr.Gallery(label="Generated images", show_label=False, columns=2) | |
gr.Markdown("### Mask") | |
gallery = gr.Gallery(label="Generated masks", show_label=False, columns=2) | |
run_button.click( | |
fn=infer, | |
inputs=[ | |
input_image, | |
text_guided_prompt, | |
text_guided_negative_prompt, | |
shape_guided_prompt, | |
shape_guided_negative_prompt, | |
fitting_degree, | |
ddim_steps, | |
scale, | |
seed, | |
task, | |
vertical_expansion_ratio, | |
horizontal_expansion_ratio, | |
outpaint_prompt, | |
outpaint_negative_prompt, | |
removal_prompt, | |
removal_negative_prompt, | |
context_prompt, | |
context_negative_prompt, | |
], | |
outputs=[inpaint_result, gallery], | |
) | |
demo.queue() | |
demo.launch(share=False, server_name="0.0.0.0", server_port=7860) | |