File size: 3,265 Bytes
8f570a9
 
0d89801
7e66050
 
 
0d89801
 
 
 
031c42b
83cae6c
 
 
0d89801
 
f0e8d1f
0d89801
 
f0e8d1f
0d89801
 
8f570a9
 
 
 
 
 
 
 
 
 
 
 
 
 
bca1af7
 
8f570a9
 
 
 
 
 
 
3e075bb
0d89801
 
 
 
 
 
7e66050
0d89801
 
 
 
 
7e66050
0d89801
 
 
8f570a9
7e66050
 
867296e
0d89801
 
7e66050
 
867296e
 
8f570a9
 
0d89801
 
 
 
 
 
 
 
 
 
 
 
 
3e075bb
 
 
 
 
 
 
 
0d89801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Tuple

import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import FluxInpaintPipeline

MARKDOWN = """
# FLUX.1 Inpainting 🔥

Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for 
creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos) 
for taking it to the next level by enabling inpainting with the FLUX.
"""

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

pipe = FluxInpaintPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)


def resize_image_dimensions(
    original_resolution_wh: Tuple[int, int],
    maximum_dimension: int = 2048
) -> Tuple[int, int]:
    width, height = original_resolution_wh

    if width > height:
        scaling_factor = maximum_dimension / width
    else:
        scaling_factor = maximum_dimension / height

    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)

    new_width = new_width - (new_width % 32)
    new_height = new_height - (new_height % 32)

    new_width = min(maximum_dimension, new_width)
    new_height = min(maximum_dimension, new_height)

    return new_width, new_height


@spaces.GPU()
def process(input_image_editor, input_text, progress=gr.Progress(track_tqdm=True)):
    if not input_text:
        gr.Info("Please enter a text prompt.")
        return None

    image = input_image_editor['background']
    mask = input_image_editor['layers'][0]

    if not image:
        gr.Info("Please upload an image.")
        return None

    if not mask:
        gr.Info("Please draw a mask on the image.")
        return None

    width, height = resize_image_dimensions(original_resolution_wh=image.size)
    resized_image = image.resize((width, height), Image.LANCZOS)
    resized_mask = mask.resize((width, height), Image.NEAREST)

    return pipe(
        prompt=input_text,
        image=resized_image,
        mask_image=resized_mask,
        width=width,
        height=height,
        strength=0.7,
        num_inference_steps=2
    ).images[0]


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            input_image_editor_component = gr.ImageEditor(
                label='Image',
                type='pil',
                sources=["upload", "webcam"],
                image_mode='RGB',
                layers=False,
                brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
            input_text_component = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
        with gr.Column():
            output_image_component = gr.Image(
                type='pil', image_mode='RGB', label='Generated image')

    submit_button_component.click(
        fn=process,
        inputs=[
            input_image_editor_component,
            input_text_component
        ],
        outputs=[
            output_image_component
        ]
    )

demo.launch(debug=False, show_error=True)