osmunphotography commited on
Commit
86afe4d
·
verified ·
1 Parent(s): 64d1965

Upload inpaint_worker 3.py

Browse files
Files changed (1) hide show
  1. inpaint_worker 3.py +264 -0
inpaint_worker 3.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from PIL import Image, ImageFilter
5
+ from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
6
+ from modules.upscaler import perform_upscale
7
+ import cv2
8
+
9
+
10
+ inpaint_head_model = None
11
+
12
+
13
+ class InpaintHead(torch.nn.Module):
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+ self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
17
+
18
+ def __call__(self, x):
19
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
20
+ return torch.nn.functional.conv2d(input=x, weight=self.head)
21
+
22
+
23
+ current_task = None
24
+
25
+
26
+ def box_blur(x, k):
27
+ x = Image.fromarray(x)
28
+ x = x.filter(ImageFilter.BoxBlur(k))
29
+ return np.array(x)
30
+
31
+
32
+ def max_filter_opencv(x, ksize=3):
33
+ # Use OpenCV maximum filter
34
+ # Make sure the input type is int16
35
+ return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
36
+
37
+
38
+ def morphological_open(x):
39
+ # Convert array to int16 type via threshold operation
40
+ x_int16 = np.zeros_like(x, dtype=np.int16)
41
+ x_int16[x > 127] = 256
42
+
43
+ for i in range(32):
44
+ # Use int16 type to avoid overflow
45
+ maxed = max_filter_opencv(x_int16, ksize=3) - 8
46
+ x_int16 = np.maximum(maxed, x_int16)
47
+
48
+ # Clip negative values to 0 and convert back to uint8 type
49
+ x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
50
+ return x_uint8
51
+
52
+
53
+ def up255(x, t=0):
54
+ y = np.zeros_like(x).astype(np.uint8)
55
+ y[x > t] = 255
56
+ return y
57
+
58
+
59
+ def imsave(x, path):
60
+ x = Image.fromarray(x)
61
+ x.save(path)
62
+
63
+
64
+ def regulate_abcd(x, a, b, c, d):
65
+ H, W = x.shape[:2]
66
+ if a < 0:
67
+ a = 0
68
+ if a > H:
69
+ a = H
70
+ if b < 0:
71
+ b = 0
72
+ if b > H:
73
+ b = H
74
+ if c < 0:
75
+ c = 0
76
+ if c > W:
77
+ c = W
78
+ if d < 0:
79
+ d = 0
80
+ if d > W:
81
+ d = W
82
+ return int(a), int(b), int(c), int(d)
83
+
84
+
85
+ def compute_initial_abcd(x):
86
+ indices = np.where(x)
87
+ a = np.min(indices[0])
88
+ b = np.max(indices[0])
89
+ c = np.min(indices[1])
90
+ d = np.max(indices[1])
91
+ abp = (b + a) // 2
92
+ abm = (b - a) // 2
93
+ cdp = (d + c) // 2
94
+ cdm = (d - c) // 2
95
+ l = int(max(abm, cdm) * 1.15)
96
+ a = abp - l
97
+ b = abp + l + 1
98
+ c = cdp - l
99
+ d = cdp + l + 1
100
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
101
+ return a, b, c, d
102
+
103
+
104
+ def solve_abcd(x, a, b, c, d, k):
105
+ k = float(k)
106
+ assert 0.0 <= k <= 1.0
107
+
108
+ H, W = x.shape[:2]
109
+ if k == 1.0:
110
+ return 0, H, 0, W
111
+ while True:
112
+ if b - a >= H * k and d - c >= W * k:
113
+ break
114
+
115
+ add_h = (b - a) < (d - c)
116
+ add_w = not add_h
117
+
118
+ if b - a == H:
119
+ add_w = True
120
+
121
+ if d - c == W:
122
+ add_h = True
123
+
124
+ if add_h:
125
+ a -= 1
126
+ b += 1
127
+
128
+ if add_w:
129
+ c -= 1
130
+ d += 1
131
+
132
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
133
+ return a, b, c, d
134
+
135
+
136
+ def fooocus_fill(image, mask):
137
+ current_image = image.copy()
138
+ raw_image = image.copy()
139
+ area = np.where(mask < 127)
140
+ store = raw_image[area]
141
+
142
+ for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
143
+ for _ in range(repeats):
144
+ current_image = box_blur(current_image, k)
145
+ current_image[area] = store
146
+
147
+ return current_image
148
+
149
+
150
+ class InpaintWorker:
151
+ def __init__(self, image, mask, use_fill=True, k=0.618):
152
+ a, b, c, d = compute_initial_abcd(mask > 0)
153
+ a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
154
+
155
+ # interested area
156
+ self.interested_area = (a, b, c, d)
157
+ self.interested_mask = mask[a:b, c:d]
158
+ self.interested_image = image[a:b, c:d]
159
+
160
+ # super resolution
161
+ if get_image_shape_ceil(self.interested_image) < 1024:
162
+ self.interested_image = perform_upscale(self.interested_image)
163
+
164
+ # resize to make images ready for diffusion
165
+ self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
166
+ self.interested_fill = self.interested_image.copy()
167
+ H, W, C = self.interested_image.shape
168
+
169
+ # process mask
170
+ self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
171
+
172
+ # compute filling
173
+ if use_fill:
174
+ self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
175
+
176
+ # soft pixels
177
+ self.mask = morphological_open(mask)
178
+ self.image = image
179
+
180
+ # ending
181
+ self.latent = None
182
+ self.latent_after_swap = None
183
+ self.swapped = False
184
+ self.latent_mask = None
185
+ self.inpaint_head_feature = None
186
+ return
187
+
188
+ def load_latent(self, latent_fill, latent_mask, latent_swap=None):
189
+ self.latent = latent_fill
190
+ self.latent_mask = latent_mask
191
+ self.latent_after_swap = latent_swap
192
+ return
193
+
194
+ def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
195
+ global inpaint_head_model
196
+
197
+ if inpaint_head_model is None:
198
+ inpaint_head_model = InpaintHead()
199
+ sd = torch.load(inpaint_head_model_path, map_location='cpu')
200
+ inpaint_head_model.load_state_dict(sd)
201
+
202
+ feed = torch.cat([
203
+ inpaint_latent_mask,
204
+ model.model.process_latent_in(inpaint_latent)
205
+ ], dim=1)
206
+
207
+ inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
208
+ inpaint_head_feature = inpaint_head_model(feed)
209
+
210
+ def input_block_patch(h, transformer_options):
211
+ if transformer_options["block"][1] == 0:
212
+ h = h + inpaint_head_feature.to(h)
213
+ return h
214
+
215
+ m = model.clone()
216
+ m.set_model_input_block_patch(input_block_patch)
217
+ return m
218
+
219
+ def swap(self):
220
+ if self.swapped:
221
+ return
222
+
223
+ if self.latent is None:
224
+ return
225
+
226
+ if self.latent_after_swap is None:
227
+ return
228
+
229
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
230
+ self.swapped = True
231
+ return
232
+
233
+ def unswap(self):
234
+ if not self.swapped:
235
+ return
236
+
237
+ if self.latent is None:
238
+ return
239
+
240
+ if self.latent_after_swap is None:
241
+ return
242
+
243
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
244
+ self.swapped = False
245
+ return
246
+
247
+ def color_correction(self, img):
248
+ fg = img.astype(np.float32)
249
+ bg = self.image.copy().astype(np.float32)
250
+ w = self.mask[:, :, None].astype(np.float32) / 255.0
251
+ y = fg * w + bg * (1 - w)
252
+ return y.clip(0, 255).astype(np.uint8)
253
+
254
+ def post_process(self, img):
255
+ a, b, c, d = self.interested_area
256
+ content = resample_image(img, d - c, b - a)
257
+ result = self.image.copy()
258
+ result[a:b, c:d] = content
259
+ result = self.color_correction(result)
260
+ return result
261
+
262
+ def visualize_mask_processing(self):
263
+ return [self.interested_fill, self.interested_mask, self.interested_image]
264
+