not-lain commited on
Commit
aa16383
·
1 Parent(s): abb6b39

wip outpainting

Browse files
Files changed (2) hide show
  1. app.py +194 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -4,6 +4,9 @@ import torch
4
  from loadimg import load_img
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation
 
 
 
7
 
8
  torch.set_float32_matmul_precision(["high", "highest"][0])
9
 
@@ -20,10 +23,186 @@ transform_image = transforms.Compose(
20
  ]
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @spaces.GPU
25
- def rmbg(image,url):
26
- if image is None :
27
  image = url
28
  image = load_img(image).convert("RGB")
29
  image_size = image.size
@@ -38,11 +217,21 @@ def rmbg(image,url):
38
  return image
39
 
40
 
41
- rmbg_tab = gr.Interface(fn=rmbg, inputs=["image","text"], outputs=["image"], api_name="rmbg")
 
 
 
 
 
 
 
 
 
 
42
 
43
  demo = gr.TabbedInterface(
44
- [rmbg_tab],
45
- ["remove background"],
46
  title="Utilities that require GPU",
47
  )
48
 
 
4
  from loadimg import load_img
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation
7
+ from diffusers import FluxFillPipeline
8
+ from PIL import Image, ImageDraw
9
+ from diffusers.utils import load_image
10
 
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
 
 
23
  ]
24
  )
25
 
26
+ pipe = FluxFillPipeline.from_pretrained(
27
+ "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
28
+ ).to("cuda")
29
+
30
+
31
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
32
+ if alignment in ("Left", "Right") and source_width >= target_width:
33
+ return False
34
+ if alignment in ("Top", "Bottom") and source_height >= target_height:
35
+ return False
36
+ return True
37
+
38
+
39
+ def prepare_image_and_mask(
40
+ image,
41
+ width,
42
+ height,
43
+ overlap_percentage,
44
+ resize_percentage,
45
+ alignment,
46
+ overlap_left,
47
+ overlap_right,
48
+ overlap_top,
49
+ overlap_bottom,
50
+ ):
51
+ target_size = (width, height)
52
+
53
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
54
+ new_width = int(image.width * scale_factor)
55
+ new_height = int(image.height * scale_factor)
56
+
57
+ source = image.resize((new_width, new_height), Image.LANCZOS)
58
+
59
+ resize_percentage = 50
60
+
61
+ # Calculate new dimensions based on percentage
62
+ resize_factor = resize_percentage / 100
63
+ new_width = int(source.width * resize_factor)
64
+ new_height = int(source.height * resize_factor)
65
+
66
+ # Ensure minimum size of 64 pixels
67
+ new_width = max(new_width, 64)
68
+ new_height = max(new_height, 64)
69
+
70
+ # Resize the image
71
+ source = source.resize((new_width, new_height), Image.LANCZOS)
72
+
73
+ # Calculate the overlap in pixels based on the percentage
74
+ overlap_x = int(new_width * (overlap_percentage / 100))
75
+ overlap_y = int(new_height * (overlap_percentage / 100))
76
+
77
+ # Ensure minimum overlap of 1 pixel
78
+ overlap_x = max(overlap_x, 1)
79
+ overlap_y = max(overlap_y, 1)
80
+
81
+ # Calculate margins based on alignment
82
+ if alignment == "Middle":
83
+ margin_x = (target_size[0] - new_width) // 2
84
+ margin_y = (target_size[1] - new_height) // 2
85
+ elif alignment == "Left":
86
+ margin_x = 0
87
+ margin_y = (target_size[1] - new_height) // 2
88
+ elif alignment == "Right":
89
+ margin_x = target_size[0] - new_width
90
+ margin_y = (target_size[1] - new_height) // 2
91
+ elif alignment == "Top":
92
+ margin_x = (target_size[0] - new_width) // 2
93
+ margin_y = 0
94
+ elif alignment == "Bottom":
95
+ margin_x = (target_size[0] - new_width) // 2
96
+ margin_y = target_size[1] - new_height
97
+
98
+ # Adjust margins to eliminate gaps
99
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
100
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
101
+
102
+ # Create a new background image and paste the resized source image
103
+ background = Image.new("RGB", target_size, (255, 255, 255))
104
+ background.paste(source, (margin_x, margin_y))
105
+
106
+ # Create the mask
107
+ mask = Image.new("L", target_size, 255)
108
+ mask_draw = ImageDraw.Draw(mask)
109
+
110
+ # Calculate overlap areas
111
+ white_gaps_patch = 2
112
+
113
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
114
+ right_overlap = (
115
+ margin_x + new_width - overlap_x
116
+ if overlap_right
117
+ else margin_x + new_width - white_gaps_patch
118
+ )
119
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
120
+ bottom_overlap = (
121
+ margin_y + new_height - overlap_y
122
+ if overlap_bottom
123
+ else margin_y + new_height - white_gaps_patch
124
+ )
125
+
126
+ if alignment == "Left":
127
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
128
+ elif alignment == "Right":
129
+ right_overlap = (
130
+ margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
131
+ )
132
+ elif alignment == "Top":
133
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
134
+ elif alignment == "Bottom":
135
+ bottom_overlap = (
136
+ margin_y + new_height - overlap_y
137
+ if overlap_bottom
138
+ else margin_y + new_height
139
+ )
140
+
141
+ # Draw the mask
142
+ mask_draw.rectangle(
143
+ [(left_overlap, top_overlap), (right_overlap, bottom_overlap)], fill=0
144
+ )
145
+
146
+ return background, mask
147
+
148
+
149
+ def inpaint(
150
+ image,
151
+ width,
152
+ height,
153
+ overlap_percentage,
154
+ num_inference_steps,
155
+ custom_resize_percentage,
156
+ prompt_input,
157
+ alignment,
158
+ overlap_left,
159
+ overlap_right,
160
+ overlap_top,
161
+ overlap_bottom,
162
+ progress=gr.Progress(track_tqdm=True),
163
+ ):
164
+ background, mask = prepare_image_and_mask(
165
+ image,
166
+ width,
167
+ height,
168
+ overlap_percentage,
169
+ custom_resize_percentage,
170
+ alignment,
171
+ overlap_left,
172
+ overlap_right,
173
+ overlap_top,
174
+ overlap_bottom,
175
+ )
176
+
177
+ if not can_expand(background.width, background.height, width, height, alignment):
178
+ alignment = "Middle"
179
+
180
+ cnet_image = background.copy()
181
+ cnet_image.paste(0, (0, 0), mask)
182
+
183
+ final_prompt = prompt_input
184
+
185
+ # generator = torch.Generator(device="cuda").manual_seed(42)
186
+
187
+ result = pipe(
188
+ prompt=final_prompt,
189
+ height=height,
190
+ width=width,
191
+ image=cnet_image,
192
+ mask_image=mask,
193
+ num_inference_steps=num_inference_steps,
194
+ guidance_scale=30,
195
+ ).images[0]
196
+
197
+ result = result.convert("RGBA")
198
+ cnet_image.paste(result, (0, 0), mask)
199
+
200
+ return cnet_image
201
+
202
 
203
  @spaces.GPU
204
+ def rmbg(image, url):
205
+ if image is None:
206
  image = url
207
  image = load_img(image).convert("RGB")
208
  image_size = image.size
 
217
  return image
218
 
219
 
220
+ def placeholder(img):
221
+ return img
222
+
223
+
224
+ rmbg_tab = gr.Interface(
225
+ fn=rmbg, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
226
+ )
227
+
228
+ outpaint_tab = gr.Interface(
229
+ fr=placeholder, inputs=["image"], outputs=["image"], api_name="outpainting"
230
+ )
231
 
232
  demo = gr.TabbedInterface(
233
+ [rmbg_tab, outpaint_tab],
234
+ ["remove background", "outpainting"],
235
  title="Utilities that require GPU",
236
  )
237
 
requirements.txt CHANGED
@@ -11,3 +11,4 @@ scikit-image
11
  kornia
12
  transformers
13
  huggingface_hub
 
 
11
  kornia
12
  transformers
13
  huggingface_hub
14
+ diffusers