File size: 4,480 Bytes
edb0494
6405936
 
edb0494
e07d759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edb0494
6405936
 
 
 
e07d759
 
6405936
e07d759
 
6405936
e07d759
6405936
 
 
e07d759
 
6405936
 
 
 
 
 
 
 
 
 
80401cc
6405936
 
8de6c5c
6405936
 
 
 
 
 
 
 
 
e07d759
6405936
 
 
 
e07d759
 
 
 
 
 
 
 
 
 
 
 
 
 
6405936
 
 
 
 
 
 
 
 
 
 
432e607
6405936
432e607
 
 
 
6405936
 
e07d759
6405936
97567b1
 
6405936
 
 
 
 
 
97567b1
6405936
 
a468e28
6405936
 
e07d759
6405936
 
 
 
97567b1
 
 
 
 
 
6405936
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import spaces
import torch

from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, TCDScheduler


def callback_cfg_cutoff(pipeline, step_index, timestep, callback_kwargs):
    if step_index == int(pipeline.num_timesteps * 0.2):
        prompt_embeds = callback_kwargs["prompt_embeds"]
        prompt_embeds = prompt_embeds[-1:]

        add_text_embeds = callback_kwargs["add_text_embeds"]
        add_text_embeds = add_text_embeds[-1:]

        add_time_ids = callback_kwargs["add_time_ids"]
        add_time_ids = add_time_ids[-1:]

        control_image = callback_kwargs["control_image"]
        control_image[0] = control_image[0][-1:]

        control_type = callback_kwargs["control_type"]
        control_type = control_type[-1:]

        pipeline._guidance_scale = 0.0
        callback_kwargs["prompt_embeds"] = prompt_embeds
        callback_kwargs["add_text_embeds"] = add_text_embeds
        callback_kwargs["add_time_ids"] = add_time_ids
        callback_kwargs["control_image"] = control_image
        callback_kwargs["control_type"] = control_type

    return callback_kwargs


MODELS = {
    "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}

controlnet_model = ControlNetUnionModel.from_pretrained(
    "OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16
)
controlnet_model.to(device="cuda", dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")

pipe = DiffusionPipeline.from_pretrained(
    "SG161222/RealVisXL_V5.0_Lightning",
    torch_dtype=torch.float16,
    vae=vae,
    controlnet=controlnet_model,
    custom_pipeline="OzzyGT/custom_sdxl_cnet_union",
).to("cuda")

pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)

prompt = "high quality"
(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt, device="cuda")


@spaces.GPU(duration=16)
def fill_image(image, model_selection):
    source = image["background"]
    mask = image["layers"][0]

    alpha_channel = mask.split()[3]
    binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
    cnet_image = source.copy()
    cnet_image.paste(0, (0, 0), binary_mask)

    image = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        control_image=[cnet_image],
        controlnet_conditioning_scale=[1.0],
        control_mode=[7],
        num_inference_steps=8,
        guidance_scale=1.5,
        callback_on_step_end=callback_cfg_cutoff,
        callback_on_step_end_tensor_inputs=[
            "prompt_embeds",
            "add_text_embeds",
            "add_time_ids",
            "control_image",
            "control_type",
        ],
    ).images[0]

    image = image.convert("RGBA")
    cnet_image.paste(image, (0, 0), binary_mask)

    yield source, cnet_image


def clear_result():
    return gr.update(value=None)


title = """<h2 align="center">Diffusers Image Fill</h2>
<div align="center">Draw the mask over the subject you want to erase or change.</div>
<div align="center">
    This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-image-fill'>Diffusers Image Fill</a>. 
    If you need a space where you can use prompts, please go to the <a href='https://huggingface.co/spaces/OzzyGT/diffusers-fast-inpaint'>Diffusers Fast Inpaint</a> space.
</div>
"""

with gr.Blocks() as demo:
    gr.HTML(title)

    run_button = gr.Button("Generate")

    with gr.Row():
        input_image = gr.ImageMask(
            type="pil",
            label="Input Image",
            crop_size=(1024, 1024),
            canvas_size=(1024, 1024),
            layers=False,
            sources=["upload"],
            height=512,
        )

        result = gr.ImageSlider(
            interactive=False,
            label="Generated Image",
        )

    model_selection = gr.Dropdown(
        choices=list(MODELS.keys()),
        value="RealVisXL V5.0 Lightning",
        label="Model",
    )

    run_button.click(
        fn=clear_result,
        inputs=None,
        outputs=result,
    ).then(
        fn=fill_image,
        inputs=[input_image, model_selection],
        outputs=result,
    )


demo.launch(share=False)