Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,480 Bytes
edb0494 6405936 edb0494 e07d759 edb0494 6405936 e07d759 6405936 e07d759 6405936 e07d759 6405936 e07d759 6405936 80401cc 6405936 8de6c5c 6405936 e07d759 6405936 e07d759 6405936 432e607 6405936 432e607 6405936 e07d759 6405936 97567b1 6405936 97567b1 6405936 a468e28 6405936 e07d759 6405936 97567b1 6405936 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, TCDScheduler
def callback_cfg_cutoff(pipeline, step_index, timestep, callback_kwargs):
if step_index == int(pipeline.num_timesteps * 0.2):
prompt_embeds = callback_kwargs["prompt_embeds"]
prompt_embeds = prompt_embeds[-1:]
add_text_embeds = callback_kwargs["add_text_embeds"]
add_text_embeds = add_text_embeds[-1:]
add_time_ids = callback_kwargs["add_time_ids"]
add_time_ids = add_time_ids[-1:]
control_image = callback_kwargs["control_image"]
control_image[0] = control_image[0][-1:]
control_type = callback_kwargs["control_type"]
control_type = control_type[-1:]
pipeline._guidance_scale = 0.0
callback_kwargs["prompt_embeds"] = prompt_embeds
callback_kwargs["add_text_embeds"] = add_text_embeds
callback_kwargs["add_time_ids"] = add_time_ids
callback_kwargs["control_image"] = control_image
callback_kwargs["control_type"] = control_type
return callback_kwargs
MODELS = {
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
controlnet_model = ControlNetUnionModel.from_pretrained(
"OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16
)
controlnet_model.to(device="cuda", dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipe = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=controlnet_model,
custom_pipeline="OzzyGT/custom_sdxl_cnet_union",
).to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
prompt = "high quality"
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt, device="cuda")
@spaces.GPU(duration=16)
def fill_image(image, model_selection):
source = image["background"]
mask = image["layers"][0]
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
image = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
control_image=[cnet_image],
controlnet_conditioning_scale=[1.0],
control_mode=[7],
num_inference_steps=8,
guidance_scale=1.5,
callback_on_step_end=callback_cfg_cutoff,
callback_on_step_end_tensor_inputs=[
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
"control_image",
"control_type",
],
).images[0]
image = image.convert("RGBA")
cnet_image.paste(image, (0, 0), binary_mask)
yield source, cnet_image
def clear_result():
return gr.update(value=None)
title = """<h2 align="center">Diffusers Image Fill</h2>
<div align="center">Draw the mask over the subject you want to erase or change.</div>
<div align="center">
This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-image-fill'>Diffusers Image Fill</a>.
If you need a space where you can use prompts, please go to the <a href='https://huggingface.co/spaces/OzzyGT/diffusers-fast-inpaint'>Diffusers Fast Inpaint</a> space.
</div>
"""
with gr.Blocks() as demo:
gr.HTML(title)
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label="Input Image",
crop_size=(1024, 1024),
canvas_size=(1024, 1024),
layers=False,
sources=["upload"],
height=512,
)
result = gr.ImageSlider(
interactive=False,
label="Generated Image",
)
model_selection = gr.Dropdown(
choices=list(MODELS.keys()),
value="RealVisXL V5.0 Lightning",
label="Model",
)
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=fill_image,
inputs=[input_image, model_selection],
outputs=result,
)
demo.launch(share=False)
|