Mix commited on
Commit
5596f16
1 Parent(s): 04d9294

Add application

Browse files
Files changed (1) hide show
  1. app.py +836 -0
app.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import IntEnum
3
+ from pathlib import Path
4
+ from tempfile import mktemp
5
+ from typing import IO, Dict, Type
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from gradio import Interface, inputs, outputs
13
+
14
+ DEVICE = "cpu"
15
+
16
+ WEIGHTS_PATH = Path(__file__).parent / "weights"
17
+
18
+ AVALIABLE_WEIGHTS = {
19
+ basename: path
20
+ for basename, ext in (
21
+ os.path.splitext(filename) for filename in os.listdir(WEIGHTS_PATH)
22
+ )
23
+ if (path := WEIGHTS_PATH / (basename + ext)).is_file() and ext.endswith("pth")
24
+ }
25
+
26
+
27
+ class ScaleMode(IntEnum):
28
+ up2x = 2
29
+ up3x = 3
30
+ up4x = 4
31
+
32
+
33
+ class TileMode(IntEnum):
34
+ full = 0
35
+ half = 1
36
+ quarter = 2
37
+ ninth = 3
38
+ sixteenth = 4
39
+
40
+
41
+ class SEBlock(nn.Module):
42
+ def __init__(self, in_channels, reduction=8, bias=False):
43
+ super(SEBlock, self).__init__()
44
+ self.conv1 = nn.Conv2d(
45
+ in_channels, in_channels // reduction, 1, 1, 0, bias=bias
46
+ )
47
+ self.conv2 = nn.Conv2d(
48
+ in_channels // reduction, in_channels, 1, 1, 0, bias=bias
49
+ )
50
+
51
+ def forward(self, x):
52
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
53
+ x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
54
+ else:
55
+ x0 = torch.mean(x, dim=(2, 3), keepdim=True)
56
+ x0 = self.conv1(x0)
57
+ x0 = F.relu(x0, inplace=True)
58
+ x0 = self.conv2(x0)
59
+ x0 = torch.sigmoid(x0)
60
+ x = torch.mul(x, x0)
61
+ return x
62
+
63
+ def forward_mean(self, x, x0):
64
+ x0 = self.conv1(x0)
65
+ x0 = F.relu(x0, inplace=True)
66
+ x0 = self.conv2(x0)
67
+ x0 = torch.sigmoid(x0)
68
+ x = torch.mul(x, x0)
69
+ return x
70
+
71
+
72
+ class UNetConv(nn.Module):
73
+ def __init__(self, in_channels, mid_channels, out_channels, se):
74
+ super(UNetConv, self).__init__()
75
+ self.conv = nn.Sequential(
76
+ nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
77
+ nn.LeakyReLU(0.1, inplace=True),
78
+ nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
79
+ nn.LeakyReLU(0.1, inplace=True),
80
+ )
81
+ if se:
82
+ self.seblock = SEBlock(out_channels, reduction=8, bias=True)
83
+ else:
84
+ self.seblock = None
85
+
86
+ def forward(self, x):
87
+ z = self.conv(x)
88
+ if self.seblock is not None:
89
+ z = self.seblock(z)
90
+ return z
91
+
92
+
93
+ class UNet1(nn.Module):
94
+ def __init__(self, in_channels, out_channels, deconv):
95
+ super(UNet1, self).__init__()
96
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
97
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
98
+ self.conv2 = UNetConv(64, 128, 64, se=True)
99
+ self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
100
+ self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
101
+
102
+ if deconv:
103
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
104
+ else:
105
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
109
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
110
+ elif isinstance(m, nn.Linear):
111
+ nn.init.normal_(m.weight, 0, 0.01)
112
+ if m.bias is not None:
113
+ nn.init.constant_(m.bias, 0)
114
+
115
+ def forward(self, x):
116
+ x1 = self.conv1(x)
117
+ x2 = self.conv1_down(x1)
118
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
119
+ x2 = self.conv2(x2)
120
+ x2 = self.conv2_up(x2)
121
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
122
+
123
+ x1 = F.pad(x1, (-4, -4, -4, -4))
124
+ x3 = self.conv3(x1 + x2)
125
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
126
+ z = self.conv_bottom(x3)
127
+ return z
128
+
129
+ def forward_a(self, x):
130
+ x1 = self.conv1(x)
131
+ x2 = self.conv1_down(x1)
132
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
133
+ x2 = self.conv2.conv(x2)
134
+ return x1, x2
135
+
136
+ def forward_b(self, x1, x2):
137
+ x2 = self.conv2_up(x2)
138
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
139
+
140
+ x1 = F.pad(x1, (-4, -4, -4, -4))
141
+ x3 = self.conv3(x1 + x2)
142
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
143
+ z = self.conv_bottom(x3)
144
+ return z
145
+
146
+
147
+ class UNet1x3(nn.Module):
148
+ def __init__(self, in_channels, out_channels, deconv):
149
+ super(UNet1x3, self).__init__()
150
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
151
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
152
+ self.conv2 = UNetConv(64, 128, 64, se=True)
153
+ self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
154
+ self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
155
+
156
+ if deconv:
157
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 5, 3, 2)
158
+ else:
159
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
160
+
161
+ for m in self.modules():
162
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
163
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
164
+ elif isinstance(m, nn.Linear):
165
+ nn.init.normal_(m.weight, 0, 0.01)
166
+ if m.bias is not None:
167
+ nn.init.constant_(m.bias, 0)
168
+
169
+ def forward(self, x):
170
+ x1 = self.conv1(x)
171
+ x2 = self.conv1_down(x1)
172
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
173
+ x2 = self.conv2(x2)
174
+ x2 = self.conv2_up(x2)
175
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
176
+
177
+ x1 = F.pad(x1, (-4, -4, -4, -4))
178
+ x3 = self.conv3(x1 + x2)
179
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
180
+ z = self.conv_bottom(x3)
181
+ return z
182
+
183
+ def forward_a(self, x):
184
+ x1 = self.conv1(x)
185
+ x2 = self.conv1_down(x1)
186
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
187
+ x2 = self.conv2.conv(x2)
188
+ return x1, x2
189
+
190
+ def forward_b(self, x1, x2):
191
+ x2 = self.conv2_up(x2)
192
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
193
+
194
+ x1 = F.pad(x1, (-4, -4, -4, -4))
195
+ x3 = self.conv3(x1 + x2)
196
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
197
+ z = self.conv_bottom(x3)
198
+ return z
199
+
200
+
201
+ class UNet2(nn.Module):
202
+ def __init__(self, in_channels, out_channels, deconv):
203
+ super(UNet2, self).__init__()
204
+
205
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
206
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
207
+ self.conv2 = UNetConv(64, 64, 128, se=True)
208
+ self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
209
+ self.conv3 = UNetConv(128, 256, 128, se=True)
210
+ self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
211
+ self.conv4 = UNetConv(128, 64, 64, se=True)
212
+ self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
213
+ self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
214
+
215
+ if deconv:
216
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
217
+ else:
218
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
219
+
220
+ for m in self.modules():
221
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
222
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
223
+ elif isinstance(m, nn.Linear):
224
+ nn.init.normal_(m.weight, 0, 0.01)
225
+ if m.bias is not None:
226
+ nn.init.constant_(m.bias, 0)
227
+
228
+ def forward(self, x):
229
+ x1 = self.conv1(x)
230
+ x2 = self.conv1_down(x1)
231
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
232
+ x2 = self.conv2(x2)
233
+
234
+ x3 = self.conv2_down(x2)
235
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
236
+ x3 = self.conv3(x3)
237
+ x3 = self.conv3_up(x3)
238
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
239
+
240
+ x2 = F.pad(x2, (-4, -4, -4, -4))
241
+ x4 = self.conv4(x2 + x3)
242
+ x4 = self.conv4_up(x4)
243
+ x4 = F.leaky_relu(x4, 0.1, inplace=True)
244
+
245
+ x1 = F.pad(x1, (-16, -16, -16, -16))
246
+ x5 = self.conv5(x1 + x4)
247
+ x5 = F.leaky_relu(x5, 0.1, inplace=True)
248
+
249
+ z = self.conv_bottom(x5)
250
+ return z
251
+
252
+ def forward_a(self, x): # conv234结尾有se
253
+ x1 = self.conv1(x)
254
+ x2 = self.conv1_down(x1)
255
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
256
+ x2 = self.conv2.conv(x2)
257
+ return x1, x2
258
+
259
+ def forward_b(self, x2): # conv234结尾有se
260
+ x3 = self.conv2_down(x2)
261
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
262
+ x3 = self.conv3.conv(x3)
263
+ return x3
264
+
265
+ def forward_c(self, x2, x3): # conv234结尾有se
266
+ x3 = self.conv3_up(x3)
267
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
268
+
269
+ x2 = F.pad(x2, (-4, -4, -4, -4))
270
+ x4 = self.conv4.conv(x2 + x3)
271
+ return x4
272
+
273
+ def forward_d(self, x1, x4): # conv234结尾有se
274
+ x4 = self.conv4_up(x4)
275
+ x4 = F.leaky_relu(x4, 0.1, inplace=True)
276
+
277
+ x1 = F.pad(x1, (-16, -16, -16, -16))
278
+ x5 = self.conv5(x1 + x4)
279
+ x5 = F.leaky_relu(x5, 0.1, inplace=True)
280
+
281
+ z = self.conv_bottom(x5)
282
+ return z
283
+
284
+
285
+ class UpCunet2x(nn.Module): # 完美tile,全程无损
286
+ def __init__(self, in_channels=3, out_channels=3):
287
+ super(UpCunet2x, self).__init__()
288
+ self.unet1 = UNet1(in_channels, out_channels, deconv=True)
289
+ self.unet2 = UNet2(in_channels, out_channels, deconv=False)
290
+
291
+ def forward(self, x, tile_mode): # 1.7G
292
+ n, c, h0, w0 = x.shape
293
+ if tile_mode == 0: # 不tile
294
+ ph = ((h0 - 1) // 2 + 1) * 2
295
+ pw = ((w0 - 1) // 2 + 1) * 2
296
+ x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect") # 需要保证被2整除
297
+ x = self.unet1.forward(x)
298
+ x0 = self.unet2.forward(x)
299
+ x1 = F.pad(x, (-20, -20, -20, -20))
300
+ x = torch.add(x0, x1)
301
+ if w0 != pw or h0 != ph:
302
+ x = x[:, :, : h0 * 2, : w0 * 2]
303
+ return x
304
+ elif tile_mode == 1: # 对长边减半
305
+ if w0 >= h0:
306
+ crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
307
+ crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
308
+ else:
309
+ crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
310
+ crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
311
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
312
+ elif tile_mode == 2: # hw都减半
313
+ crop_size = (
314
+ ((h0 - 1) // 4 * 4 + 4) // 2,
315
+ ((w0 - 1) // 4 * 4 + 4) // 2,
316
+ ) # 5.6G
317
+ elif tile_mode == 3: # hw都三分之一
318
+ crop_size = (
319
+ ((h0 - 1) // 6 * 6 + 6) // 3,
320
+ ((w0 - 1) // 6 * 6 + 6) // 3,
321
+ ) # 4.2G
322
+ elif tile_mode == 4: # hw都四分之一
323
+ crop_size = (
324
+ ((h0 - 1) // 8 * 8 + 8) // 4,
325
+ ((w0 - 1) // 8 * 8 + 8) // 4,
326
+ ) # 3.7G
327
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
328
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
329
+ x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect")
330
+ n, c, h, w = x.shape
331
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
332
+ if "Half" in x.type():
333
+ se_mean0 = se_mean0.half()
334
+ n_patch = 0
335
+ tmp_dict = {}
336
+ opt_res_dict = {}
337
+ for i in range(0, h - 36, crop_size[0]):
338
+ tmp_dict[i] = {}
339
+ for j in range(0, w - 36, crop_size[1]):
340
+ x_crop = x[:, :, i : i + crop_size[0] + 36, j : j + crop_size[1] + 36]
341
+ n, c1, h1, w1 = x_crop.shape
342
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
343
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
344
+ tmp_se_mean = torch.mean(
345
+ x_crop.float(), dim=(2, 3), keepdim=True
346
+ ).half()
347
+ else:
348
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
349
+ se_mean0 += tmp_se_mean
350
+ n_patch += 1
351
+ tmp_dict[i][j] = (tmp0, x_crop)
352
+ se_mean0 /= n_patch
353
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
354
+ if "Half" in x.type():
355
+ se_mean1 = se_mean1.half()
356
+ for i in range(0, h - 36, crop_size[0]):
357
+ for j in range(0, w - 36, crop_size[1]):
358
+ tmp0, x_crop = tmp_dict[i][j]
359
+ x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
360
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
361
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
362
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
363
+ tmp_se_mean = torch.mean(
364
+ tmp_x2.float(), dim=(2, 3), keepdim=True
365
+ ).half()
366
+ else:
367
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
368
+ se_mean1 += tmp_se_mean
369
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
370
+ se_mean1 /= n_patch
371
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
372
+ if "Half" in x.type():
373
+ se_mean0 = se_mean0.half()
374
+ for i in range(0, h - 36, crop_size[0]):
375
+ for j in range(0, w - 36, crop_size[1]):
376
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
377
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
378
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
379
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
380
+ tmp_se_mean = torch.mean(
381
+ tmp_x3.float(), dim=(2, 3), keepdim=True
382
+ ).half()
383
+ else:
384
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
385
+ se_mean0 += tmp_se_mean
386
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
387
+ se_mean0 /= n_patch
388
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
389
+ if "Half" in x.type():
390
+ se_mean1 = se_mean1.half()
391
+ for i in range(0, h - 36, crop_size[0]):
392
+ for j in range(0, w - 36, crop_size[1]):
393
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
394
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
395
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
396
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
397
+ tmp_se_mean = torch.mean(
398
+ tmp_x4.float(), dim=(2, 3), keepdim=True
399
+ ).half()
400
+ else:
401
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
402
+ se_mean1 += tmp_se_mean
403
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
404
+ se_mean1 /= n_patch
405
+ for i in range(0, h - 36, crop_size[0]):
406
+ opt_res_dict[i] = {}
407
+ for j in range(0, w - 36, crop_size[1]):
408
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
409
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
410
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
411
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
412
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
413
+ opt_res_dict[i][j] = x_crop
414
+ del tmp_dict
415
+ torch.cuda.empty_cache()
416
+ res = torch.zeros((n, c, h * 2 - 72, w * 2 - 72)).to(x.device)
417
+ if "Half" in x.type():
418
+ res = res.half()
419
+ for i in range(0, h - 36, crop_size[0]):
420
+ for j in range(0, w - 36, crop_size[1]):
421
+ res[
422
+ :, :, i * 2 : i * 2 + h1 * 2 - 72, j * 2 : j * 2 + w1 * 2 - 72
423
+ ] = opt_res_dict[i][j]
424
+ del opt_res_dict
425
+ torch.cuda.empty_cache()
426
+ if w0 != pw or h0 != ph:
427
+ res = res[:, :, : h0 * 2, : w0 * 2]
428
+ return res #
429
+
430
+
431
+ class UpCunet3x(nn.Module): # 完美tile,全程无损
432
+ def __init__(self, in_channels=3, out_channels=3):
433
+ super(UpCunet3x, self).__init__()
434
+ self.unet1 = UNet1x3(in_channels, out_channels, deconv=True)
435
+ self.unet2 = UNet2(in_channels, out_channels, deconv=False)
436
+
437
+ def forward(self, x, tile_mode): # 1.7G
438
+ n, c, h0, w0 = x.shape
439
+ if tile_mode == 0: # 不tile
440
+ ph = ((h0 - 1) // 4 + 1) * 4
441
+ pw = ((w0 - 1) // 4 + 1) * 4
442
+ x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect") # 需要保证被2整除
443
+ x = self.unet1.forward(x)
444
+ x0 = self.unet2.forward(x)
445
+ x1 = F.pad(x, (-20, -20, -20, -20))
446
+ x = torch.add(x0, x1)
447
+ if w0 != pw or h0 != ph:
448
+ x = x[:, :, : h0 * 3, : w0 * 3]
449
+ return x
450
+ elif tile_mode == 1: # 对长边减半
451
+ if w0 >= h0:
452
+ crop_size_w = ((w0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
453
+ crop_size_h = (h0 - 1) // 4 * 4 + 4 # 能被4整除
454
+ else:
455
+ crop_size_h = ((h0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
456
+ crop_size_w = (w0 - 1) // 4 * 4 + 4 # 能被4整除
457
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
458
+ elif tile_mode == 2: # hw都减半
459
+ crop_size = (
460
+ ((h0 - 1) // 8 * 8 + 8) // 2,
461
+ ((w0 - 1) // 8 * 8 + 8) // 2,
462
+ ) # 5.6G
463
+ elif tile_mode == 3: # hw都三分之一
464
+ crop_size = (
465
+ ((h0 - 1) // 12 * 12 + 12) // 3,
466
+ ((w0 - 1) // 12 * 12 + 12) // 3,
467
+ ) # 4.2G
468
+ elif tile_mode == 4: # hw都四分之一
469
+ crop_size = (
470
+ ((h0 - 1) // 16 * 16 + 16) // 4,
471
+ ((w0 - 1) // 16 * 16 + 16) // 4,
472
+ ) # 3.7G
473
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
474
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
475
+ x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect")
476
+ n, c, h, w = x.shape
477
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
478
+ if "Half" in x.type():
479
+ se_mean0 = se_mean0.half()
480
+ n_patch = 0
481
+ tmp_dict = {}
482
+ opt_res_dict = {}
483
+ for i in range(0, h - 28, crop_size[0]):
484
+ tmp_dict[i] = {}
485
+ for j in range(0, w - 28, crop_size[1]):
486
+ x_crop = x[:, :, i : i + crop_size[0] + 28, j : j + crop_size[1] + 28]
487
+ n, c1, h1, w1 = x_crop.shape
488
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
489
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
490
+ tmp_se_mean = torch.mean(
491
+ x_crop.float(), dim=(2, 3), keepdim=True
492
+ ).half()
493
+ else:
494
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
495
+ se_mean0 += tmp_se_mean
496
+ n_patch += 1
497
+ tmp_dict[i][j] = (tmp0, x_crop)
498
+ se_mean0 /= n_patch
499
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
500
+ if "Half" in x.type():
501
+ se_mean1 = se_mean1.half()
502
+ for i in range(0, h - 28, crop_size[0]):
503
+ for j in range(0, w - 28, crop_size[1]):
504
+ tmp0, x_crop = tmp_dict[i][j]
505
+ x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
506
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
507
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
508
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
509
+ tmp_se_mean = torch.mean(
510
+ tmp_x2.float(), dim=(2, 3), keepdim=True
511
+ ).half()
512
+ else:
513
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
514
+ se_mean1 += tmp_se_mean
515
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
516
+ se_mean1 /= n_patch
517
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
518
+ if "Half" in x.type():
519
+ se_mean0 = se_mean0.half()
520
+ for i in range(0, h - 28, crop_size[0]):
521
+ for j in range(0, w - 28, crop_size[1]):
522
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
523
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
524
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
525
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
526
+ tmp_se_mean = torch.mean(
527
+ tmp_x3.float(), dim=(2, 3), keepdim=True
528
+ ).half()
529
+ else:
530
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
531
+ se_mean0 += tmp_se_mean
532
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
533
+ se_mean0 /= n_patch
534
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
535
+ if "Half" in x.type():
536
+ se_mean1 = se_mean1.half()
537
+ for i in range(0, h - 28, crop_size[0]):
538
+ for j in range(0, w - 28, crop_size[1]):
539
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
540
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
541
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
542
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
543
+ tmp_se_mean = torch.mean(
544
+ tmp_x4.float(), dim=(2, 3), keepdim=True
545
+ ).half()
546
+ else:
547
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
548
+ se_mean1 += tmp_se_mean
549
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
550
+ se_mean1 /= n_patch
551
+ for i in range(0, h - 28, crop_size[0]):
552
+ opt_res_dict[i] = {}
553
+ for j in range(0, w - 28, crop_size[1]):
554
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
555
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
556
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
557
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
558
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
559
+ opt_res_dict[i][j] = x_crop #
560
+ del tmp_dict
561
+ torch.cuda.empty_cache()
562
+ res = torch.zeros((n, c, h * 3 - 84, w * 3 - 84)).to(x.device)
563
+ if "Half" in x.type():
564
+ res = res.half()
565
+ for i in range(0, h - 28, crop_size[0]):
566
+ for j in range(0, w - 28, crop_size[1]):
567
+ res[
568
+ :, :, i * 3 : i * 3 + h1 * 3 - 84, j * 3 : j * 3 + w1 * 3 - 84
569
+ ] = opt_res_dict[i][j]
570
+ del opt_res_dict
571
+ torch.cuda.empty_cache()
572
+ if w0 != pw or h0 != ph:
573
+ res = res[:, :, : h0 * 3, : w0 * 3]
574
+ return res
575
+
576
+
577
+ class UpCunet4x(nn.Module): # 完美tile,全程无损
578
+ def __init__(self, in_channels=3, out_channels=3):
579
+ super(UpCunet4x, self).__init__()
580
+ self.unet1 = UNet1(in_channels, 64, deconv=True)
581
+ self.unet2 = UNet2(64, 64, deconv=False)
582
+ self.ps = nn.PixelShuffle(2)
583
+ self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
584
+
585
+ def forward(self, x, tile_mode):
586
+ n, c, h0, w0 = x.shape
587
+ x00 = x
588
+ if tile_mode == 0: # 不tile
589
+ ph = ((h0 - 1) // 2 + 1) * 2
590
+ pw = ((w0 - 1) // 2 + 1) * 2
591
+ x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect") # 需要保证被2整除
592
+ x = self.unet1.forward(x)
593
+ x0 = self.unet2.forward(x)
594
+ x1 = F.pad(x, (-20, -20, -20, -20))
595
+ x = torch.add(x0, x1)
596
+ x = self.conv_final(x)
597
+ x = F.pad(x, (-1, -1, -1, -1))
598
+ x = self.ps(x)
599
+ if w0 != pw or h0 != ph:
600
+ x = x[:, :, : h0 * 4, : w0 * 4]
601
+ x += F.interpolate(x00, scale_factor=4, mode="nearest")
602
+ return x
603
+ elif tile_mode == 1: # 对长边减半
604
+ if w0 >= h0:
605
+ crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
606
+ crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
607
+ else:
608
+ crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
609
+ crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
610
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
611
+ elif tile_mode == 2: # hw都减半
612
+ crop_size = (
613
+ ((h0 - 1) // 4 * 4 + 4) // 2,
614
+ ((w0 - 1) // 4 * 4 + 4) // 2,
615
+ ) # 5.6G
616
+ elif tile_mode == 3: # hw都三分之一
617
+ crop_size = (
618
+ ((h0 - 1) // 6 * 6 + 6) // 3,
619
+ ((w0 - 1) // 6 * 6 + 6) // 3,
620
+ ) # 4.1G
621
+ elif tile_mode == 4: # hw都四分之一
622
+ crop_size = (
623
+ ((h0 - 1) // 8 * 8 + 8) // 4,
624
+ ((w0 - 1) // 8 * 8 + 8) // 4,
625
+ ) # 3.7G
626
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
627
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
628
+ x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect")
629
+ n, c, h, w = x.shape
630
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
631
+ if "Half" in x.type():
632
+ se_mean0 = se_mean0.half()
633
+ n_patch = 0
634
+ tmp_dict = {}
635
+ opt_res_dict = {}
636
+ for i in range(0, h - 38, crop_size[0]):
637
+ tmp_dict[i] = {}
638
+ for j in range(0, w - 38, crop_size[1]):
639
+ x_crop = x[:, :, i : i + crop_size[0] + 38, j : j + crop_size[1] + 38]
640
+ n, c1, h1, w1 = x_crop.shape
641
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
642
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
643
+ tmp_se_mean = torch.mean(
644
+ x_crop.float(), dim=(2, 3), keepdim=True
645
+ ).half()
646
+ else:
647
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
648
+ se_mean0 += tmp_se_mean
649
+ n_patch += 1
650
+ tmp_dict[i][j] = (tmp0, x_crop)
651
+ se_mean0 /= n_patch
652
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
653
+ if "Half" in x.type():
654
+ se_mean1 = se_mean1.half()
655
+ for i in range(0, h - 38, crop_size[0]):
656
+ for j in range(0, w - 38, crop_size[1]):
657
+ tmp0, x_crop = tmp_dict[i][j]
658
+ x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
659
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
660
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
661
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
662
+ tmp_se_mean = torch.mean(
663
+ tmp_x2.float(), dim=(2, 3), keepdim=True
664
+ ).half()
665
+ else:
666
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
667
+ se_mean1 += tmp_se_mean
668
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
669
+ se_mean1 /= n_patch
670
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
671
+ if "Half" in x.type():
672
+ se_mean0 = se_mean0.half()
673
+ for i in range(0, h - 38, crop_size[0]):
674
+ for j in range(0, w - 38, crop_size[1]):
675
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
676
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
677
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
678
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
679
+ tmp_se_mean = torch.mean(
680
+ tmp_x3.float(), dim=(2, 3), keepdim=True
681
+ ).half()
682
+ else:
683
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
684
+ se_mean0 += tmp_se_mean
685
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
686
+ se_mean0 /= n_patch
687
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
688
+ if "Half" in x.type():
689
+ se_mean1 = se_mean1.half()
690
+ for i in range(0, h - 38, crop_size[0]):
691
+ for j in range(0, w - 38, crop_size[1]):
692
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
693
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
694
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
695
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
696
+ tmp_se_mean = torch.mean(
697
+ tmp_x4.float(), dim=(2, 3), keepdim=True
698
+ ).half()
699
+ else:
700
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
701
+ se_mean1 += tmp_se_mean
702
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
703
+ se_mean1 /= n_patch
704
+ for i in range(0, h - 38, crop_size[0]):
705
+ opt_res_dict[i] = {}
706
+ for j in range(0, w - 38, crop_size[1]):
707
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
708
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
709
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
710
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
711
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
712
+ x_crop = self.conv_final(x_crop)
713
+ x_crop = F.pad(x_crop, (-1, -1, -1, -1))
714
+ x_crop = self.ps(x_crop)
715
+ opt_res_dict[i][j] = x_crop
716
+ del tmp_dict
717
+ torch.cuda.empty_cache()
718
+ res = torch.zeros((n, c, h * 4 - 152, w * 4 - 152)).to(x.device)
719
+ if "Half" in x.type():
720
+ res = res.half()
721
+ for i in range(0, h - 38, crop_size[0]):
722
+ for j in range(0, w - 38, crop_size[1]):
723
+ # print(opt_res_dict[i][j].shape,res[:, :, i * 4:i * 4 + h1 * 4 - 144, j * 4:j * 4 + w1 * 4 - 144].shape)
724
+ res[
725
+ :, :, i * 4 : i * 4 + h1 * 4 - 152, j * 4 : j * 4 + w1 * 4 - 152
726
+ ] = opt_res_dict[i][j]
727
+ del opt_res_dict
728
+ torch.cuda.empty_cache()
729
+ if w0 != pw or h0 != ph:
730
+ res = res[:, :, : h0 * 4, : w0 * 4]
731
+ res += F.interpolate(x00, scale_factor=4, mode="nearest")
732
+ return res #
733
+
734
+
735
+ models: Dict[str, Type[nn.Module]] = {
736
+ obj.__name__: obj
737
+ for obj in globals().values()
738
+ if isinstance(obj, type) and issubclass(obj, nn.Module)
739
+ }
740
+
741
+
742
+ class RealWaifuUpScaler:
743
+ def __init__(self, scale: int, weight_path: str, half: bool, device: str):
744
+ weight = torch.load(weight_path, map_location=device)
745
+ self.model = models[f"UpCunet{scale}x"]()
746
+
747
+ if half == True:
748
+ self.model = self.model.half().to(device)
749
+ else:
750
+ self.model = self.model.to(device)
751
+
752
+ self.model.load_state_dict(weight, strict=True)
753
+ self.model.eval()
754
+
755
+ self.half = half
756
+ self.device = device
757
+
758
+ def np2tensor(self, np_frame):
759
+ if self.half == False:
760
+ return (
761
+ torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
762
+ .unsqueeze(0)
763
+ .to(self.device)
764
+ .float()
765
+ / 255
766
+ )
767
+ else:
768
+ return (
769
+ torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
770
+ .unsqueeze(0)
771
+ .to(self.device)
772
+ .half()
773
+ / 255
774
+ )
775
+
776
+ def tensor2np(self, tensor):
777
+ if self.half == False:
778
+ return np.transpose(
779
+ (tensor.data.squeeze() * 255.0)
780
+ .round()
781
+ .clamp_(0, 255)
782
+ .byte()
783
+ .cpu()
784
+ .numpy(),
785
+ (1, 2, 0),
786
+ )
787
+ else:
788
+ return np.transpose(
789
+ (tensor.data.squeeze().float() * 255.0)
790
+ .round()
791
+ .clamp_(0, 255)
792
+ .byte()
793
+ .cpu()
794
+ .numpy(),
795
+ (1, 2, 0),
796
+ )
797
+
798
+ def __call__(self, frame, tile_mode):
799
+ with torch.no_grad():
800
+ tensor = self.np2tensor(frame)
801
+ result = self.tensor2np(self.model(tensor, tile_mode))
802
+ return result
803
+
804
+
805
+ input_image = inputs.File(label="Input image")
806
+ half_precision = inputs.Checkbox(
807
+ label="Half precision (NOT work for CPU)", default=False
808
+ )
809
+ model_weight = inputs.Dropdown(sorted(AVALIABLE_WEIGHTS), label="Choice model weight")
810
+ tile_mode = inputs.Radio([mode.name for mode in TileMode], label="Output tile mode")
811
+
812
+ output_image = outputs.Image(label="Output image preview")
813
+ output_file = outputs.File(label="Output image file")
814
+
815
+
816
+ def main(file: IO[bytes], half: bool, weight: str, tile: str):
817
+ scale = next(mode.value for mode in ScaleMode if weight.startswith(mode.name))
818
+ upscaler = RealWaifuUpScaler(
819
+ scale, weight_path=str(AVALIABLE_WEIGHTS[weight]), half=half, device=DEVICE
820
+ )
821
+
822
+ frame = cv2.imread(file.name)
823
+ result = upscaler(frame[:, :, [2, 1, 0]], TileMode[tile])
824
+
825
+ _, ext = os.path.splitext(file.name)
826
+ tempfile = mktemp(suffix=ext)
827
+ cv2.imwrite(tempfile, result)
828
+ return result, tempfile
829
+
830
+
831
+ interface = Interface(
832
+ main,
833
+ inputs=[input_image, half_precision, model_weight, tile_mode],
834
+ outputs=[output_image, output_file],
835
+ )
836
+ interface.launch()