manhkhanhUIT commited on
Commit
7fab858
·
1 Parent(s): a8b38c0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CODE_OF_CONDUCT.md +9 -0
  2. Dockerfile +43 -0
  3. Face_Detection/align_warp_back_multiple_dlib.py +437 -0
  4. Face_Detection/align_warp_back_multiple_dlib_HR.py +437 -0
  5. Face_Detection/detect_all_dlib.py +184 -0
  6. Face_Detection/detect_all_dlib_HR.py +184 -0
  7. Face_Enhancement/data/__init__.py +22 -0
  8. Face_Enhancement/data/base_dataset.py +125 -0
  9. Face_Enhancement/data/custom_dataset.py +56 -0
  10. Face_Enhancement/data/face_dataset.py +102 -0
  11. Face_Enhancement/data/image_folder.py +101 -0
  12. Face_Enhancement/data/pix2pix_dataset.py +108 -0
  13. Face_Enhancement/models/__init__.py +44 -0
  14. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/LICENSE +21 -0
  15. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/README.md +118 -0
  16. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py +14 -0
  17. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py +412 -0
  18. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py +74 -0
  19. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py +137 -0
  20. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py +94 -0
  21. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py +29 -0
  22. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py +56 -0
  23. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py +62 -0
  24. Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py +114 -0
  25. Face_Enhancement/models/networks/__init__.py +58 -0
  26. Face_Enhancement/models/networks/architecture.py +173 -0
  27. Face_Enhancement/models/networks/base_network.py +58 -0
  28. Face_Enhancement/models/networks/encoder.py +53 -0
  29. Face_Enhancement/models/networks/generator.py +233 -0
  30. Face_Enhancement/models/networks/normalization.py +100 -0
  31. Face_Enhancement/models/networks/sync_batchnorm/__init__.py +14 -0
  32. Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py +412 -0
  33. Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py +74 -0
  34. Face_Enhancement/models/networks/sync_batchnorm/comm.py +137 -0
  35. Face_Enhancement/models/networks/sync_batchnorm/replicate.py +94 -0
  36. Face_Enhancement/models/networks/sync_batchnorm/unittest.py +29 -0
  37. Face_Enhancement/models/pix2pix_model.py +246 -0
  38. Face_Enhancement/options/__init__.py +2 -0
  39. Face_Enhancement/options/base_options.py +294 -0
  40. Face_Enhancement/options/test_options.py +26 -0
  41. Face_Enhancement/requirements.txt +9 -0
  42. Face_Enhancement/test_face.py +45 -0
  43. Face_Enhancement/util/__init__.py +2 -0
  44. Face_Enhancement/util/iter_counter.py +74 -0
  45. Face_Enhancement/util/util.py +210 -0
  46. Face_Enhancement/util/visualizer.py +134 -0
  47. GUI.py +217 -0
  48. Global/data/Create_Bigfile.py +63 -0
  49. Global/data/Load_Bigfile.py +42 -0
  50. Global/data/__init__.py +0 -0
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [[email protected]](mailto:[email protected]) with questions or concerns
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.1-base-ubuntu20.04
2
+
3
+ RUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzip python3-pip python3-dev cmake libgl1-mesa-dev python-is-python3 libgtk2.0-dev -yq
4
+ ADD . /app
5
+ WORKDIR /app
6
+ RUN cd Face_Enhancement/models/networks/ &&\
7
+ git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
8
+ cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
9
+ cd ../../../
10
+
11
+ RUN cd Global/detection_models &&\
12
+ git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
13
+ cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
14
+ cd ../../
15
+
16
+ RUN cd Face_Detection/ &&\
17
+ wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 &&\
18
+ bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 &&\
19
+ cd ../
20
+
21
+ RUN cd Face_Enhancement/ &&\
22
+ wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip &&\
23
+ unzip checkpoints.zip &&\
24
+ cd ../ &&\
25
+ cd Global/ &&\
26
+ wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip &&\
27
+ unzip checkpoints.zip &&\
28
+ rm -f checkpoints.zip &&\
29
+ cd ../
30
+
31
+ RUN pip3 install numpy
32
+
33
+ RUN pip3 install dlib
34
+
35
+ RUN pip3 install -r requirements.txt
36
+
37
+ RUN git clone https://github.com/NVlabs/SPADE.git
38
+
39
+ RUN cd SPADE/ && pip3 install -r requirements.txt
40
+
41
+ RUN cd ..
42
+
43
+ CMD ["python3", "run.py"]
Face_Detection/align_warp_back_multiple_dlib.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ import numpy as np
6
+ import skimage.io as io
7
+
8
+ # from face_sdk import FaceDetection
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.patches import Rectangle
11
+ from skimage.transform import SimilarityTransform
12
+ from skimage.transform import warp
13
+ from PIL import Image, ImageFilter
14
+ import torch.nn.functional as F
15
+ import torchvision as tv
16
+ import torchvision.utils as vutils
17
+ import time
18
+ import cv2
19
+ import os
20
+ from skimage import img_as_ubyte
21
+ import json
22
+ import argparse
23
+ import dlib
24
+
25
+
26
+ def calculate_cdf(histogram):
27
+ """
28
+ This method calculates the cumulative distribution function
29
+ :param array histogram: The values of the histogram
30
+ :return: normalized_cdf: The normalized cumulative distribution function
31
+ :rtype: array
32
+ """
33
+ # Get the cumulative sum of the elements
34
+ cdf = histogram.cumsum()
35
+
36
+ # Normalize the cdf
37
+ normalized_cdf = cdf / float(cdf.max())
38
+
39
+ return normalized_cdf
40
+
41
+
42
+ def calculate_lookup(src_cdf, ref_cdf):
43
+ """
44
+ This method creates the lookup table
45
+ :param array src_cdf: The cdf for the source image
46
+ :param array ref_cdf: The cdf for the reference image
47
+ :return: lookup_table: The lookup table
48
+ :rtype: array
49
+ """
50
+ lookup_table = np.zeros(256)
51
+ lookup_val = 0
52
+ for src_pixel_val in range(len(src_cdf)):
53
+ lookup_val
54
+ for ref_pixel_val in range(len(ref_cdf)):
55
+ if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
56
+ lookup_val = ref_pixel_val
57
+ break
58
+ lookup_table[src_pixel_val] = lookup_val
59
+ return lookup_table
60
+
61
+
62
+ def match_histograms(src_image, ref_image):
63
+ """
64
+ This method matches the source image histogram to the
65
+ reference signal
66
+ :param image src_image: The original source image
67
+ :param image ref_image: The reference image
68
+ :return: image_after_matching
69
+ :rtype: image (array)
70
+ """
71
+ # Split the images into the different color channels
72
+ # b means blue, g means green and r means red
73
+ src_b, src_g, src_r = cv2.split(src_image)
74
+ ref_b, ref_g, ref_r = cv2.split(ref_image)
75
+
76
+ # Compute the b, g, and r histograms separately
77
+ # The flatten() Numpy method returns a copy of the array c
78
+ # collapsed into one dimension.
79
+ src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
80
+ src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
81
+ src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
82
+ ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
83
+ ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
84
+ ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
85
+
86
+ # Compute the normalized cdf for the source and reference image
87
+ src_cdf_blue = calculate_cdf(src_hist_blue)
88
+ src_cdf_green = calculate_cdf(src_hist_green)
89
+ src_cdf_red = calculate_cdf(src_hist_red)
90
+ ref_cdf_blue = calculate_cdf(ref_hist_blue)
91
+ ref_cdf_green = calculate_cdf(ref_hist_green)
92
+ ref_cdf_red = calculate_cdf(ref_hist_red)
93
+
94
+ # Make a separate lookup table for each color
95
+ blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
96
+ green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
97
+ red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
98
+
99
+ # Use the lookup function to transform the colors of the original
100
+ # source image
101
+ blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
102
+ green_after_transform = cv2.LUT(src_g, green_lookup_table)
103
+ red_after_transform = cv2.LUT(src_r, red_lookup_table)
104
+
105
+ # Put the image back together
106
+ image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
107
+ image_after_matching = cv2.convertScaleAbs(image_after_matching)
108
+
109
+ return image_after_matching
110
+
111
+
112
+ def _standard_face_pts():
113
+ pts = (
114
+ np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
115
+ - 1.0
116
+ )
117
+
118
+ return np.reshape(pts, (5, 2))
119
+
120
+
121
+ def _origin_face_pts():
122
+ pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
123
+
124
+ return np.reshape(pts, (5, 2))
125
+
126
+
127
+ def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
128
+
129
+ std_pts = _standard_face_pts() # [-1,1]
130
+ target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
131
+
132
+ # print(target_pts)
133
+
134
+ h, w, c = img.shape
135
+ if normalize == True:
136
+ landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
137
+ landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
138
+
139
+ # print(landmark)
140
+
141
+ affine = SimilarityTransform()
142
+
143
+ affine.estimate(target_pts, landmark)
144
+
145
+ return affine
146
+
147
+
148
+ def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
149
+
150
+ std_pts = _standard_face_pts() # [-1,1]
151
+ target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
152
+
153
+ # print(target_pts)
154
+
155
+ h, w, c = img.shape
156
+ if normalize == True:
157
+ landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
158
+ landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
159
+
160
+ # print(landmark)
161
+
162
+ affine = SimilarityTransform()
163
+
164
+ affine.estimate(landmark, target_pts)
165
+
166
+ return affine
167
+
168
+
169
+ def show_detection(image, box, landmark):
170
+ plt.imshow(image)
171
+ print(box[2] - box[0])
172
+ plt.gca().add_patch(
173
+ Rectangle(
174
+ (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
175
+ )
176
+ )
177
+ plt.scatter(landmark[0][0], landmark[0][1])
178
+ plt.scatter(landmark[1][0], landmark[1][1])
179
+ plt.scatter(landmark[2][0], landmark[2][1])
180
+ plt.scatter(landmark[3][0], landmark[3][1])
181
+ plt.scatter(landmark[4][0], landmark[4][1])
182
+ plt.show()
183
+
184
+
185
+ def affine2theta(affine, input_w, input_h, target_w, target_h):
186
+ # param = np.linalg.inv(affine)
187
+ param = affine
188
+ theta = np.zeros([2, 3])
189
+ theta[0, 0] = param[0, 0] * input_h / target_h
190
+ theta[0, 1] = param[0, 1] * input_w / target_h
191
+ theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
192
+ theta[1, 0] = param[1, 0] * input_h / target_w
193
+ theta[1, 1] = param[1, 1] * input_w / target_w
194
+ theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
195
+ return theta
196
+
197
+
198
+ def blur_blending(im1, im2, mask):
199
+
200
+ mask *= 255.0
201
+
202
+ kernel = np.ones((10, 10), np.uint8)
203
+ mask = cv2.erode(mask, kernel, iterations=1)
204
+
205
+ mask = Image.fromarray(mask.astype("uint8")).convert("L")
206
+ im1 = Image.fromarray(im1.astype("uint8"))
207
+ im2 = Image.fromarray(im2.astype("uint8"))
208
+
209
+ mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
210
+ im = Image.composite(im1, im2, mask)
211
+
212
+ im = Image.composite(im, im2, mask_blur)
213
+
214
+ return np.array(im) / 255.0
215
+
216
+
217
+ def blur_blending_cv2(im1, im2, mask):
218
+
219
+ mask *= 255.0
220
+
221
+ kernel = np.ones((9, 9), np.uint8)
222
+ mask = cv2.erode(mask, kernel, iterations=3)
223
+
224
+ mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
225
+ mask_blur /= 255.0
226
+
227
+ im = im1 * mask_blur + (1 - mask_blur) * im2
228
+
229
+ im /= 255.0
230
+ im = np.clip(im, 0.0, 1.0)
231
+
232
+ return im
233
+
234
+
235
+ # def Poisson_blending(im1,im2,mask):
236
+
237
+
238
+ # Image.composite(
239
+ def Poisson_blending(im1, im2, mask):
240
+
241
+ # mask=1-mask
242
+ mask *= 255
243
+ kernel = np.ones((10, 10), np.uint8)
244
+ mask = cv2.erode(mask, kernel, iterations=1)
245
+ mask /= 255
246
+ mask = 1 - mask
247
+ mask *= 255
248
+
249
+ mask = mask[:, :, 0]
250
+ width, height, channels = im1.shape
251
+ center = (int(height / 2), int(width / 2))
252
+ result = cv2.seamlessClone(
253
+ im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
254
+ )
255
+
256
+ return result / 255.0
257
+
258
+
259
+ def Poisson_B(im1, im2, mask, center):
260
+
261
+ mask *= 255
262
+
263
+ result = cv2.seamlessClone(
264
+ im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
265
+ )
266
+
267
+ return result / 255
268
+
269
+
270
+ def seamless_clone(old_face, new_face, raw_mask):
271
+
272
+ height, width, _ = old_face.shape
273
+ height = height // 2
274
+ width = width // 2
275
+
276
+ y_indices, x_indices, _ = np.nonzero(raw_mask)
277
+ y_crop = slice(np.min(y_indices), np.max(y_indices))
278
+ x_crop = slice(np.min(x_indices), np.max(x_indices))
279
+ y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
280
+ x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
281
+
282
+ insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
283
+ insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
284
+ insertion_mask[insertion_mask != 0] = 255
285
+ prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
286
+ "uint8"
287
+ )
288
+ # if np.sum(insertion_mask) == 0:
289
+ n_mask = insertion_mask[1:-1, 1:-1, :]
290
+ n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
291
+ print(n_mask.shape)
292
+ x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
293
+ if w < 4 or h < 4:
294
+ blended = prior
295
+ else:
296
+ blended = cv2.seamlessClone(
297
+ insertion, # pylint: disable=no-member
298
+ prior,
299
+ insertion_mask,
300
+ (x_center, y_center),
301
+ cv2.NORMAL_CLONE,
302
+ ) # pylint: disable=no-member
303
+
304
+ blended = blended[height:-height, width:-width]
305
+
306
+ return blended.astype("float32") / 255.0
307
+
308
+
309
+ def get_landmark(face_landmarks, id):
310
+ part = face_landmarks.part(id)
311
+ x = part.x
312
+ y = part.y
313
+
314
+ return (x, y)
315
+
316
+
317
+ def search(face_landmarks):
318
+
319
+ x1, y1 = get_landmark(face_landmarks, 36)
320
+ x2, y2 = get_landmark(face_landmarks, 39)
321
+ x3, y3 = get_landmark(face_landmarks, 42)
322
+ x4, y4 = get_landmark(face_landmarks, 45)
323
+
324
+ x_nose, y_nose = get_landmark(face_landmarks, 30)
325
+
326
+ x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
327
+ x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
328
+
329
+ x_left_eye = int((x1 + x2) / 2)
330
+ y_left_eye = int((y1 + y2) / 2)
331
+ x_right_eye = int((x3 + x4) / 2)
332
+ y_right_eye = int((y3 + y4) / 2)
333
+
334
+ results = np.array(
335
+ [
336
+ [x_left_eye, y_left_eye],
337
+ [x_right_eye, y_right_eye],
338
+ [x_nose, y_nose],
339
+ [x_left_mouth, y_left_mouth],
340
+ [x_right_mouth, y_right_mouth],
341
+ ]
342
+ )
343
+
344
+ return results
345
+
346
+
347
+ if __name__ == "__main__":
348
+
349
+ parser = argparse.ArgumentParser()
350
+ parser.add_argument("--origin_url", type=str, default="./", help="origin images")
351
+ parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
352
+ parser.add_argument("--save_url", type=str, default="./save")
353
+ opts = parser.parse_args()
354
+
355
+ origin_url = opts.origin_url
356
+ replace_url = opts.replace_url
357
+ save_url = opts.save_url
358
+
359
+ if not os.path.exists(save_url):
360
+ os.makedirs(save_url)
361
+
362
+ face_detector = dlib.get_frontal_face_detector()
363
+ landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
364
+
365
+ count = 0
366
+
367
+ for x in os.listdir(origin_url):
368
+ img_url = os.path.join(origin_url, x)
369
+ pil_img = Image.open(img_url).convert("RGB")
370
+
371
+ origin_width, origin_height = pil_img.size
372
+ image = np.array(pil_img)
373
+
374
+ start = time.time()
375
+ faces = face_detector(image)
376
+ done = time.time()
377
+
378
+ if len(faces) == 0:
379
+ print("Warning: There is no face in %s" % (x))
380
+ continue
381
+
382
+ blended = image
383
+ for face_id in range(len(faces)):
384
+
385
+ current_face = faces[face_id]
386
+ face_landmarks = landmark_locator(image, current_face)
387
+ current_fl = search(face_landmarks)
388
+
389
+ forward_mask = np.ones_like(image).astype("uint8")
390
+ affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
391
+ aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True)
392
+ forward_mask = warp(
393
+ forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True
394
+ )
395
+
396
+ affine_inverse = affine.inverse
397
+ cur_face = aligned_face
398
+ if replace_url != "":
399
+
400
+ face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
401
+ cur_url = os.path.join(replace_url, face_name)
402
+ restored_face = Image.open(cur_url).convert("RGB")
403
+ restored_face = np.array(restored_face)
404
+ cur_face = restored_face
405
+
406
+ ## Histogram Color matching
407
+ A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
408
+ B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
409
+ B = match_histograms(B, A)
410
+ cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
411
+
412
+ warped_back = warp(
413
+ cur_face,
414
+ affine_inverse,
415
+ output_shape=(origin_height, origin_width, 3),
416
+ order=3,
417
+ preserve_range=True,
418
+ )
419
+
420
+ backward_mask = warp(
421
+ forward_mask,
422
+ affine_inverse,
423
+ output_shape=(origin_height, origin_width, 3),
424
+ order=0,
425
+ preserve_range=True,
426
+ ) ## Nearest neighbour
427
+
428
+ blended = blur_blending_cv2(warped_back, blended, backward_mask)
429
+ blended *= 255.0
430
+
431
+ io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))
432
+
433
+ count += 1
434
+
435
+ if count % 1000 == 0:
436
+ print("%d have finished ..." % (count))
437
+
Face_Detection/align_warp_back_multiple_dlib_HR.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ import numpy as np
6
+ import skimage.io as io
7
+
8
+ # from face_sdk import FaceDetection
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.patches import Rectangle
11
+ from skimage.transform import SimilarityTransform
12
+ from skimage.transform import warp
13
+ from PIL import Image, ImageFilter
14
+ import torch.nn.functional as F
15
+ import torchvision as tv
16
+ import torchvision.utils as vutils
17
+ import time
18
+ import cv2
19
+ import os
20
+ from skimage import img_as_ubyte
21
+ import json
22
+ import argparse
23
+ import dlib
24
+
25
+
26
+ def calculate_cdf(histogram):
27
+ """
28
+ This method calculates the cumulative distribution function
29
+ :param array histogram: The values of the histogram
30
+ :return: normalized_cdf: The normalized cumulative distribution function
31
+ :rtype: array
32
+ """
33
+ # Get the cumulative sum of the elements
34
+ cdf = histogram.cumsum()
35
+
36
+ # Normalize the cdf
37
+ normalized_cdf = cdf / float(cdf.max())
38
+
39
+ return normalized_cdf
40
+
41
+
42
+ def calculate_lookup(src_cdf, ref_cdf):
43
+ """
44
+ This method creates the lookup table
45
+ :param array src_cdf: The cdf for the source image
46
+ :param array ref_cdf: The cdf for the reference image
47
+ :return: lookup_table: The lookup table
48
+ :rtype: array
49
+ """
50
+ lookup_table = np.zeros(256)
51
+ lookup_val = 0
52
+ for src_pixel_val in range(len(src_cdf)):
53
+ lookup_val
54
+ for ref_pixel_val in range(len(ref_cdf)):
55
+ if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
56
+ lookup_val = ref_pixel_val
57
+ break
58
+ lookup_table[src_pixel_val] = lookup_val
59
+ return lookup_table
60
+
61
+
62
+ def match_histograms(src_image, ref_image):
63
+ """
64
+ This method matches the source image histogram to the
65
+ reference signal
66
+ :param image src_image: The original source image
67
+ :param image ref_image: The reference image
68
+ :return: image_after_matching
69
+ :rtype: image (array)
70
+ """
71
+ # Split the images into the different color channels
72
+ # b means blue, g means green and r means red
73
+ src_b, src_g, src_r = cv2.split(src_image)
74
+ ref_b, ref_g, ref_r = cv2.split(ref_image)
75
+
76
+ # Compute the b, g, and r histograms separately
77
+ # The flatten() Numpy method returns a copy of the array c
78
+ # collapsed into one dimension.
79
+ src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
80
+ src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
81
+ src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
82
+ ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
83
+ ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
84
+ ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
85
+
86
+ # Compute the normalized cdf for the source and reference image
87
+ src_cdf_blue = calculate_cdf(src_hist_blue)
88
+ src_cdf_green = calculate_cdf(src_hist_green)
89
+ src_cdf_red = calculate_cdf(src_hist_red)
90
+ ref_cdf_blue = calculate_cdf(ref_hist_blue)
91
+ ref_cdf_green = calculate_cdf(ref_hist_green)
92
+ ref_cdf_red = calculate_cdf(ref_hist_red)
93
+
94
+ # Make a separate lookup table for each color
95
+ blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
96
+ green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
97
+ red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
98
+
99
+ # Use the lookup function to transform the colors of the original
100
+ # source image
101
+ blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
102
+ green_after_transform = cv2.LUT(src_g, green_lookup_table)
103
+ red_after_transform = cv2.LUT(src_r, red_lookup_table)
104
+
105
+ # Put the image back together
106
+ image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
107
+ image_after_matching = cv2.convertScaleAbs(image_after_matching)
108
+
109
+ return image_after_matching
110
+
111
+
112
+ def _standard_face_pts():
113
+ pts = (
114
+ np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
115
+ - 1.0
116
+ )
117
+
118
+ return np.reshape(pts, (5, 2))
119
+
120
+
121
+ def _origin_face_pts():
122
+ pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
123
+
124
+ return np.reshape(pts, (5, 2))
125
+
126
+
127
+ def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
128
+
129
+ std_pts = _standard_face_pts() # [-1,1]
130
+ target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
131
+
132
+ # print(target_pts)
133
+
134
+ h, w, c = img.shape
135
+ if normalize == True:
136
+ landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
137
+ landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
138
+
139
+ # print(landmark)
140
+
141
+ affine = SimilarityTransform()
142
+
143
+ affine.estimate(target_pts, landmark)
144
+
145
+ return affine
146
+
147
+
148
+ def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
149
+
150
+ std_pts = _standard_face_pts() # [-1,1]
151
+ target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
152
+
153
+ # print(target_pts)
154
+
155
+ h, w, c = img.shape
156
+ if normalize == True:
157
+ landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
158
+ landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
159
+
160
+ # print(landmark)
161
+
162
+ affine = SimilarityTransform()
163
+
164
+ affine.estimate(landmark, target_pts)
165
+
166
+ return affine
167
+
168
+
169
+ def show_detection(image, box, landmark):
170
+ plt.imshow(image)
171
+ print(box[2] - box[0])
172
+ plt.gca().add_patch(
173
+ Rectangle(
174
+ (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
175
+ )
176
+ )
177
+ plt.scatter(landmark[0][0], landmark[0][1])
178
+ plt.scatter(landmark[1][0], landmark[1][1])
179
+ plt.scatter(landmark[2][0], landmark[2][1])
180
+ plt.scatter(landmark[3][0], landmark[3][1])
181
+ plt.scatter(landmark[4][0], landmark[4][1])
182
+ plt.show()
183
+
184
+
185
+ def affine2theta(affine, input_w, input_h, target_w, target_h):
186
+ # param = np.linalg.inv(affine)
187
+ param = affine
188
+ theta = np.zeros([2, 3])
189
+ theta[0, 0] = param[0, 0] * input_h / target_h
190
+ theta[0, 1] = param[0, 1] * input_w / target_h
191
+ theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
192
+ theta[1, 0] = param[1, 0] * input_h / target_w
193
+ theta[1, 1] = param[1, 1] * input_w / target_w
194
+ theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
195
+ return theta
196
+
197
+
198
+ def blur_blending(im1, im2, mask):
199
+
200
+ mask *= 255.0
201
+
202
+ kernel = np.ones((10, 10), np.uint8)
203
+ mask = cv2.erode(mask, kernel, iterations=1)
204
+
205
+ mask = Image.fromarray(mask.astype("uint8")).convert("L")
206
+ im1 = Image.fromarray(im1.astype("uint8"))
207
+ im2 = Image.fromarray(im2.astype("uint8"))
208
+
209
+ mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
210
+ im = Image.composite(im1, im2, mask)
211
+
212
+ im = Image.composite(im, im2, mask_blur)
213
+
214
+ return np.array(im) / 255.0
215
+
216
+
217
+ def blur_blending_cv2(im1, im2, mask):
218
+
219
+ mask *= 255.0
220
+
221
+ kernel = np.ones((9, 9), np.uint8)
222
+ mask = cv2.erode(mask, kernel, iterations=3)
223
+
224
+ mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
225
+ mask_blur /= 255.0
226
+
227
+ im = im1 * mask_blur + (1 - mask_blur) * im2
228
+
229
+ im /= 255.0
230
+ im = np.clip(im, 0.0, 1.0)
231
+
232
+ return im
233
+
234
+
235
+ # def Poisson_blending(im1,im2,mask):
236
+
237
+
238
+ # Image.composite(
239
+ def Poisson_blending(im1, im2, mask):
240
+
241
+ # mask=1-mask
242
+ mask *= 255
243
+ kernel = np.ones((10, 10), np.uint8)
244
+ mask = cv2.erode(mask, kernel, iterations=1)
245
+ mask /= 255
246
+ mask = 1 - mask
247
+ mask *= 255
248
+
249
+ mask = mask[:, :, 0]
250
+ width, height, channels = im1.shape
251
+ center = (int(height / 2), int(width / 2))
252
+ result = cv2.seamlessClone(
253
+ im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
254
+ )
255
+
256
+ return result / 255.0
257
+
258
+
259
+ def Poisson_B(im1, im2, mask, center):
260
+
261
+ mask *= 255
262
+
263
+ result = cv2.seamlessClone(
264
+ im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
265
+ )
266
+
267
+ return result / 255
268
+
269
+
270
+ def seamless_clone(old_face, new_face, raw_mask):
271
+
272
+ height, width, _ = old_face.shape
273
+ height = height // 2
274
+ width = width // 2
275
+
276
+ y_indices, x_indices, _ = np.nonzero(raw_mask)
277
+ y_crop = slice(np.min(y_indices), np.max(y_indices))
278
+ x_crop = slice(np.min(x_indices), np.max(x_indices))
279
+ y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
280
+ x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
281
+
282
+ insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
283
+ insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
284
+ insertion_mask[insertion_mask != 0] = 255
285
+ prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
286
+ "uint8"
287
+ )
288
+ # if np.sum(insertion_mask) == 0:
289
+ n_mask = insertion_mask[1:-1, 1:-1, :]
290
+ n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
291
+ print(n_mask.shape)
292
+ x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
293
+ if w < 4 or h < 4:
294
+ blended = prior
295
+ else:
296
+ blended = cv2.seamlessClone(
297
+ insertion, # pylint: disable=no-member
298
+ prior,
299
+ insertion_mask,
300
+ (x_center, y_center),
301
+ cv2.NORMAL_CLONE,
302
+ ) # pylint: disable=no-member
303
+
304
+ blended = blended[height:-height, width:-width]
305
+
306
+ return blended.astype("float32") / 255.0
307
+
308
+
309
+ def get_landmark(face_landmarks, id):
310
+ part = face_landmarks.part(id)
311
+ x = part.x
312
+ y = part.y
313
+
314
+ return (x, y)
315
+
316
+
317
+ def search(face_landmarks):
318
+
319
+ x1, y1 = get_landmark(face_landmarks, 36)
320
+ x2, y2 = get_landmark(face_landmarks, 39)
321
+ x3, y3 = get_landmark(face_landmarks, 42)
322
+ x4, y4 = get_landmark(face_landmarks, 45)
323
+
324
+ x_nose, y_nose = get_landmark(face_landmarks, 30)
325
+
326
+ x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
327
+ x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
328
+
329
+ x_left_eye = int((x1 + x2) / 2)
330
+ y_left_eye = int((y1 + y2) / 2)
331
+ x_right_eye = int((x3 + x4) / 2)
332
+ y_right_eye = int((y3 + y4) / 2)
333
+
334
+ results = np.array(
335
+ [
336
+ [x_left_eye, y_left_eye],
337
+ [x_right_eye, y_right_eye],
338
+ [x_nose, y_nose],
339
+ [x_left_mouth, y_left_mouth],
340
+ [x_right_mouth, y_right_mouth],
341
+ ]
342
+ )
343
+
344
+ return results
345
+
346
+
347
+ if __name__ == "__main__":
348
+
349
+ parser = argparse.ArgumentParser()
350
+ parser.add_argument("--origin_url", type=str, default="./", help="origin images")
351
+ parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
352
+ parser.add_argument("--save_url", type=str, default="./save")
353
+ opts = parser.parse_args()
354
+
355
+ origin_url = opts.origin_url
356
+ replace_url = opts.replace_url
357
+ save_url = opts.save_url
358
+
359
+ if not os.path.exists(save_url):
360
+ os.makedirs(save_url)
361
+
362
+ face_detector = dlib.get_frontal_face_detector()
363
+ landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
364
+
365
+ count = 0
366
+
367
+ for x in os.listdir(origin_url):
368
+ img_url = os.path.join(origin_url, x)
369
+ pil_img = Image.open(img_url).convert("RGB")
370
+
371
+ origin_width, origin_height = pil_img.size
372
+ image = np.array(pil_img)
373
+
374
+ start = time.time()
375
+ faces = face_detector(image)
376
+ done = time.time()
377
+
378
+ if len(faces) == 0:
379
+ print("Warning: There is no face in %s" % (x))
380
+ continue
381
+
382
+ blended = image
383
+ for face_id in range(len(faces)):
384
+
385
+ current_face = faces[face_id]
386
+ face_landmarks = landmark_locator(image, current_face)
387
+ current_fl = search(face_landmarks)
388
+
389
+ forward_mask = np.ones_like(image).astype("uint8")
390
+ affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
391
+ aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True)
392
+ forward_mask = warp(
393
+ forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True
394
+ )
395
+
396
+ affine_inverse = affine.inverse
397
+ cur_face = aligned_face
398
+ if replace_url != "":
399
+
400
+ face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
401
+ cur_url = os.path.join(replace_url, face_name)
402
+ restored_face = Image.open(cur_url).convert("RGB")
403
+ restored_face = np.array(restored_face)
404
+ cur_face = restored_face
405
+
406
+ ## Histogram Color matching
407
+ A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
408
+ B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
409
+ B = match_histograms(B, A)
410
+ cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
411
+
412
+ warped_back = warp(
413
+ cur_face,
414
+ affine_inverse,
415
+ output_shape=(origin_height, origin_width, 3),
416
+ order=3,
417
+ preserve_range=True,
418
+ )
419
+
420
+ backward_mask = warp(
421
+ forward_mask,
422
+ affine_inverse,
423
+ output_shape=(origin_height, origin_width, 3),
424
+ order=0,
425
+ preserve_range=True,
426
+ ) ## Nearest neighbour
427
+
428
+ blended = blur_blending_cv2(warped_back, blended, backward_mask)
429
+ blended *= 255.0
430
+
431
+ io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))
432
+
433
+ count += 1
434
+
435
+ if count % 1000 == 0:
436
+ print("%d have finished ..." % (count))
437
+
Face_Detection/detect_all_dlib.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ import numpy as np
6
+ import skimage.io as io
7
+
8
+ # from FaceSDK.face_sdk import FaceDetection
9
+ # from face_sdk import FaceDetection
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.patches import Rectangle
12
+ from skimage.transform import SimilarityTransform
13
+ from skimage.transform import warp
14
+ from PIL import Image
15
+ import torch.nn.functional as F
16
+ import torchvision as tv
17
+ import torchvision.utils as vutils
18
+ import time
19
+ import cv2
20
+ import os
21
+ from skimage import img_as_ubyte
22
+ import json
23
+ import argparse
24
+ import dlib
25
+
26
+
27
+ def _standard_face_pts():
28
+ pts = (
29
+ np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
30
+ - 1.0
31
+ )
32
+
33
+ return np.reshape(pts, (5, 2))
34
+
35
+
36
+ def _origin_face_pts():
37
+ pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
38
+
39
+ return np.reshape(pts, (5, 2))
40
+
41
+
42
+ def get_landmark(face_landmarks, id):
43
+ part = face_landmarks.part(id)
44
+ x = part.x
45
+ y = part.y
46
+
47
+ return (x, y)
48
+
49
+
50
+ def search(face_landmarks):
51
+
52
+ x1, y1 = get_landmark(face_landmarks, 36)
53
+ x2, y2 = get_landmark(face_landmarks, 39)
54
+ x3, y3 = get_landmark(face_landmarks, 42)
55
+ x4, y4 = get_landmark(face_landmarks, 45)
56
+
57
+ x_nose, y_nose = get_landmark(face_landmarks, 30)
58
+
59
+ x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
60
+ x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
61
+
62
+ x_left_eye = int((x1 + x2) / 2)
63
+ y_left_eye = int((y1 + y2) / 2)
64
+ x_right_eye = int((x3 + x4) / 2)
65
+ y_right_eye = int((y3 + y4) / 2)
66
+
67
+ results = np.array(
68
+ [
69
+ [x_left_eye, y_left_eye],
70
+ [x_right_eye, y_right_eye],
71
+ [x_nose, y_nose],
72
+ [x_left_mouth, y_left_mouth],
73
+ [x_right_mouth, y_right_mouth],
74
+ ]
75
+ )
76
+
77
+ return results
78
+
79
+
80
+ def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
81
+
82
+ std_pts = _standard_face_pts() # [-1,1]
83
+ target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
84
+
85
+ # print(target_pts)
86
+
87
+ h, w, c = img.shape
88
+ if normalize == True:
89
+ landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
90
+ landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
91
+
92
+ # print(landmark)
93
+
94
+ affine = SimilarityTransform()
95
+
96
+ affine.estimate(target_pts, landmark)
97
+
98
+ return affine.params
99
+
100
+
101
+ def show_detection(image, box, landmark):
102
+ plt.imshow(image)
103
+ print(box[2] - box[0])
104
+ plt.gca().add_patch(
105
+ Rectangle(
106
+ (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
107
+ )
108
+ )
109
+ plt.scatter(landmark[0][0], landmark[0][1])
110
+ plt.scatter(landmark[1][0], landmark[1][1])
111
+ plt.scatter(landmark[2][0], landmark[2][1])
112
+ plt.scatter(landmark[3][0], landmark[3][1])
113
+ plt.scatter(landmark[4][0], landmark[4][1])
114
+ plt.show()
115
+
116
+
117
+ def affine2theta(affine, input_w, input_h, target_w, target_h):
118
+ # param = np.linalg.inv(affine)
119
+ param = affine
120
+ theta = np.zeros([2, 3])
121
+ theta[0, 0] = param[0, 0] * input_h / target_h
122
+ theta[0, 1] = param[0, 1] * input_w / target_h
123
+ theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
124
+ theta[1, 0] = param[1, 0] * input_h / target_w
125
+ theta[1, 1] = param[1, 1] * input_w / target_w
126
+ theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
127
+ return theta
128
+
129
+
130
+ if __name__ == "__main__":
131
+
132
+ parser = argparse.ArgumentParser()
133
+ parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
134
+ parser.add_argument(
135
+ "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
136
+ )
137
+ opts = parser.parse_args()
138
+
139
+ url = opts.url
140
+ save_url = opts.save_url
141
+
142
+ ### If the origin url is None, then we don't need to reid the origin image
143
+
144
+ os.makedirs(url, exist_ok=True)
145
+ os.makedirs(save_url, exist_ok=True)
146
+
147
+ face_detector = dlib.get_frontal_face_detector()
148
+ landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
149
+
150
+ count = 0
151
+
152
+ map_id = {}
153
+ for x in os.listdir(url):
154
+ img_url = os.path.join(url, x)
155
+ pil_img = Image.open(img_url).convert("RGB")
156
+
157
+ image = np.array(pil_img)
158
+
159
+ start = time.time()
160
+ faces = face_detector(image)
161
+ done = time.time()
162
+
163
+ if len(faces) == 0:
164
+ print("Warning: There is no face in %s" % (x))
165
+ continue
166
+
167
+ print(len(faces))
168
+
169
+ if len(faces) > 0:
170
+ for face_id in range(len(faces)):
171
+ current_face = faces[face_id]
172
+ face_landmarks = landmark_locator(image, current_face)
173
+ current_fl = search(face_landmarks)
174
+
175
+ affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
176
+ aligned_face = warp(image, affine, output_shape=(256, 256, 3))
177
+ img_name = x[:-4] + "_" + str(face_id + 1)
178
+ io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))
179
+
180
+ count += 1
181
+
182
+ if count % 1000 == 0:
183
+ print("%d have finished ..." % (count))
184
+
Face_Detection/detect_all_dlib_HR.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ import numpy as np
6
+ import skimage.io as io
7
+
8
+ # from FaceSDK.face_sdk import FaceDetection
9
+ # from face_sdk import FaceDetection
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.patches import Rectangle
12
+ from skimage.transform import SimilarityTransform
13
+ from skimage.transform import warp
14
+ from PIL import Image
15
+ import torch.nn.functional as F
16
+ import torchvision as tv
17
+ import torchvision.utils as vutils
18
+ import time
19
+ import cv2
20
+ import os
21
+ from skimage import img_as_ubyte
22
+ import json
23
+ import argparse
24
+ import dlib
25
+
26
+
27
+ def _standard_face_pts():
28
+ pts = (
29
+ np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
30
+ - 1.0
31
+ )
32
+
33
+ return np.reshape(pts, (5, 2))
34
+
35
+
36
+ def _origin_face_pts():
37
+ pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
38
+
39
+ return np.reshape(pts, (5, 2))
40
+
41
+
42
+ def get_landmark(face_landmarks, id):
43
+ part = face_landmarks.part(id)
44
+ x = part.x
45
+ y = part.y
46
+
47
+ return (x, y)
48
+
49
+
50
+ def search(face_landmarks):
51
+
52
+ x1, y1 = get_landmark(face_landmarks, 36)
53
+ x2, y2 = get_landmark(face_landmarks, 39)
54
+ x3, y3 = get_landmark(face_landmarks, 42)
55
+ x4, y4 = get_landmark(face_landmarks, 45)
56
+
57
+ x_nose, y_nose = get_landmark(face_landmarks, 30)
58
+
59
+ x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
60
+ x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
61
+
62
+ x_left_eye = int((x1 + x2) / 2)
63
+ y_left_eye = int((y1 + y2) / 2)
64
+ x_right_eye = int((x3 + x4) / 2)
65
+ y_right_eye = int((y3 + y4) / 2)
66
+
67
+ results = np.array(
68
+ [
69
+ [x_left_eye, y_left_eye],
70
+ [x_right_eye, y_right_eye],
71
+ [x_nose, y_nose],
72
+ [x_left_mouth, y_left_mouth],
73
+ [x_right_mouth, y_right_mouth],
74
+ ]
75
+ )
76
+
77
+ return results
78
+
79
+
80
+ def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
81
+
82
+ std_pts = _standard_face_pts() # [-1,1]
83
+ target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
84
+
85
+ # print(target_pts)
86
+
87
+ h, w, c = img.shape
88
+ if normalize == True:
89
+ landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
90
+ landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
91
+
92
+ # print(landmark)
93
+
94
+ affine = SimilarityTransform()
95
+
96
+ affine.estimate(target_pts, landmark)
97
+
98
+ return affine.params
99
+
100
+
101
+ def show_detection(image, box, landmark):
102
+ plt.imshow(image)
103
+ print(box[2] - box[0])
104
+ plt.gca().add_patch(
105
+ Rectangle(
106
+ (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
107
+ )
108
+ )
109
+ plt.scatter(landmark[0][0], landmark[0][1])
110
+ plt.scatter(landmark[1][0], landmark[1][1])
111
+ plt.scatter(landmark[2][0], landmark[2][1])
112
+ plt.scatter(landmark[3][0], landmark[3][1])
113
+ plt.scatter(landmark[4][0], landmark[4][1])
114
+ plt.show()
115
+
116
+
117
+ def affine2theta(affine, input_w, input_h, target_w, target_h):
118
+ # param = np.linalg.inv(affine)
119
+ param = affine
120
+ theta = np.zeros([2, 3])
121
+ theta[0, 0] = param[0, 0] * input_h / target_h
122
+ theta[0, 1] = param[0, 1] * input_w / target_h
123
+ theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
124
+ theta[1, 0] = param[1, 0] * input_h / target_w
125
+ theta[1, 1] = param[1, 1] * input_w / target_w
126
+ theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
127
+ return theta
128
+
129
+
130
+ if __name__ == "__main__":
131
+
132
+ parser = argparse.ArgumentParser()
133
+ parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
134
+ parser.add_argument(
135
+ "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
136
+ )
137
+ opts = parser.parse_args()
138
+
139
+ url = opts.url
140
+ save_url = opts.save_url
141
+
142
+ ### If the origin url is None, then we don't need to reid the origin image
143
+
144
+ os.makedirs(url, exist_ok=True)
145
+ os.makedirs(save_url, exist_ok=True)
146
+
147
+ face_detector = dlib.get_frontal_face_detector()
148
+ landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
149
+
150
+ count = 0
151
+
152
+ map_id = {}
153
+ for x in os.listdir(url):
154
+ img_url = os.path.join(url, x)
155
+ pil_img = Image.open(img_url).convert("RGB")
156
+
157
+ image = np.array(pil_img)
158
+
159
+ start = time.time()
160
+ faces = face_detector(image)
161
+ done = time.time()
162
+
163
+ if len(faces) == 0:
164
+ print("Warning: There is no face in %s" % (x))
165
+ continue
166
+
167
+ print(len(faces))
168
+
169
+ if len(faces) > 0:
170
+ for face_id in range(len(faces)):
171
+ current_face = faces[face_id]
172
+ face_landmarks = landmark_locator(image, current_face)
173
+ current_fl = search(face_landmarks)
174
+
175
+ affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
176
+ aligned_face = warp(image, affine, output_shape=(512, 512, 3))
177
+ img_name = x[:-4] + "_" + str(face_id + 1)
178
+ io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))
179
+
180
+ count += 1
181
+
182
+ if count % 1000 == 0:
183
+ print("%d have finished ..." % (count))
184
+
Face_Enhancement/data/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import importlib
5
+ import torch.utils.data
6
+ from data.base_dataset import BaseDataset
7
+ from data.face_dataset import FaceTestDataset
8
+
9
+
10
+ def create_dataloader(opt):
11
+
12
+ instance = FaceTestDataset()
13
+ instance.initialize(opt)
14
+ print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
15
+ dataloader = torch.utils.data.DataLoader(
16
+ instance,
17
+ batch_size=opt.batchSize,
18
+ shuffle=not opt.serial_batches,
19
+ num_workers=int(opt.nThreads),
20
+ drop_last=opt.isTrain,
21
+ )
22
+ return dataloader
Face_Enhancement/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch.utils.data as data
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import numpy as np
8
+ import random
9
+
10
+
11
+ class BaseDataset(data.Dataset):
12
+ def __init__(self):
13
+ super(BaseDataset, self).__init__()
14
+
15
+ @staticmethod
16
+ def modify_commandline_options(parser, is_train):
17
+ return parser
18
+
19
+ def initialize(self, opt):
20
+ pass
21
+
22
+
23
+ def get_params(opt, size):
24
+ w, h = size
25
+ new_h = h
26
+ new_w = w
27
+ if opt.preprocess_mode == "resize_and_crop":
28
+ new_h = new_w = opt.load_size
29
+ elif opt.preprocess_mode == "scale_width_and_crop":
30
+ new_w = opt.load_size
31
+ new_h = opt.load_size * h // w
32
+ elif opt.preprocess_mode == "scale_shortside_and_crop":
33
+ ss, ls = min(w, h), max(w, h) # shortside and longside
34
+ width_is_shorter = w == ss
35
+ ls = int(opt.load_size * ls / ss)
36
+ new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
37
+
38
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
39
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
40
+
41
+ flip = random.random() > 0.5
42
+ return {"crop_pos": (x, y), "flip": flip}
43
+
44
+
45
+ def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
46
+ transform_list = []
47
+ if "resize" in opt.preprocess_mode:
48
+ osize = [opt.load_size, opt.load_size]
49
+ transform_list.append(transforms.Resize(osize, interpolation=method))
50
+ elif "scale_width" in opt.preprocess_mode:
51
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
52
+ elif "scale_shortside" in opt.preprocess_mode:
53
+ transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
54
+
55
+ if "crop" in opt.preprocess_mode:
56
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size)))
57
+
58
+ if opt.preprocess_mode == "none":
59
+ base = 32
60
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
61
+
62
+ if opt.preprocess_mode == "fixed":
63
+ w = opt.crop_size
64
+ h = round(opt.crop_size / opt.aspect_ratio)
65
+ transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
66
+
67
+ if opt.isTrain and not opt.no_flip:
68
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))
69
+
70
+ if toTensor:
71
+ transform_list += [transforms.ToTensor()]
72
+
73
+ if normalize:
74
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
75
+ return transforms.Compose(transform_list)
76
+
77
+
78
+ def normalize():
79
+ return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
80
+
81
+
82
+ def __resize(img, w, h, method=Image.BICUBIC):
83
+ return img.resize((w, h), method)
84
+
85
+
86
+ def __make_power_2(img, base, method=Image.BICUBIC):
87
+ ow, oh = img.size
88
+ h = int(round(oh / base) * base)
89
+ w = int(round(ow / base) * base)
90
+ if (h == oh) and (w == ow):
91
+ return img
92
+ return img.resize((w, h), method)
93
+
94
+
95
+ def __scale_width(img, target_width, method=Image.BICUBIC):
96
+ ow, oh = img.size
97
+ if ow == target_width:
98
+ return img
99
+ w = target_width
100
+ h = int(target_width * oh / ow)
101
+ return img.resize((w, h), method)
102
+
103
+
104
+ def __scale_shortside(img, target_width, method=Image.BICUBIC):
105
+ ow, oh = img.size
106
+ ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
107
+ width_is_shorter = ow == ss
108
+ if ss == target_width:
109
+ return img
110
+ ls = int(target_width * ls / ss)
111
+ nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
112
+ return img.resize((nw, nh), method)
113
+
114
+
115
+ def __crop(img, pos, size):
116
+ ow, oh = img.size
117
+ x1, y1 = pos
118
+ tw = th = size
119
+ return img.crop((x1, y1, x1 + tw, y1 + th))
120
+
121
+
122
+ def __flip(img, flip):
123
+ if flip:
124
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
125
+ return img
Face_Enhancement/data/custom_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from data.pix2pix_dataset import Pix2pixDataset
5
+ from data.image_folder import make_dataset
6
+
7
+
8
+ class CustomDataset(Pix2pixDataset):
9
+ """ Dataset that loads images from directories
10
+ Use option --label_dir, --image_dir, --instance_dir to specify the directories.
11
+ The images in the directories are sorted in alphabetical order and paired in order.
12
+ """
13
+
14
+ @staticmethod
15
+ def modify_commandline_options(parser, is_train):
16
+ parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
17
+ parser.set_defaults(preprocess_mode="resize_and_crop")
18
+ load_size = 286 if is_train else 256
19
+ parser.set_defaults(load_size=load_size)
20
+ parser.set_defaults(crop_size=256)
21
+ parser.set_defaults(display_winsize=256)
22
+ parser.set_defaults(label_nc=13)
23
+ parser.set_defaults(contain_dontcare_label=False)
24
+
25
+ parser.add_argument(
26
+ "--label_dir", type=str, required=True, help="path to the directory that contains label images"
27
+ )
28
+ parser.add_argument(
29
+ "--image_dir", type=str, required=True, help="path to the directory that contains photo images"
30
+ )
31
+ parser.add_argument(
32
+ "--instance_dir",
33
+ type=str,
34
+ default="",
35
+ help="path to the directory that contains instance maps. Leave black if not exists",
36
+ )
37
+ return parser
38
+
39
+ def get_paths(self, opt):
40
+ label_dir = opt.label_dir
41
+ label_paths = make_dataset(label_dir, recursive=False, read_cache=True)
42
+
43
+ image_dir = opt.image_dir
44
+ image_paths = make_dataset(image_dir, recursive=False, read_cache=True)
45
+
46
+ if len(opt.instance_dir) > 0:
47
+ instance_dir = opt.instance_dir
48
+ instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
49
+ else:
50
+ instance_paths = []
51
+
52
+ assert len(label_paths) == len(
53
+ image_paths
54
+ ), "The #images in %s and %s do not match. Is there something wrong?"
55
+
56
+ return label_paths, image_paths, instance_paths
Face_Enhancement/data/face_dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from data.base_dataset import BaseDataset, get_params, get_transform
5
+ from PIL import Image
6
+ import util.util as util
7
+ import os
8
+ import torch
9
+
10
+
11
+ class FaceTestDataset(BaseDataset):
12
+ @staticmethod
13
+ def modify_commandline_options(parser, is_train):
14
+ parser.add_argument(
15
+ "--no_pairing_check",
16
+ action="store_true",
17
+ help="If specified, skip sanity check of correct label-image file pairing",
18
+ )
19
+ # parser.set_defaults(contain_dontcare_label=False)
20
+ # parser.set_defaults(no_instance=True)
21
+ return parser
22
+
23
+ def initialize(self, opt):
24
+ self.opt = opt
25
+
26
+ image_path = os.path.join(opt.dataroot, opt.old_face_folder)
27
+ label_path = os.path.join(opt.dataroot, opt.old_face_label_folder)
28
+
29
+ image_list = os.listdir(image_path)
30
+ image_list = sorted(image_list)
31
+ # image_list=image_list[:opt.max_dataset_size]
32
+
33
+ self.label_paths = label_path ## Just the root dir
34
+ self.image_paths = image_list ## All the image name
35
+
36
+ self.parts = [
37
+ "skin",
38
+ "hair",
39
+ "l_brow",
40
+ "r_brow",
41
+ "l_eye",
42
+ "r_eye",
43
+ "eye_g",
44
+ "l_ear",
45
+ "r_ear",
46
+ "ear_r",
47
+ "nose",
48
+ "mouth",
49
+ "u_lip",
50
+ "l_lip",
51
+ "neck",
52
+ "neck_l",
53
+ "cloth",
54
+ "hat",
55
+ ]
56
+
57
+ size = len(self.image_paths)
58
+ self.dataset_size = size
59
+
60
+ def __getitem__(self, index):
61
+
62
+ params = get_params(self.opt, (-1, -1))
63
+ image_name = self.image_paths[index]
64
+ image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name)
65
+ image = Image.open(image_path)
66
+ image = image.convert("RGB")
67
+
68
+ transform_image = get_transform(self.opt, params)
69
+ image_tensor = transform_image(image)
70
+
71
+ img_name = image_name[:-4]
72
+ transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
73
+ full_label = []
74
+
75
+ cnt = 0
76
+
77
+ for each_part in self.parts:
78
+ part_name = img_name + "_" + each_part + ".png"
79
+ part_url = os.path.join(self.label_paths, part_name)
80
+
81
+ if os.path.exists(part_url):
82
+ label = Image.open(part_url).convert("RGB")
83
+ label_tensor = transform_label(label) ## 3 channels and pixel [0,1]
84
+ full_label.append(label_tensor[0])
85
+ else:
86
+ current_part = torch.zeros((self.opt.load_size, self.opt.load_size))
87
+ full_label.append(current_part)
88
+ cnt += 1
89
+
90
+ full_label_tensor = torch.stack(full_label, 0)
91
+
92
+ input_dict = {
93
+ "label": full_label_tensor,
94
+ "image": image_tensor,
95
+ "path": image_path,
96
+ }
97
+
98
+ return input_dict
99
+
100
+ def __len__(self):
101
+ return self.dataset_size
102
+
Face_Enhancement/data/image_folder.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch.utils.data as data
5
+ from PIL import Image
6
+ import os
7
+
8
+ IMG_EXTENSIONS = [
9
+ ".jpg",
10
+ ".JPG",
11
+ ".jpeg",
12
+ ".JPEG",
13
+ ".png",
14
+ ".PNG",
15
+ ".ppm",
16
+ ".PPM",
17
+ ".bmp",
18
+ ".BMP",
19
+ ".tiff",
20
+ ".webp",
21
+ ]
22
+
23
+
24
+ def is_image_file(filename):
25
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
26
+
27
+
28
+ def make_dataset_rec(dir, images):
29
+ assert os.path.isdir(dir), "%s is not a valid directory" % dir
30
+
31
+ for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
32
+ for fname in fnames:
33
+ if is_image_file(fname):
34
+ path = os.path.join(root, fname)
35
+ images.append(path)
36
+
37
+
38
+ def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
39
+ images = []
40
+
41
+ if read_cache:
42
+ possible_filelist = os.path.join(dir, "files.list")
43
+ if os.path.isfile(possible_filelist):
44
+ with open(possible_filelist, "r") as f:
45
+ images = f.read().splitlines()
46
+ return images
47
+
48
+ if recursive:
49
+ make_dataset_rec(dir, images)
50
+ else:
51
+ assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
52
+
53
+ for root, dnames, fnames in sorted(os.walk(dir)):
54
+ for fname in fnames:
55
+ if is_image_file(fname):
56
+ path = os.path.join(root, fname)
57
+ images.append(path)
58
+
59
+ if write_cache:
60
+ filelist_cache = os.path.join(dir, "files.list")
61
+ with open(filelist_cache, "w") as f:
62
+ for path in images:
63
+ f.write("%s\n" % path)
64
+ print("wrote filelist cache at %s" % filelist_cache)
65
+
66
+ return images
67
+
68
+
69
+ def default_loader(path):
70
+ return Image.open(path).convert("RGB")
71
+
72
+
73
+ class ImageFolder(data.Dataset):
74
+ def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
75
+ imgs = make_dataset(root)
76
+ if len(imgs) == 0:
77
+ raise (
78
+ RuntimeError(
79
+ "Found 0 images in: " + root + "\n"
80
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
81
+ )
82
+ )
83
+
84
+ self.root = root
85
+ self.imgs = imgs
86
+ self.transform = transform
87
+ self.return_paths = return_paths
88
+ self.loader = loader
89
+
90
+ def __getitem__(self, index):
91
+ path = self.imgs[index]
92
+ img = self.loader(path)
93
+ if self.transform is not None:
94
+ img = self.transform(img)
95
+ if self.return_paths:
96
+ return img, path
97
+ else:
98
+ return img
99
+
100
+ def __len__(self):
101
+ return len(self.imgs)
Face_Enhancement/data/pix2pix_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from data.base_dataset import BaseDataset, get_params, get_transform
5
+ from PIL import Image
6
+ import util.util as util
7
+ import os
8
+
9
+
10
+ class Pix2pixDataset(BaseDataset):
11
+ @staticmethod
12
+ def modify_commandline_options(parser, is_train):
13
+ parser.add_argument(
14
+ "--no_pairing_check",
15
+ action="store_true",
16
+ help="If specified, skip sanity check of correct label-image file pairing",
17
+ )
18
+ return parser
19
+
20
+ def initialize(self, opt):
21
+ self.opt = opt
22
+
23
+ label_paths, image_paths, instance_paths = self.get_paths(opt)
24
+
25
+ util.natural_sort(label_paths)
26
+ util.natural_sort(image_paths)
27
+ if not opt.no_instance:
28
+ util.natural_sort(instance_paths)
29
+
30
+ label_paths = label_paths[: opt.max_dataset_size]
31
+ image_paths = image_paths[: opt.max_dataset_size]
32
+ instance_paths = instance_paths[: opt.max_dataset_size]
33
+
34
+ if not opt.no_pairing_check:
35
+ for path1, path2 in zip(label_paths, image_paths):
36
+ assert self.paths_match(path1, path2), (
37
+ "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
38
+ % (path1, path2)
39
+ )
40
+
41
+ self.label_paths = label_paths
42
+ self.image_paths = image_paths
43
+ self.instance_paths = instance_paths
44
+
45
+ size = len(self.label_paths)
46
+ self.dataset_size = size
47
+
48
+ def get_paths(self, opt):
49
+ label_paths = []
50
+ image_paths = []
51
+ instance_paths = []
52
+ assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
53
+ return label_paths, image_paths, instance_paths
54
+
55
+ def paths_match(self, path1, path2):
56
+ filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
57
+ filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
58
+ return filename1_without_ext == filename2_without_ext
59
+
60
+ def __getitem__(self, index):
61
+ # Label Image
62
+ label_path = self.label_paths[index]
63
+ label = Image.open(label_path)
64
+ params = get_params(self.opt, label.size)
65
+ transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
66
+ label_tensor = transform_label(label) * 255.0
67
+ label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
68
+
69
+ # input image (real images)
70
+ image_path = self.image_paths[index]
71
+ assert self.paths_match(
72
+ label_path, image_path
73
+ ), "The label_path %s and image_path %s don't match." % (label_path, image_path)
74
+ image = Image.open(image_path)
75
+ image = image.convert("RGB")
76
+
77
+ transform_image = get_transform(self.opt, params)
78
+ image_tensor = transform_image(image)
79
+
80
+ # if using instance maps
81
+ if self.opt.no_instance:
82
+ instance_tensor = 0
83
+ else:
84
+ instance_path = self.instance_paths[index]
85
+ instance = Image.open(instance_path)
86
+ if instance.mode == "L":
87
+ instance_tensor = transform_label(instance) * 255
88
+ instance_tensor = instance_tensor.long()
89
+ else:
90
+ instance_tensor = transform_label(instance)
91
+
92
+ input_dict = {
93
+ "label": label_tensor,
94
+ "instance": instance_tensor,
95
+ "image": image_tensor,
96
+ "path": image_path,
97
+ }
98
+
99
+ # Give subclasses a chance to modify the final output
100
+ self.postprocess(input_dict)
101
+
102
+ return input_dict
103
+
104
+ def postprocess(self, input_dict):
105
+ return input_dict
106
+
107
+ def __len__(self):
108
+ return self.dataset_size
Face_Enhancement/models/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import importlib
5
+ import torch
6
+
7
+
8
+ def find_model_using_name(model_name):
9
+ # Given the option --model [modelname],
10
+ # the file "models/modelname_model.py"
11
+ # will be imported.
12
+ model_filename = "models." + model_name + "_model"
13
+ modellib = importlib.import_module(model_filename)
14
+
15
+ # In the file, the class called ModelNameModel() will
16
+ # be instantiated. It has to be a subclass of torch.nn.Module,
17
+ # and it is case-insensitive.
18
+ model = None
19
+ target_model_name = model_name.replace("_", "") + "model"
20
+ for name, cls in modellib.__dict__.items():
21
+ if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module):
22
+ model = cls
23
+
24
+ if model is None:
25
+ print(
26
+ "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase."
27
+ % (model_filename, target_model_name)
28
+ )
29
+ exit(0)
30
+
31
+ return model
32
+
33
+
34
+ def get_option_setter(model_name):
35
+ model_class = find_model_using_name(model_name)
36
+ return model_class.modify_commandline_options
37
+
38
+
39
+ def create_model(opt):
40
+ model = find_model_using_name(opt.model)
41
+ instance = model(opt)
42
+ print("model [%s] was created" % (type(instance).__name__))
43
+
44
+ return instance
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Jiayuan MAO
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Synchronized-BatchNorm-PyTorch
2
+
3
+ **IMPORTANT: Please read the "Implementation details and highlights" section before use.**
4
+
5
+ Synchronized Batch Normalization implementation in PyTorch.
6
+
7
+ This module differs from the built-in PyTorch BatchNorm as the mean and
8
+ standard-deviation are reduced across all devices during training.
9
+
10
+ For example, when one uses `nn.DataParallel` to wrap the network during
11
+ training, PyTorch's implementation normalize the tensor on each device using
12
+ the statistics only on that device, which accelerated the computation and
13
+ is also easy to implement, but the statistics might be inaccurate.
14
+ Instead, in this synchronized version, the statistics will be computed
15
+ over all training samples distributed on multiple devices.
16
+
17
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
18
+ as the built-in PyTorch implementation.
19
+
20
+ This module is currently only a prototype version for research usages. As mentioned below,
21
+ it has its limitations and may even suffer from some design problems. If you have any
22
+ questions or suggestions, please feel free to
23
+ [open an issue](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues) or
24
+ [submit a pull request](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues).
25
+
26
+ ## Why Synchronized BatchNorm?
27
+
28
+ Although the typical implementation of BatchNorm working on multiple devices (GPUs)
29
+ is fast (with no communication overhead), it inevitably reduces the size of batch size,
30
+ which potentially degenerates the performance. This is not a significant issue in some
31
+ standard vision tasks such as ImageNet classification (as the batch size per device
32
+ is usually large enough to obtain good statistics). However, it will hurt the performance
33
+ in some tasks that the batch size is usually very small (e.g., 1 per GPU).
34
+
35
+ For example, the importance of synchronized batch normalization in object detection has been recently proved with a
36
+ an extensive analysis in the paper [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240).
37
+
38
+ ## Usage
39
+
40
+ To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight
41
+ difference with typical usage of the `nn.DataParallel`.
42
+
43
+ Use it with a provided, customized data parallel wrapper:
44
+
45
+ ```python
46
+ from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback
47
+
48
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
49
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
50
+ ```
51
+
52
+ Or, if you are using a customized data parallel module, you can use this library as a monkey patching.
53
+
54
+ ```python
55
+ from torch.nn import DataParallel # or your customized DataParallel module
56
+ from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback
57
+
58
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
59
+ sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
60
+ patch_replication_callback(sync_bn) # monkey-patching
61
+ ```
62
+
63
+ You can use `convert_model` to convert your model to use Synchronized BatchNorm easily.
64
+
65
+ ```python
66
+ import torch.nn as nn
67
+ from torchvision import models
68
+ from sync_batchnorm import convert_model
69
+ # m is a standard pytorch model
70
+ m = models.resnet18(True)
71
+ m = nn.DataParallel(m)
72
+ # after convert, m is using SyncBN
73
+ m = convert_model(m)
74
+ ```
75
+
76
+ See also `tests/test_sync_batchnorm.py` for numeric result comparison.
77
+
78
+ ## Implementation details and highlights
79
+
80
+ If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look
81
+ at the code with detailed comments. Here we only emphasize some highlights of the implementation:
82
+
83
+ - This implementation is in pure-python. No C++ extra extension libs.
84
+ - Easy to use as demonstrated above.
85
+ - It uses unbiased variance to update the moving average, and use `sqrt(max(var, eps))` instead of `sqrt(var + eps)`.
86
+ - The implementation requires that each module on different devices should invoke the `batchnorm` for exactly SAME
87
+ amount of times in each forward pass. For example, you can not only call `batchnorm` on GPU0 but not on GPU1. The `#i
88
+ (i = 1, 2, 3, ...)` calls of the `batchnorm` on each device will be viewed as a whole and the statistics will be reduced.
89
+ This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this
90
+ will usually not be the issue for most of the models.
91
+
92
+ ## Known issues
93
+
94
+ #### Runtime error on backward pass.
95
+
96
+ Due to a [PyTorch Bug](https://github.com/pytorch/pytorch/issues/3883), using old PyTorch libraries will trigger an `RuntimeError` with messages like:
97
+
98
+ ```
99
+ Assertion `pos >= 0 && pos < buffer.size()` failed.
100
+ ```
101
+
102
+ This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the
103
+ instructions [here](https://github.com/pytorch/pytorch#from-source).
104
+
105
+ #### Numeric error.
106
+
107
+ Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less
108
+ numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in
109
+ `tests/test_sync_batchnorm.py`.
110
+
111
+ ## Authors and License:
112
+
113
+ Copyright (c) 2018-, [Jiayuan Mao](https://vccy.xyz).
114
+
115
+ **Contributors**: [Tete Xiao](https://tetexiao.com), [DTennant](https://github.com/DTennant).
116
+
117
+ Distributed under **MIT License** (See LICENSE)
118
+
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import set_sbn_eps_mode
12
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
13
+ from .batchnorm import patch_sync_batchnorm, convert_model
14
+ from .replicate import DataParallelWithCallback, patch_replication_callback
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+ import contextlib
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from torch.nn.modules.batchnorm import _BatchNorm
18
+
19
+ try:
20
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21
+ except ImportError:
22
+ ReduceAddCoalesced = Broadcast = None
23
+
24
+ try:
25
+ from jactorch.parallel.comm import SyncMaster
26
+ from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27
+ except ImportError:
28
+ from .comm import SyncMaster
29
+ from .replicate import DataParallelWithCallback
30
+
31
+ __all__ = [
32
+ 'set_sbn_eps_mode',
33
+ 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
34
+ 'patch_sync_batchnorm', 'convert_model'
35
+ ]
36
+
37
+
38
+ SBN_EPS_MODE = 'clamp'
39
+
40
+
41
+ def set_sbn_eps_mode(mode):
42
+ global SBN_EPS_MODE
43
+ assert mode in ('clamp', 'plus')
44
+ SBN_EPS_MODE = mode
45
+
46
+
47
+ def _sum_ft(tensor):
48
+ """sum over the first and last dimention"""
49
+ return tensor.sum(dim=0).sum(dim=-1)
50
+
51
+
52
+ def _unsqueeze_ft(tensor):
53
+ """add new dimensions at the front and the tail"""
54
+ return tensor.unsqueeze(0).unsqueeze(-1)
55
+
56
+
57
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
58
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
59
+
60
+
61
+ class _SynchronizedBatchNorm(_BatchNorm):
62
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
63
+ assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
64
+
65
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
66
+ track_running_stats=track_running_stats)
67
+
68
+ if not self.track_running_stats:
69
+ import warnings
70
+ warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
71
+
72
+ self._sync_master = SyncMaster(self._data_parallel_master)
73
+
74
+ self._is_parallel = False
75
+ self._parallel_id = None
76
+ self._slave_pipe = None
77
+
78
+ def forward(self, input):
79
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
80
+ if not (self._is_parallel and self.training):
81
+ return F.batch_norm(
82
+ input, self.running_mean, self.running_var, self.weight, self.bias,
83
+ self.training, self.momentum, self.eps)
84
+
85
+ # Resize the input to (B, C, -1).
86
+ input_shape = input.size()
87
+ assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
88
+ input = input.view(input.size(0), self.num_features, -1)
89
+
90
+ # Compute the sum and square-sum.
91
+ sum_size = input.size(0) * input.size(2)
92
+ input_sum = _sum_ft(input)
93
+ input_ssum = _sum_ft(input ** 2)
94
+
95
+ # Reduce-and-broadcast the statistics.
96
+ if self._parallel_id == 0:
97
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
98
+ else:
99
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
100
+
101
+ # Compute the output.
102
+ if self.affine:
103
+ # MJY:: Fuse the multiplication for speed.
104
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
105
+ else:
106
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
107
+
108
+ # Reshape it.
109
+ return output.view(input_shape)
110
+
111
+ def __data_parallel_replicate__(self, ctx, copy_id):
112
+ self._is_parallel = True
113
+ self._parallel_id = copy_id
114
+
115
+ # parallel_id == 0 means master device.
116
+ if self._parallel_id == 0:
117
+ ctx.sync_master = self._sync_master
118
+ else:
119
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
120
+
121
+ def _data_parallel_master(self, intermediates):
122
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
123
+
124
+ # Always using same "device order" makes the ReduceAdd operation faster.
125
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
126
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
127
+
128
+ to_reduce = [i[1][:2] for i in intermediates]
129
+ to_reduce = [j for i in to_reduce for j in i] # flatten
130
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
131
+
132
+ sum_size = sum([i[1].sum_size for i in intermediates])
133
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
134
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
135
+
136
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
137
+
138
+ outputs = []
139
+ for i, rec in enumerate(intermediates):
140
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
141
+
142
+ return outputs
143
+
144
+ def _compute_mean_std(self, sum_, ssum, size):
145
+ """Compute the mean and standard-deviation with sum and square-sum. This method
146
+ also maintains the moving average on the master device."""
147
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
148
+ mean = sum_ / size
149
+ sumvar = ssum - sum_ * mean
150
+ unbias_var = sumvar / (size - 1)
151
+ bias_var = sumvar / size
152
+
153
+ if hasattr(torch, 'no_grad'):
154
+ with torch.no_grad():
155
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
156
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
157
+ else:
158
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
159
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
160
+
161
+ if SBN_EPS_MODE == 'clamp':
162
+ return mean, bias_var.clamp(self.eps) ** -0.5
163
+ elif SBN_EPS_MODE == 'plus':
164
+ return mean, (bias_var + self.eps) ** -0.5
165
+ else:
166
+ raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
167
+
168
+
169
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
170
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
171
+ mini-batch.
172
+
173
+ .. math::
174
+
175
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
176
+
177
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
178
+ standard-deviation are reduced across all devices during training.
179
+
180
+ For example, when one uses `nn.DataParallel` to wrap the network during
181
+ training, PyTorch's implementation normalize the tensor on each device using
182
+ the statistics only on that device, which accelerated the computation and
183
+ is also easy to implement, but the statistics might be inaccurate.
184
+ Instead, in this synchronized version, the statistics will be computed
185
+ over all training samples distributed on multiple devices.
186
+
187
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
188
+ as the built-in PyTorch implementation.
189
+
190
+ The mean and standard-deviation are calculated per-dimension over
191
+ the mini-batches and gamma and beta are learnable parameter vectors
192
+ of size C (where C is the input size).
193
+
194
+ During training, this layer keeps a running estimate of its computed mean
195
+ and variance. The running sum is kept with a default momentum of 0.1.
196
+
197
+ During evaluation, this running mean/variance is used for normalization.
198
+
199
+ Because the BatchNorm is done over the `C` dimension, computing statistics
200
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
201
+
202
+ Args:
203
+ num_features: num_features from an expected input of size
204
+ `batch_size x num_features [x width]`
205
+ eps: a value added to the denominator for numerical stability.
206
+ Default: 1e-5
207
+ momentum: the value used for the running_mean and running_var
208
+ computation. Default: 0.1
209
+ affine: a boolean value that when set to ``True``, gives the layer learnable
210
+ affine parameters. Default: ``True``
211
+
212
+ Shape::
213
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
214
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
215
+
216
+ Examples:
217
+ >>> # With Learnable Parameters
218
+ >>> m = SynchronizedBatchNorm1d(100)
219
+ >>> # Without Learnable Parameters
220
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
221
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
222
+ >>> output = m(input)
223
+ """
224
+
225
+ def _check_input_dim(self, input):
226
+ if input.dim() != 2 and input.dim() != 3:
227
+ raise ValueError('expected 2D or 3D input (got {}D input)'
228
+ .format(input.dim()))
229
+
230
+
231
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
232
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
233
+ of 3d inputs
234
+
235
+ .. math::
236
+
237
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
238
+
239
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
240
+ standard-deviation are reduced across all devices during training.
241
+
242
+ For example, when one uses `nn.DataParallel` to wrap the network during
243
+ training, PyTorch's implementation normalize the tensor on each device using
244
+ the statistics only on that device, which accelerated the computation and
245
+ is also easy to implement, but the statistics might be inaccurate.
246
+ Instead, in this synchronized version, the statistics will be computed
247
+ over all training samples distributed on multiple devices.
248
+
249
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
250
+ as the built-in PyTorch implementation.
251
+
252
+ The mean and standard-deviation are calculated per-dimension over
253
+ the mini-batches and gamma and beta are learnable parameter vectors
254
+ of size C (where C is the input size).
255
+
256
+ During training, this layer keeps a running estimate of its computed mean
257
+ and variance. The running sum is kept with a default momentum of 0.1.
258
+
259
+ During evaluation, this running mean/variance is used for normalization.
260
+
261
+ Because the BatchNorm is done over the `C` dimension, computing statistics
262
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
263
+
264
+ Args:
265
+ num_features: num_features from an expected input of
266
+ size batch_size x num_features x height x width
267
+ eps: a value added to the denominator for numerical stability.
268
+ Default: 1e-5
269
+ momentum: the value used for the running_mean and running_var
270
+ computation. Default: 0.1
271
+ affine: a boolean value that when set to ``True``, gives the layer learnable
272
+ affine parameters. Default: ``True``
273
+
274
+ Shape::
275
+ - Input: :math:`(N, C, H, W)`
276
+ - Output: :math:`(N, C, H, W)` (same shape as input)
277
+
278
+ Examples:
279
+ >>> # With Learnable Parameters
280
+ >>> m = SynchronizedBatchNorm2d(100)
281
+ >>> # Without Learnable Parameters
282
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
283
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
284
+ >>> output = m(input)
285
+ """
286
+
287
+ def _check_input_dim(self, input):
288
+ if input.dim() != 4:
289
+ raise ValueError('expected 4D input (got {}D input)'
290
+ .format(input.dim()))
291
+
292
+
293
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
294
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
295
+ of 4d inputs
296
+
297
+ .. math::
298
+
299
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
300
+
301
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
302
+ standard-deviation are reduced across all devices during training.
303
+
304
+ For example, when one uses `nn.DataParallel` to wrap the network during
305
+ training, PyTorch's implementation normalize the tensor on each device using
306
+ the statistics only on that device, which accelerated the computation and
307
+ is also easy to implement, but the statistics might be inaccurate.
308
+ Instead, in this synchronized version, the statistics will be computed
309
+ over all training samples distributed on multiple devices.
310
+
311
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
312
+ as the built-in PyTorch implementation.
313
+
314
+ The mean and standard-deviation are calculated per-dimension over
315
+ the mini-batches and gamma and beta are learnable parameter vectors
316
+ of size C (where C is the input size).
317
+
318
+ During training, this layer keeps a running estimate of its computed mean
319
+ and variance. The running sum is kept with a default momentum of 0.1.
320
+
321
+ During evaluation, this running mean/variance is used for normalization.
322
+
323
+ Because the BatchNorm is done over the `C` dimension, computing statistics
324
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
325
+ or Spatio-temporal BatchNorm
326
+
327
+ Args:
328
+ num_features: num_features from an expected input of
329
+ size batch_size x num_features x depth x height x width
330
+ eps: a value added to the denominator for numerical stability.
331
+ Default: 1e-5
332
+ momentum: the value used for the running_mean and running_var
333
+ computation. Default: 0.1
334
+ affine: a boolean value that when set to ``True``, gives the layer learnable
335
+ affine parameters. Default: ``True``
336
+
337
+ Shape::
338
+ - Input: :math:`(N, C, D, H, W)`
339
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
340
+
341
+ Examples:
342
+ >>> # With Learnable Parameters
343
+ >>> m = SynchronizedBatchNorm3d(100)
344
+ >>> # Without Learnable Parameters
345
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
346
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
347
+ >>> output = m(input)
348
+ """
349
+
350
+ def _check_input_dim(self, input):
351
+ if input.dim() != 5:
352
+ raise ValueError('expected 5D input (got {}D input)'
353
+ .format(input.dim()))
354
+
355
+
356
+ @contextlib.contextmanager
357
+ def patch_sync_batchnorm():
358
+ import torch.nn as nn
359
+
360
+ backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
361
+
362
+ nn.BatchNorm1d = SynchronizedBatchNorm1d
363
+ nn.BatchNorm2d = SynchronizedBatchNorm2d
364
+ nn.BatchNorm3d = SynchronizedBatchNorm3d
365
+
366
+ yield
367
+
368
+ nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
369
+
370
+
371
+ def convert_model(module):
372
+ """Traverse the input module and its child recursively
373
+ and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
374
+ to SynchronizedBatchNorm*N*d
375
+
376
+ Args:
377
+ module: the input module needs to be convert to SyncBN model
378
+
379
+ Examples:
380
+ >>> import torch.nn as nn
381
+ >>> import torchvision
382
+ >>> # m is a standard pytorch model
383
+ >>> m = torchvision.models.resnet18(True)
384
+ >>> m = nn.DataParallel(m)
385
+ >>> # after convert, m is using SyncBN
386
+ >>> m = convert_model(m)
387
+ """
388
+ if isinstance(module, torch.nn.DataParallel):
389
+ mod = module.module
390
+ mod = convert_model(mod)
391
+ mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
392
+ return mod
393
+
394
+ mod = module
395
+ for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
396
+ torch.nn.modules.batchnorm.BatchNorm2d,
397
+ torch.nn.modules.batchnorm.BatchNorm3d],
398
+ [SynchronizedBatchNorm1d,
399
+ SynchronizedBatchNorm2d,
400
+ SynchronizedBatchNorm3d]):
401
+ if isinstance(module, pth_module):
402
+ mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
403
+ mod.running_mean = module.running_mean
404
+ mod.running_var = module.running_var
405
+ if module.affine:
406
+ mod.weight.data = module.weight.data.clone().detach()
407
+ mod.bias.data = module.bias.data.clone().detach()
408
+
409
+ for name, child in module.named_children():
410
+ mod.add_module(name, convert_model(child))
411
+
412
+ return mod
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : batchnorm_reimpl.py
4
+ # Author : acgtyrant
5
+ # Date : 11/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.init as init
14
+
15
+ __all__ = ['BatchNorm2dReimpl']
16
+
17
+
18
+ class BatchNorm2dReimpl(nn.Module):
19
+ """
20
+ A re-implementation of batch normalization, used for testing the numerical
21
+ stability.
22
+
23
+ Author: acgtyrant
24
+ See also:
25
+ https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26
+ """
27
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
28
+ super().__init__()
29
+
30
+ self.num_features = num_features
31
+ self.eps = eps
32
+ self.momentum = momentum
33
+ self.weight = nn.Parameter(torch.empty(num_features))
34
+ self.bias = nn.Parameter(torch.empty(num_features))
35
+ self.register_buffer('running_mean', torch.zeros(num_features))
36
+ self.register_buffer('running_var', torch.ones(num_features))
37
+ self.reset_parameters()
38
+
39
+ def reset_running_stats(self):
40
+ self.running_mean.zero_()
41
+ self.running_var.fill_(1)
42
+
43
+ def reset_parameters(self):
44
+ self.reset_running_stats()
45
+ init.uniform_(self.weight)
46
+ init.zeros_(self.bias)
47
+
48
+ def forward(self, input_):
49
+ batchsize, channels, height, width = input_.size()
50
+ numel = batchsize * height * width
51
+ input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52
+ sum_ = input_.sum(1)
53
+ sum_of_square = input_.pow(2).sum(1)
54
+ mean = sum_ / numel
55
+ sumvar = sum_of_square - sum_ * mean
56
+
57
+ self.running_mean = (
58
+ (1 - self.momentum) * self.running_mean
59
+ + self.momentum * mean.detach()
60
+ )
61
+ unbias_var = sumvar / (numel - 1)
62
+ self.running_var = (
63
+ (1 - self.momentum) * self.running_var
64
+ + self.momentum * unbias_var.detach()
65
+ )
66
+
67
+ bias_var = sumvar / numel
68
+ inv_std = 1 / (bias_var + self.eps).pow(0.5)
69
+ output = (
70
+ (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71
+ self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72
+
73
+ return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74
+
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+ import torch
13
+
14
+
15
+ class TorchTestCase(unittest.TestCase):
16
+ def assertTensorClose(self, x, y):
17
+ adiff = float((x - y).abs().max())
18
+ if (y == 0).all():
19
+ rdiff = 'NaN'
20
+ else:
21
+ rdiff = float((adiff / y).abs().max())
22
+
23
+ message = (
24
+ 'Tensor close check failed\n'
25
+ 'adiff={}\n'
26
+ 'rdiff={}\n'
27
+ ).format(adiff, rdiff)
28
+ self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
29
+
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_numeric_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm.unittest import TorchTestCase
16
+
17
+
18
+ def handy_var(a, unbias=True):
19
+ n = a.size(0)
20
+ asum = a.sum(dim=0)
21
+ as_sum = (a ** 2).sum(dim=0) # a square sum
22
+ sumvar = as_sum - asum * asum / n
23
+ if unbias:
24
+ return sumvar / (n - 1)
25
+ else:
26
+ return sumvar / n
27
+
28
+
29
+ class NumericTestCase(TorchTestCase):
30
+ def testNumericBatchNorm(self):
31
+ a = torch.rand(16, 10)
32
+ bn = nn.BatchNorm1d(10, momentum=1, eps=1e-5, affine=False)
33
+ bn.train()
34
+
35
+ a_var1 = Variable(a, requires_grad=True)
36
+ b_var1 = bn(a_var1)
37
+ loss1 = b_var1.sum()
38
+ loss1.backward()
39
+
40
+ a_var2 = Variable(a, requires_grad=True)
41
+ a_mean2 = a_var2.mean(dim=0, keepdim=True)
42
+ a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43
+ # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44
+ b_var2 = (a_var2 - a_mean2) / a_std2
45
+ loss2 = b_var2.sum()
46
+ loss2.backward()
47
+
48
+ self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49
+ self.assertTensorClose(bn.running_var, handy_var(a))
50
+ self.assertTensorClose(a_var1.data, a_var2.data)
51
+ self.assertTensorClose(b_var1.data, b_var2.data)
52
+ self.assertTensorClose(a_var1.grad, a_var2.grad)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ unittest.main()
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test_numeric_batchnorm_v2.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 11/01/2018
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ """
11
+ Test the numerical implementation of batch normalization.
12
+
13
+ Author: acgtyrant.
14
+ See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
15
+ """
16
+
17
+ import unittest
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+
23
+ from sync_batchnorm.unittest import TorchTestCase
24
+ from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl
25
+
26
+
27
+ class NumericTestCasev2(TorchTestCase):
28
+ def testNumericBatchNorm(self):
29
+ CHANNELS = 16
30
+ batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1)
31
+ optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01)
32
+
33
+ batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1)
34
+ batchnorm2.weight.data.copy_(batchnorm1.weight.data)
35
+ batchnorm2.bias.data.copy_(batchnorm1.bias.data)
36
+ optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01)
37
+
38
+ for _ in range(100):
39
+ input_ = torch.rand(16, CHANNELS, 16, 16)
40
+
41
+ input1 = input_.clone().requires_grad_(True)
42
+ output1 = batchnorm1(input1)
43
+ output1.sum().backward()
44
+ optimizer1.step()
45
+
46
+ input2 = input_.clone().requires_grad_(True)
47
+ output2 = batchnorm2(input2)
48
+ output2.sum().backward()
49
+ optimizer2.step()
50
+
51
+ self.assertTensorClose(input1, input2)
52
+ self.assertTensorClose(output1, output2)
53
+ self.assertTensorClose(input1.grad, input2.grad)
54
+ self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad)
55
+ self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad)
56
+ self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean)
57
+ self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean)
58
+
59
+
60
+ if __name__ == '__main__':
61
+ unittest.main()
62
+
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_sync_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm import set_sbn_eps_mode
16
+ from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
17
+ from sync_batchnorm.unittest import TorchTestCase
18
+
19
+ set_sbn_eps_mode('plus')
20
+
21
+
22
+ def handy_var(a, unbias=True):
23
+ n = a.size(0)
24
+ asum = a.sum(dim=0)
25
+ as_sum = (a ** 2).sum(dim=0) # a square sum
26
+ sumvar = as_sum - asum * asum / n
27
+ if unbias:
28
+ return sumvar / (n - 1)
29
+ else:
30
+ return sumvar / n
31
+
32
+
33
+ def _find_bn(module):
34
+ for m in module.modules():
35
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
36
+ return m
37
+
38
+
39
+ class SyncTestCase(TorchTestCase):
40
+ def _syncParameters(self, bn1, bn2):
41
+ bn1.reset_parameters()
42
+ bn2.reset_parameters()
43
+ if bn1.affine and bn2.affine:
44
+ bn2.weight.data.copy_(bn1.weight.data)
45
+ bn2.bias.data.copy_(bn1.bias.data)
46
+
47
+ def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
48
+ """Check the forward and backward for the customized batch normalization."""
49
+ bn1.train(mode=is_train)
50
+ bn2.train(mode=is_train)
51
+
52
+ if cuda:
53
+ input = input.cuda()
54
+
55
+ self._syncParameters(_find_bn(bn1), _find_bn(bn2))
56
+
57
+ input1 = Variable(input, requires_grad=True)
58
+ output1 = bn1(input1)
59
+ output1.sum().backward()
60
+ input2 = Variable(input, requires_grad=True)
61
+ output2 = bn2(input2)
62
+ output2.sum().backward()
63
+
64
+ self.assertTensorClose(input1.data, input2.data)
65
+ self.assertTensorClose(output1.data, output2.data)
66
+ self.assertTensorClose(input1.grad, input2.grad)
67
+ self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
68
+ self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
69
+
70
+ def testSyncBatchNormNormalTrain(self):
71
+ bn = nn.BatchNorm1d(10)
72
+ sync_bn = SynchronizedBatchNorm1d(10)
73
+
74
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
75
+
76
+ def testSyncBatchNormNormalEval(self):
77
+ bn = nn.BatchNorm1d(10)
78
+ sync_bn = SynchronizedBatchNorm1d(10)
79
+
80
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
81
+
82
+ def testSyncBatchNormSyncTrain(self):
83
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
84
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
85
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
86
+
87
+ bn.cuda()
88
+ sync_bn.cuda()
89
+
90
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
91
+
92
+ def testSyncBatchNormSyncEval(self):
93
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
94
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
95
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
96
+
97
+ bn.cuda()
98
+ sync_bn.cuda()
99
+
100
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
101
+
102
+ def testSyncBatchNorm2DSyncTrain(self):
103
+ bn = nn.BatchNorm2d(10)
104
+ sync_bn = SynchronizedBatchNorm2d(10)
105
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
106
+
107
+ bn.cuda()
108
+ sync_bn.cuda()
109
+
110
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
111
+
112
+
113
+ if __name__ == '__main__':
114
+ unittest.main()
Face_Enhancement/models/networks/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ from models.networks.base_network import BaseNetwork
6
+ from models.networks.generator import *
7
+ from models.networks.encoder import *
8
+ import util.util as util
9
+
10
+
11
+ def find_network_using_name(target_network_name, filename):
12
+ target_class_name = target_network_name + filename
13
+ module_name = "models.networks." + filename
14
+ network = util.find_class_in_module(target_class_name, module_name)
15
+
16
+ assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network
17
+
18
+ return network
19
+
20
+
21
+ def modify_commandline_options(parser, is_train):
22
+ opt, _ = parser.parse_known_args()
23
+
24
+ netG_cls = find_network_using_name(opt.netG, "generator")
25
+ parser = netG_cls.modify_commandline_options(parser, is_train)
26
+ if is_train:
27
+ netD_cls = find_network_using_name(opt.netD, "discriminator")
28
+ parser = netD_cls.modify_commandline_options(parser, is_train)
29
+ netE_cls = find_network_using_name("conv", "encoder")
30
+ parser = netE_cls.modify_commandline_options(parser, is_train)
31
+
32
+ return parser
33
+
34
+
35
+ def create_network(cls, opt):
36
+ net = cls(opt)
37
+ net.print_network()
38
+ if len(opt.gpu_ids) > 0:
39
+ assert torch.cuda.is_available()
40
+ net.cuda()
41
+ net.init_weights(opt.init_type, opt.init_variance)
42
+ return net
43
+
44
+
45
+ def define_G(opt):
46
+ netG_cls = find_network_using_name(opt.netG, "generator")
47
+ return create_network(netG_cls, opt)
48
+
49
+
50
+ def define_D(opt):
51
+ netD_cls = find_network_using_name(opt.netD, "discriminator")
52
+ return create_network(netD_cls, opt)
53
+
54
+
55
+ def define_E(opt):
56
+ # there exists only one encoder type
57
+ netE_cls = find_network_using_name("conv", "encoder")
58
+ return create_network(netE_cls, opt)
Face_Enhancement/models/networks/architecture.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision
8
+ import torch.nn.utils.spectral_norm as spectral_norm
9
+ from models.networks.normalization import SPADE
10
+
11
+
12
+ # ResNet block that uses SPADE.
13
+ # It differs from the ResNet block of pix2pixHD in that
14
+ # it takes in the segmentation map as input, learns the skip connection if necessary,
15
+ # and applies normalization first and then convolution.
16
+ # This architecture seemed like a standard architecture for unconditional or
17
+ # class-conditional GAN architecture using residual block.
18
+ # The code was inspired from https://github.com/LMescheder/GAN_stability.
19
+ class SPADEResnetBlock(nn.Module):
20
+ def __init__(self, fin, fout, opt):
21
+ super().__init__()
22
+ # Attributes
23
+ self.learned_shortcut = fin != fout
24
+ fmiddle = min(fin, fout)
25
+
26
+ self.opt = opt
27
+ # create conv layers
28
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
29
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
30
+ if self.learned_shortcut:
31
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
32
+
33
+ # apply spectral norm if specified
34
+ if "spectral" in opt.norm_G:
35
+ self.conv_0 = spectral_norm(self.conv_0)
36
+ self.conv_1 = spectral_norm(self.conv_1)
37
+ if self.learned_shortcut:
38
+ self.conv_s = spectral_norm(self.conv_s)
39
+
40
+ # define normalization layers
41
+ spade_config_str = opt.norm_G.replace("spectral", "")
42
+ self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
43
+ self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
44
+ if self.learned_shortcut:
45
+ self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
46
+
47
+ # note the resnet block with SPADE also takes in |seg|,
48
+ # the semantic segmentation map as input
49
+ def forward(self, x, seg, degraded_image):
50
+ x_s = self.shortcut(x, seg, degraded_image)
51
+
52
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
53
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
54
+
55
+ out = x_s + dx
56
+
57
+ return out
58
+
59
+ def shortcut(self, x, seg, degraded_image):
60
+ if self.learned_shortcut:
61
+ x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
62
+ else:
63
+ x_s = x
64
+ return x_s
65
+
66
+ def actvn(self, x):
67
+ return F.leaky_relu(x, 2e-1)
68
+
69
+
70
+ # ResNet block used in pix2pixHD
71
+ # We keep the same architecture as pix2pixHD.
72
+ class ResnetBlock(nn.Module):
73
+ def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
74
+ super().__init__()
75
+
76
+ pw = (kernel_size - 1) // 2
77
+ self.conv_block = nn.Sequential(
78
+ nn.ReflectionPad2d(pw),
79
+ norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
80
+ activation,
81
+ nn.ReflectionPad2d(pw),
82
+ norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
83
+ )
84
+
85
+ def forward(self, x):
86
+ y = self.conv_block(x)
87
+ out = x + y
88
+ return out
89
+
90
+
91
+ # VGG architecter, used for the perceptual loss using a pretrained VGG network
92
+ class VGG19(torch.nn.Module):
93
+ def __init__(self, requires_grad=False):
94
+ super().__init__()
95
+ vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
96
+ self.slice1 = torch.nn.Sequential()
97
+ self.slice2 = torch.nn.Sequential()
98
+ self.slice3 = torch.nn.Sequential()
99
+ self.slice4 = torch.nn.Sequential()
100
+ self.slice5 = torch.nn.Sequential()
101
+ for x in range(2):
102
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
103
+ for x in range(2, 7):
104
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
105
+ for x in range(7, 12):
106
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
107
+ for x in range(12, 21):
108
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(21, 30):
110
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
111
+ if not requires_grad:
112
+ for param in self.parameters():
113
+ param.requires_grad = False
114
+
115
+ def forward(self, X):
116
+ h_relu1 = self.slice1(X)
117
+ h_relu2 = self.slice2(h_relu1)
118
+ h_relu3 = self.slice3(h_relu2)
119
+ h_relu4 = self.slice4(h_relu3)
120
+ h_relu5 = self.slice5(h_relu4)
121
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
122
+ return out
123
+
124
+
125
+ class SPADEResnetBlock_non_spade(nn.Module):
126
+ def __init__(self, fin, fout, opt):
127
+ super().__init__()
128
+ # Attributes
129
+ self.learned_shortcut = fin != fout
130
+ fmiddle = min(fin, fout)
131
+
132
+ self.opt = opt
133
+ # create conv layers
134
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
135
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
136
+ if self.learned_shortcut:
137
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
138
+
139
+ # apply spectral norm if specified
140
+ if "spectral" in opt.norm_G:
141
+ self.conv_0 = spectral_norm(self.conv_0)
142
+ self.conv_1 = spectral_norm(self.conv_1)
143
+ if self.learned_shortcut:
144
+ self.conv_s = spectral_norm(self.conv_s)
145
+
146
+ # define normalization layers
147
+ spade_config_str = opt.norm_G.replace("spectral", "")
148
+ self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
149
+ self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
150
+ if self.learned_shortcut:
151
+ self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
152
+
153
+ # note the resnet block with SPADE also takes in |seg|,
154
+ # the semantic segmentation map as input
155
+ def forward(self, x, seg, degraded_image):
156
+ x_s = self.shortcut(x, seg, degraded_image)
157
+
158
+ dx = self.conv_0(self.actvn(x))
159
+ dx = self.conv_1(self.actvn(dx))
160
+
161
+ out = x_s + dx
162
+
163
+ return out
164
+
165
+ def shortcut(self, x, seg, degraded_image):
166
+ if self.learned_shortcut:
167
+ x_s = self.conv_s(x)
168
+ else:
169
+ x_s = x
170
+ return x_s
171
+
172
+ def actvn(self, x):
173
+ return F.leaky_relu(x, 2e-1)
Face_Enhancement/models/networks/base_network.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import init
6
+
7
+
8
+ class BaseNetwork(nn.Module):
9
+ def __init__(self):
10
+ super(BaseNetwork, self).__init__()
11
+
12
+ @staticmethod
13
+ def modify_commandline_options(parser, is_train):
14
+ return parser
15
+
16
+ def print_network(self):
17
+ if isinstance(self, list):
18
+ self = self[0]
19
+ num_params = 0
20
+ for param in self.parameters():
21
+ num_params += param.numel()
22
+ print(
23
+ "Network [%s] was created. Total number of parameters: %.1f million. "
24
+ "To see the architecture, do print(network)." % (type(self).__name__, num_params / 1000000)
25
+ )
26
+
27
+ def init_weights(self, init_type="normal", gain=0.02):
28
+ def init_func(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("BatchNorm2d") != -1:
31
+ if hasattr(m, "weight") and m.weight is not None:
32
+ init.normal_(m.weight.data, 1.0, gain)
33
+ if hasattr(m, "bias") and m.bias is not None:
34
+ init.constant_(m.bias.data, 0.0)
35
+ elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
36
+ if init_type == "normal":
37
+ init.normal_(m.weight.data, 0.0, gain)
38
+ elif init_type == "xavier":
39
+ init.xavier_normal_(m.weight.data, gain=gain)
40
+ elif init_type == "xavier_uniform":
41
+ init.xavier_uniform_(m.weight.data, gain=1.0)
42
+ elif init_type == "kaiming":
43
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
44
+ elif init_type == "orthogonal":
45
+ init.orthogonal_(m.weight.data, gain=gain)
46
+ elif init_type == "none": # uses pytorch's default init method
47
+ m.reset_parameters()
48
+ else:
49
+ raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
50
+ if hasattr(m, "bias") and m.bias is not None:
51
+ init.constant_(m.bias.data, 0.0)
52
+
53
+ self.apply(init_func)
54
+
55
+ # propagate to children
56
+ for m in self.children():
57
+ if hasattr(m, "init_weights"):
58
+ m.init_weights(init_type, gain)
Face_Enhancement/models/networks/encoder.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from models.networks.base_network import BaseNetwork
8
+ from models.networks.normalization import get_nonspade_norm_layer
9
+
10
+
11
+ class ConvEncoder(BaseNetwork):
12
+ """ Same architecture as the image discriminator """
13
+
14
+ def __init__(self, opt):
15
+ super().__init__()
16
+
17
+ kw = 3
18
+ pw = int(np.ceil((kw - 1.0) / 2))
19
+ ndf = opt.ngf
20
+ norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
21
+ self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
22
+ self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))
23
+ self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))
24
+ self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))
25
+ self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
26
+ if opt.crop_size >= 256:
27
+ self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
28
+
29
+ self.so = s0 = 4
30
+ self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
31
+ self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
32
+
33
+ self.actvn = nn.LeakyReLU(0.2, False)
34
+ self.opt = opt
35
+
36
+ def forward(self, x):
37
+ if x.size(2) != 256 or x.size(3) != 256:
38
+ x = F.interpolate(x, size=(256, 256), mode="bilinear")
39
+
40
+ x = self.layer1(x)
41
+ x = self.layer2(self.actvn(x))
42
+ x = self.layer3(self.actvn(x))
43
+ x = self.layer4(self.actvn(x))
44
+ x = self.layer5(self.actvn(x))
45
+ if self.opt.crop_size >= 256:
46
+ x = self.layer6(self.actvn(x))
47
+ x = self.actvn(x)
48
+
49
+ x = x.view(x.size(0), -1)
50
+ mu = self.fc_mu(x)
51
+ logvar = self.fc_var(x)
52
+
53
+ return mu, logvar
Face_Enhancement/models/networks/generator.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from models.networks.base_network import BaseNetwork
8
+ from models.networks.normalization import get_nonspade_norm_layer
9
+ from models.networks.architecture import ResnetBlock as ResnetBlock
10
+ from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
11
+ from models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade
12
+
13
+
14
+ class SPADEGenerator(BaseNetwork):
15
+ @staticmethod
16
+ def modify_commandline_options(parser, is_train):
17
+ parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
18
+ parser.add_argument(
19
+ "--num_upsampling_layers",
20
+ choices=("normal", "more", "most"),
21
+ default="normal",
22
+ help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
23
+ )
24
+
25
+ return parser
26
+
27
+ def __init__(self, opt):
28
+ super().__init__()
29
+ self.opt = opt
30
+ nf = opt.ngf
31
+
32
+ self.sw, self.sh = self.compute_latent_vector_size(opt)
33
+
34
+ print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))
35
+
36
+ if opt.use_vae:
37
+ # In case of VAE, we will sample from random z vector
38
+ self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
39
+ else:
40
+ # Otherwise, we make the network deterministic by starting with
41
+ # downsampled segmentation map instead of random z
42
+ if self.opt.no_parsing_map:
43
+ self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
44
+ else:
45
+ self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
46
+
47
+ if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
48
+ self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
49
+ else:
50
+ self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
51
+
52
+ if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
53
+ self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
54
+ self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
55
+
56
+ else:
57
+ self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
58
+ self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
59
+
60
+ if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
61
+ self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
62
+ else:
63
+ self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)
64
+
65
+ if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
66
+ self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
67
+ else:
68
+ self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)
69
+
70
+ if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
71
+ self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
72
+ else:
73
+ self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)
74
+
75
+ if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
76
+ self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
77
+ else:
78
+ self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)
79
+
80
+ final_nc = nf
81
+
82
+ if opt.num_upsampling_layers == "most":
83
+ self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
84
+ final_nc = nf // 2
85
+
86
+ self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
87
+
88
+ self.up = nn.Upsample(scale_factor=2)
89
+
90
+ def compute_latent_vector_size(self, opt):
91
+ if opt.num_upsampling_layers == "normal":
92
+ num_up_layers = 5
93
+ elif opt.num_upsampling_layers == "more":
94
+ num_up_layers = 6
95
+ elif opt.num_upsampling_layers == "most":
96
+ num_up_layers = 7
97
+ else:
98
+ raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
99
+
100
+ sw = opt.load_size // (2 ** num_up_layers)
101
+ sh = round(sw / opt.aspect_ratio)
102
+
103
+ return sw, sh
104
+
105
+ def forward(self, input, degraded_image, z=None):
106
+ seg = input
107
+
108
+ if self.opt.use_vae:
109
+ # we sample z from unit normal and reshape the tensor
110
+ if z is None:
111
+ z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
112
+ x = self.fc(z)
113
+ x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
114
+ else:
115
+ # we downsample segmap and run convolution
116
+ if self.opt.no_parsing_map:
117
+ x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
118
+ else:
119
+ x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
120
+ x = self.fc(x)
121
+
122
+ x = self.head_0(x, seg, degraded_image)
123
+
124
+ x = self.up(x)
125
+ x = self.G_middle_0(x, seg, degraded_image)
126
+
127
+ if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
128
+ x = self.up(x)
129
+
130
+ x = self.G_middle_1(x, seg, degraded_image)
131
+
132
+ x = self.up(x)
133
+ x = self.up_0(x, seg, degraded_image)
134
+ x = self.up(x)
135
+ x = self.up_1(x, seg, degraded_image)
136
+ x = self.up(x)
137
+ x = self.up_2(x, seg, degraded_image)
138
+ x = self.up(x)
139
+ x = self.up_3(x, seg, degraded_image)
140
+
141
+ if self.opt.num_upsampling_layers == "most":
142
+ x = self.up(x)
143
+ x = self.up_4(x, seg, degraded_image)
144
+
145
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
146
+ x = F.tanh(x)
147
+
148
+ return x
149
+
150
+
151
+ class Pix2PixHDGenerator(BaseNetwork):
152
+ @staticmethod
153
+ def modify_commandline_options(parser, is_train):
154
+ parser.add_argument(
155
+ "--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
156
+ )
157
+ parser.add_argument(
158
+ "--resnet_n_blocks",
159
+ type=int,
160
+ default=9,
161
+ help="number of residual blocks in the global generator network",
162
+ )
163
+ parser.add_argument(
164
+ "--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
165
+ )
166
+ parser.add_argument(
167
+ "--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
168
+ )
169
+ # parser.set_defaults(norm_G='instance')
170
+ return parser
171
+
172
+ def __init__(self, opt):
173
+ super().__init__()
174
+ input_nc = 3
175
+
176
+ # print("xxxxx")
177
+ # print(opt.norm_G)
178
+ norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
179
+ activation = nn.ReLU(False)
180
+
181
+ model = []
182
+
183
+ # initial conv
184
+ model += [
185
+ nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
186
+ norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
187
+ activation,
188
+ ]
189
+
190
+ # downsample
191
+ mult = 1
192
+ for i in range(opt.resnet_n_downsample):
193
+ model += [
194
+ norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
195
+ activation,
196
+ ]
197
+ mult *= 2
198
+
199
+ # resnet blocks
200
+ for i in range(opt.resnet_n_blocks):
201
+ model += [
202
+ ResnetBlock(
203
+ opt.ngf * mult,
204
+ norm_layer=norm_layer,
205
+ activation=activation,
206
+ kernel_size=opt.resnet_kernel_size,
207
+ )
208
+ ]
209
+
210
+ # upsample
211
+ for i in range(opt.resnet_n_downsample):
212
+ nc_in = int(opt.ngf * mult)
213
+ nc_out = int((opt.ngf * mult) / 2)
214
+ model += [
215
+ norm_layer(
216
+ nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
217
+ ),
218
+ activation,
219
+ ]
220
+ mult = mult // 2
221
+
222
+ # final output conv
223
+ model += [
224
+ nn.ReflectionPad2d(3),
225
+ nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
226
+ nn.Tanh(),
227
+ ]
228
+
229
+ self.model = nn.Sequential(*model)
230
+
231
+ def forward(self, input, degraded_image, z=None):
232
+ return self.model(degraded_image)
233
+
Face_Enhancement/models/networks/normalization.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import re
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
9
+ import torch.nn.utils.spectral_norm as spectral_norm
10
+
11
+
12
+ def get_nonspade_norm_layer(opt, norm_type="instance"):
13
+ # helper function to get # output channels of the previous layer
14
+ def get_out_channel(layer):
15
+ if hasattr(layer, "out_channels"):
16
+ return getattr(layer, "out_channels")
17
+ return layer.weight.size(0)
18
+
19
+ # this function will be returned
20
+ def add_norm_layer(layer):
21
+ nonlocal norm_type
22
+ if norm_type.startswith("spectral"):
23
+ layer = spectral_norm(layer)
24
+ subnorm_type = norm_type[len("spectral") :]
25
+
26
+ if subnorm_type == "none" or len(subnorm_type) == 0:
27
+ return layer
28
+
29
+ # remove bias in the previous layer, which is meaningless
30
+ # since it has no effect after normalization
31
+ if getattr(layer, "bias", None) is not None:
32
+ delattr(layer, "bias")
33
+ layer.register_parameter("bias", None)
34
+
35
+ if subnorm_type == "batch":
36
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
37
+ elif subnorm_type == "sync_batch":
38
+ norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
39
+ elif subnorm_type == "instance":
40
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
41
+ else:
42
+ raise ValueError("normalization layer %s is not recognized" % subnorm_type)
43
+
44
+ return nn.Sequential(layer, norm_layer)
45
+
46
+ return add_norm_layer
47
+
48
+
49
+ class SPADE(nn.Module):
50
+ def __init__(self, config_text, norm_nc, label_nc, opt):
51
+ super().__init__()
52
+
53
+ assert config_text.startswith("spade")
54
+ parsed = re.search("spade(\D+)(\d)x\d", config_text)
55
+ param_free_norm_type = str(parsed.group(1))
56
+ ks = int(parsed.group(2))
57
+ self.opt = opt
58
+ if param_free_norm_type == "instance":
59
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
60
+ elif param_free_norm_type == "syncbatch":
61
+ self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
62
+ elif param_free_norm_type == "batch":
63
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
64
+ else:
65
+ raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type)
66
+
67
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
68
+ nhidden = 128
69
+
70
+ pw = ks // 2
71
+
72
+ if self.opt.no_parsing_map:
73
+ self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
74
+ else:
75
+ self.mlp_shared = nn.Sequential(
76
+ nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
77
+ )
78
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
79
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
80
+
81
+ def forward(self, x, segmap, degraded_image):
82
+
83
+ # Part 1. generate parameter-free normalized activations
84
+ normalized = self.param_free_norm(x)
85
+
86
+ # Part 2. produce scaling and bias conditioned on semantic map
87
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
88
+ degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear")
89
+
90
+ if self.opt.no_parsing_map:
91
+ actv = self.mlp_shared(degraded_face)
92
+ else:
93
+ actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))
94
+ gamma = self.mlp_gamma(actv)
95
+ beta = self.mlp_beta(actv)
96
+
97
+ # apply scale and bias
98
+ out = normalized * (1 + gamma) + beta
99
+
100
+ return out
Face_Enhancement/models/networks/sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import set_sbn_eps_mode
12
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
13
+ from .batchnorm import patch_sync_batchnorm, convert_model
14
+ from .replicate import DataParallelWithCallback, patch_replication_callback
Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+ import contextlib
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from torch.nn.modules.batchnorm import _BatchNorm
18
+
19
+ try:
20
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21
+ except ImportError:
22
+ ReduceAddCoalesced = Broadcast = None
23
+
24
+ try:
25
+ from jactorch.parallel.comm import SyncMaster
26
+ from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27
+ except ImportError:
28
+ from .comm import SyncMaster
29
+ from .replicate import DataParallelWithCallback
30
+
31
+ __all__ = [
32
+ 'set_sbn_eps_mode',
33
+ 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
34
+ 'patch_sync_batchnorm', 'convert_model'
35
+ ]
36
+
37
+
38
+ SBN_EPS_MODE = 'clamp'
39
+
40
+
41
+ def set_sbn_eps_mode(mode):
42
+ global SBN_EPS_MODE
43
+ assert mode in ('clamp', 'plus')
44
+ SBN_EPS_MODE = mode
45
+
46
+
47
+ def _sum_ft(tensor):
48
+ """sum over the first and last dimention"""
49
+ return tensor.sum(dim=0).sum(dim=-1)
50
+
51
+
52
+ def _unsqueeze_ft(tensor):
53
+ """add new dimensions at the front and the tail"""
54
+ return tensor.unsqueeze(0).unsqueeze(-1)
55
+
56
+
57
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
58
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
59
+
60
+
61
+ class _SynchronizedBatchNorm(_BatchNorm):
62
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
63
+ assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
64
+
65
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
66
+ track_running_stats=track_running_stats)
67
+
68
+ if not self.track_running_stats:
69
+ import warnings
70
+ warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
71
+
72
+ self._sync_master = SyncMaster(self._data_parallel_master)
73
+
74
+ self._is_parallel = False
75
+ self._parallel_id = None
76
+ self._slave_pipe = None
77
+
78
+ def forward(self, input):
79
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
80
+ if not (self._is_parallel and self.training):
81
+ return F.batch_norm(
82
+ input, self.running_mean, self.running_var, self.weight, self.bias,
83
+ self.training, self.momentum, self.eps)
84
+
85
+ # Resize the input to (B, C, -1).
86
+ input_shape = input.size()
87
+ assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
88
+ input = input.view(input.size(0), self.num_features, -1)
89
+
90
+ # Compute the sum and square-sum.
91
+ sum_size = input.size(0) * input.size(2)
92
+ input_sum = _sum_ft(input)
93
+ input_ssum = _sum_ft(input ** 2)
94
+
95
+ # Reduce-and-broadcast the statistics.
96
+ if self._parallel_id == 0:
97
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
98
+ else:
99
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
100
+
101
+ # Compute the output.
102
+ if self.affine:
103
+ # MJY:: Fuse the multiplication for speed.
104
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
105
+ else:
106
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
107
+
108
+ # Reshape it.
109
+ return output.view(input_shape)
110
+
111
+ def __data_parallel_replicate__(self, ctx, copy_id):
112
+ self._is_parallel = True
113
+ self._parallel_id = copy_id
114
+
115
+ # parallel_id == 0 means master device.
116
+ if self._parallel_id == 0:
117
+ ctx.sync_master = self._sync_master
118
+ else:
119
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
120
+
121
+ def _data_parallel_master(self, intermediates):
122
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
123
+
124
+ # Always using same "device order" makes the ReduceAdd operation faster.
125
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
126
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
127
+
128
+ to_reduce = [i[1][:2] for i in intermediates]
129
+ to_reduce = [j for i in to_reduce for j in i] # flatten
130
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
131
+
132
+ sum_size = sum([i[1].sum_size for i in intermediates])
133
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
134
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
135
+
136
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
137
+
138
+ outputs = []
139
+ for i, rec in enumerate(intermediates):
140
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
141
+
142
+ return outputs
143
+
144
+ def _compute_mean_std(self, sum_, ssum, size):
145
+ """Compute the mean and standard-deviation with sum and square-sum. This method
146
+ also maintains the moving average on the master device."""
147
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
148
+ mean = sum_ / size
149
+ sumvar = ssum - sum_ * mean
150
+ unbias_var = sumvar / (size - 1)
151
+ bias_var = sumvar / size
152
+
153
+ if hasattr(torch, 'no_grad'):
154
+ with torch.no_grad():
155
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
156
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
157
+ else:
158
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
159
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
160
+
161
+ if SBN_EPS_MODE == 'clamp':
162
+ return mean, bias_var.clamp(self.eps) ** -0.5
163
+ elif SBN_EPS_MODE == 'plus':
164
+ return mean, (bias_var + self.eps) ** -0.5
165
+ else:
166
+ raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
167
+
168
+
169
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
170
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
171
+ mini-batch.
172
+
173
+ .. math::
174
+
175
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
176
+
177
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
178
+ standard-deviation are reduced across all devices during training.
179
+
180
+ For example, when one uses `nn.DataParallel` to wrap the network during
181
+ training, PyTorch's implementation normalize the tensor on each device using
182
+ the statistics only on that device, which accelerated the computation and
183
+ is also easy to implement, but the statistics might be inaccurate.
184
+ Instead, in this synchronized version, the statistics will be computed
185
+ over all training samples distributed on multiple devices.
186
+
187
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
188
+ as the built-in PyTorch implementation.
189
+
190
+ The mean and standard-deviation are calculated per-dimension over
191
+ the mini-batches and gamma and beta are learnable parameter vectors
192
+ of size C (where C is the input size).
193
+
194
+ During training, this layer keeps a running estimate of its computed mean
195
+ and variance. The running sum is kept with a default momentum of 0.1.
196
+
197
+ During evaluation, this running mean/variance is used for normalization.
198
+
199
+ Because the BatchNorm is done over the `C` dimension, computing statistics
200
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
201
+
202
+ Args:
203
+ num_features: num_features from an expected input of size
204
+ `batch_size x num_features [x width]`
205
+ eps: a value added to the denominator for numerical stability.
206
+ Default: 1e-5
207
+ momentum: the value used for the running_mean and running_var
208
+ computation. Default: 0.1
209
+ affine: a boolean value that when set to ``True``, gives the layer learnable
210
+ affine parameters. Default: ``True``
211
+
212
+ Shape::
213
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
214
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
215
+
216
+ Examples:
217
+ >>> # With Learnable Parameters
218
+ >>> m = SynchronizedBatchNorm1d(100)
219
+ >>> # Without Learnable Parameters
220
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
221
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
222
+ >>> output = m(input)
223
+ """
224
+
225
+ def _check_input_dim(self, input):
226
+ if input.dim() != 2 and input.dim() != 3:
227
+ raise ValueError('expected 2D or 3D input (got {}D input)'
228
+ .format(input.dim()))
229
+
230
+
231
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
232
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
233
+ of 3d inputs
234
+
235
+ .. math::
236
+
237
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
238
+
239
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
240
+ standard-deviation are reduced across all devices during training.
241
+
242
+ For example, when one uses `nn.DataParallel` to wrap the network during
243
+ training, PyTorch's implementation normalize the tensor on each device using
244
+ the statistics only on that device, which accelerated the computation and
245
+ is also easy to implement, but the statistics might be inaccurate.
246
+ Instead, in this synchronized version, the statistics will be computed
247
+ over all training samples distributed on multiple devices.
248
+
249
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
250
+ as the built-in PyTorch implementation.
251
+
252
+ The mean and standard-deviation are calculated per-dimension over
253
+ the mini-batches and gamma and beta are learnable parameter vectors
254
+ of size C (where C is the input size).
255
+
256
+ During training, this layer keeps a running estimate of its computed mean
257
+ and variance. The running sum is kept with a default momentum of 0.1.
258
+
259
+ During evaluation, this running mean/variance is used for normalization.
260
+
261
+ Because the BatchNorm is done over the `C` dimension, computing statistics
262
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
263
+
264
+ Args:
265
+ num_features: num_features from an expected input of
266
+ size batch_size x num_features x height x width
267
+ eps: a value added to the denominator for numerical stability.
268
+ Default: 1e-5
269
+ momentum: the value used for the running_mean and running_var
270
+ computation. Default: 0.1
271
+ affine: a boolean value that when set to ``True``, gives the layer learnable
272
+ affine parameters. Default: ``True``
273
+
274
+ Shape::
275
+ - Input: :math:`(N, C, H, W)`
276
+ - Output: :math:`(N, C, H, W)` (same shape as input)
277
+
278
+ Examples:
279
+ >>> # With Learnable Parameters
280
+ >>> m = SynchronizedBatchNorm2d(100)
281
+ >>> # Without Learnable Parameters
282
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
283
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
284
+ >>> output = m(input)
285
+ """
286
+
287
+ def _check_input_dim(self, input):
288
+ if input.dim() != 4:
289
+ raise ValueError('expected 4D input (got {}D input)'
290
+ .format(input.dim()))
291
+
292
+
293
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
294
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
295
+ of 4d inputs
296
+
297
+ .. math::
298
+
299
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
300
+
301
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
302
+ standard-deviation are reduced across all devices during training.
303
+
304
+ For example, when one uses `nn.DataParallel` to wrap the network during
305
+ training, PyTorch's implementation normalize the tensor on each device using
306
+ the statistics only on that device, which accelerated the computation and
307
+ is also easy to implement, but the statistics might be inaccurate.
308
+ Instead, in this synchronized version, the statistics will be computed
309
+ over all training samples distributed on multiple devices.
310
+
311
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
312
+ as the built-in PyTorch implementation.
313
+
314
+ The mean and standard-deviation are calculated per-dimension over
315
+ the mini-batches and gamma and beta are learnable parameter vectors
316
+ of size C (where C is the input size).
317
+
318
+ During training, this layer keeps a running estimate of its computed mean
319
+ and variance. The running sum is kept with a default momentum of 0.1.
320
+
321
+ During evaluation, this running mean/variance is used for normalization.
322
+
323
+ Because the BatchNorm is done over the `C` dimension, computing statistics
324
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
325
+ or Spatio-temporal BatchNorm
326
+
327
+ Args:
328
+ num_features: num_features from an expected input of
329
+ size batch_size x num_features x depth x height x width
330
+ eps: a value added to the denominator for numerical stability.
331
+ Default: 1e-5
332
+ momentum: the value used for the running_mean and running_var
333
+ computation. Default: 0.1
334
+ affine: a boolean value that when set to ``True``, gives the layer learnable
335
+ affine parameters. Default: ``True``
336
+
337
+ Shape::
338
+ - Input: :math:`(N, C, D, H, W)`
339
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
340
+
341
+ Examples:
342
+ >>> # With Learnable Parameters
343
+ >>> m = SynchronizedBatchNorm3d(100)
344
+ >>> # Without Learnable Parameters
345
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
346
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
347
+ >>> output = m(input)
348
+ """
349
+
350
+ def _check_input_dim(self, input):
351
+ if input.dim() != 5:
352
+ raise ValueError('expected 5D input (got {}D input)'
353
+ .format(input.dim()))
354
+
355
+
356
+ @contextlib.contextmanager
357
+ def patch_sync_batchnorm():
358
+ import torch.nn as nn
359
+
360
+ backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
361
+
362
+ nn.BatchNorm1d = SynchronizedBatchNorm1d
363
+ nn.BatchNorm2d = SynchronizedBatchNorm2d
364
+ nn.BatchNorm3d = SynchronizedBatchNorm3d
365
+
366
+ yield
367
+
368
+ nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
369
+
370
+
371
+ def convert_model(module):
372
+ """Traverse the input module and its child recursively
373
+ and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
374
+ to SynchronizedBatchNorm*N*d
375
+
376
+ Args:
377
+ module: the input module needs to be convert to SyncBN model
378
+
379
+ Examples:
380
+ >>> import torch.nn as nn
381
+ >>> import torchvision
382
+ >>> # m is a standard pytorch model
383
+ >>> m = torchvision.models.resnet18(True)
384
+ >>> m = nn.DataParallel(m)
385
+ >>> # after convert, m is using SyncBN
386
+ >>> m = convert_model(m)
387
+ """
388
+ if isinstance(module, torch.nn.DataParallel):
389
+ mod = module.module
390
+ mod = convert_model(mod)
391
+ mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
392
+ return mod
393
+
394
+ mod = module
395
+ for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
396
+ torch.nn.modules.batchnorm.BatchNorm2d,
397
+ torch.nn.modules.batchnorm.BatchNorm3d],
398
+ [SynchronizedBatchNorm1d,
399
+ SynchronizedBatchNorm2d,
400
+ SynchronizedBatchNorm3d]):
401
+ if isinstance(module, pth_module):
402
+ mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
403
+ mod.running_mean = module.running_mean
404
+ mod.running_var = module.running_var
405
+ if module.affine:
406
+ mod.weight.data = module.weight.data.clone().detach()
407
+ mod.bias.data = module.bias.data.clone().detach()
408
+
409
+ for name, child in module.named_children():
410
+ mod.add_module(name, convert_model(child))
411
+
412
+ return mod
Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : batchnorm_reimpl.py
4
+ # Author : acgtyrant
5
+ # Date : 11/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.init as init
14
+
15
+ __all__ = ['BatchNorm2dReimpl']
16
+
17
+
18
+ class BatchNorm2dReimpl(nn.Module):
19
+ """
20
+ A re-implementation of batch normalization, used for testing the numerical
21
+ stability.
22
+
23
+ Author: acgtyrant
24
+ See also:
25
+ https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26
+ """
27
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
28
+ super().__init__()
29
+
30
+ self.num_features = num_features
31
+ self.eps = eps
32
+ self.momentum = momentum
33
+ self.weight = nn.Parameter(torch.empty(num_features))
34
+ self.bias = nn.Parameter(torch.empty(num_features))
35
+ self.register_buffer('running_mean', torch.zeros(num_features))
36
+ self.register_buffer('running_var', torch.ones(num_features))
37
+ self.reset_parameters()
38
+
39
+ def reset_running_stats(self):
40
+ self.running_mean.zero_()
41
+ self.running_var.fill_(1)
42
+
43
+ def reset_parameters(self):
44
+ self.reset_running_stats()
45
+ init.uniform_(self.weight)
46
+ init.zeros_(self.bias)
47
+
48
+ def forward(self, input_):
49
+ batchsize, channels, height, width = input_.size()
50
+ numel = batchsize * height * width
51
+ input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52
+ sum_ = input_.sum(1)
53
+ sum_of_square = input_.pow(2).sum(1)
54
+ mean = sum_ / numel
55
+ sumvar = sum_of_square - sum_ * mean
56
+
57
+ self.running_mean = (
58
+ (1 - self.momentum) * self.running_mean
59
+ + self.momentum * mean.detach()
60
+ )
61
+ unbias_var = sumvar / (numel - 1)
62
+ self.running_var = (
63
+ (1 - self.momentum) * self.running_var
64
+ + self.momentum * unbias_var.detach()
65
+ )
66
+
67
+ bias_var = sumvar / numel
68
+ inv_std = 1 / (bias_var + self.eps).pow(0.5)
69
+ output = (
70
+ (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71
+ self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72
+
73
+ return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74
+
Face_Enhancement/models/networks/sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
Face_Enhancement/models/networks/sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
Face_Enhancement/models/networks/sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+ import torch
13
+
14
+
15
+ class TorchTestCase(unittest.TestCase):
16
+ def assertTensorClose(self, x, y):
17
+ adiff = float((x - y).abs().max())
18
+ if (y == 0).all():
19
+ rdiff = 'NaN'
20
+ else:
21
+ rdiff = float((adiff / y).abs().max())
22
+
23
+ message = (
24
+ 'Tensor close check failed\n'
25
+ 'adiff={}\n'
26
+ 'rdiff={}\n'
27
+ ).format(adiff, rdiff)
28
+ self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
29
+
Face_Enhancement/models/pix2pix_model.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch
5
+ import models.networks as networks
6
+ import util.util as util
7
+
8
+
9
+ class Pix2PixModel(torch.nn.Module):
10
+ @staticmethod
11
+ def modify_commandline_options(parser, is_train):
12
+ networks.modify_commandline_options(parser, is_train)
13
+ return parser
14
+
15
+ def __init__(self, opt):
16
+ super().__init__()
17
+ self.opt = opt
18
+ self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
19
+ self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor
20
+
21
+ self.netG, self.netD, self.netE = self.initialize_networks(opt)
22
+
23
+ # set loss functions
24
+ if opt.isTrain:
25
+ self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
26
+ self.criterionFeat = torch.nn.L1Loss()
27
+ if not opt.no_vgg_loss:
28
+ self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
29
+ if opt.use_vae:
30
+ self.KLDLoss = networks.KLDLoss()
31
+
32
+ # Entry point for all calls involving forward pass
33
+ # of deep networks. We used this approach since DataParallel module
34
+ # can't parallelize custom functions, we branch to different
35
+ # routines based on |mode|.
36
+ def forward(self, data, mode):
37
+ input_semantics, real_image, degraded_image = self.preprocess_input(data)
38
+
39
+ if mode == "generator":
40
+ g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image)
41
+ return g_loss, generated
42
+ elif mode == "discriminator":
43
+ d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image)
44
+ return d_loss
45
+ elif mode == "encode_only":
46
+ z, mu, logvar = self.encode_z(real_image)
47
+ return mu, logvar
48
+ elif mode == "inference":
49
+ with torch.no_grad():
50
+ fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
51
+ return fake_image
52
+ else:
53
+ raise ValueError("|mode| is invalid")
54
+
55
+ def create_optimizers(self, opt):
56
+ G_params = list(self.netG.parameters())
57
+ if opt.use_vae:
58
+ G_params += list(self.netE.parameters())
59
+ if opt.isTrain:
60
+ D_params = list(self.netD.parameters())
61
+
62
+ beta1, beta2 = opt.beta1, opt.beta2
63
+ if opt.no_TTUR:
64
+ G_lr, D_lr = opt.lr, opt.lr
65
+ else:
66
+ G_lr, D_lr = opt.lr / 2, opt.lr * 2
67
+
68
+ optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
69
+ optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
70
+
71
+ return optimizer_G, optimizer_D
72
+
73
+ def save(self, epoch):
74
+ util.save_network(self.netG, "G", epoch, self.opt)
75
+ util.save_network(self.netD, "D", epoch, self.opt)
76
+ if self.opt.use_vae:
77
+ util.save_network(self.netE, "E", epoch, self.opt)
78
+
79
+ ############################################################################
80
+ # Private helper methods
81
+ ############################################################################
82
+
83
+ def initialize_networks(self, opt):
84
+ netG = networks.define_G(opt)
85
+ netD = networks.define_D(opt) if opt.isTrain else None
86
+ netE = networks.define_E(opt) if opt.use_vae else None
87
+
88
+ if not opt.isTrain or opt.continue_train:
89
+ netG = util.load_network(netG, "G", opt.which_epoch, opt)
90
+ if opt.isTrain:
91
+ netD = util.load_network(netD, "D", opt.which_epoch, opt)
92
+ if opt.use_vae:
93
+ netE = util.load_network(netE, "E", opt.which_epoch, opt)
94
+
95
+ return netG, netD, netE
96
+
97
+ # preprocess the input, such as moving the tensors to GPUs and
98
+ # transforming the label map to one-hot encoding
99
+ # |data|: dictionary of the input data
100
+
101
+ def preprocess_input(self, data):
102
+ # move to GPU and change data types
103
+ # data['label'] = data['label'].long()
104
+
105
+ if not self.opt.isTrain:
106
+ if self.use_gpu():
107
+ data["label"] = data["label"].cuda()
108
+ data["image"] = data["image"].cuda()
109
+ return data["label"], data["image"], data["image"]
110
+
111
+ ## While testing, the input image is the degraded face
112
+ if self.use_gpu():
113
+ data["label"] = data["label"].cuda()
114
+ data["degraded_image"] = data["degraded_image"].cuda()
115
+ data["image"] = data["image"].cuda()
116
+
117
+ # # create one-hot label map
118
+ # label_map = data['label']
119
+ # bs, _, h, w = label_map.size()
120
+ # nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
121
+ # else self.opt.label_nc
122
+ # input_label = self.FloatTensor(bs, nc, h, w).zero_()
123
+ # input_semantics = input_label.scatter_(1, label_map, 1.0)
124
+
125
+ return data["label"], data["image"], data["degraded_image"]
126
+
127
+ def compute_generator_loss(self, input_semantics, degraded_image, real_image):
128
+ G_losses = {}
129
+
130
+ fake_image, KLD_loss = self.generate_fake(
131
+ input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae
132
+ )
133
+
134
+ if self.opt.use_vae:
135
+ G_losses["KLD"] = KLD_loss
136
+
137
+ pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
138
+
139
+ G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)
140
+
141
+ if not self.opt.no_ganFeat_loss:
142
+ num_D = len(pred_fake)
143
+ GAN_Feat_loss = self.FloatTensor(1).fill_(0)
144
+ for i in range(num_D): # for each discriminator
145
+ # last output is the final prediction, so we exclude it
146
+ num_intermediate_outputs = len(pred_fake[i]) - 1
147
+ for j in range(num_intermediate_outputs): # for each layer output
148
+ unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
149
+ GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
150
+ G_losses["GAN_Feat"] = GAN_Feat_loss
151
+
152
+ if not self.opt.no_vgg_loss:
153
+ G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg
154
+
155
+ return G_losses, fake_image
156
+
157
+ def compute_discriminator_loss(self, input_semantics, degraded_image, real_image):
158
+ D_losses = {}
159
+ with torch.no_grad():
160
+ fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
161
+ fake_image = fake_image.detach()
162
+ fake_image.requires_grad_()
163
+
164
+ pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
165
+
166
+ D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
167
+ D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)
168
+
169
+ return D_losses
170
+
171
+ def encode_z(self, real_image):
172
+ mu, logvar = self.netE(real_image)
173
+ z = self.reparameterize(mu, logvar)
174
+ return z, mu, logvar
175
+
176
+ def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):
177
+ z = None
178
+ KLD_loss = None
179
+ if self.opt.use_vae:
180
+ z, mu, logvar = self.encode_z(real_image)
181
+ if compute_kld_loss:
182
+ KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
183
+
184
+ fake_image = self.netG(input_semantics, degraded_image, z=z)
185
+
186
+ assert (
187
+ not compute_kld_loss
188
+ ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
189
+
190
+ return fake_image, KLD_loss
191
+
192
+ # Given fake and real image, return the prediction of discriminator
193
+ # for each fake and real image.
194
+
195
+ def discriminate(self, input_semantics, fake_image, real_image):
196
+
197
+ if self.opt.no_parsing_map:
198
+ fake_concat = fake_image
199
+ real_concat = real_image
200
+ else:
201
+ fake_concat = torch.cat([input_semantics, fake_image], dim=1)
202
+ real_concat = torch.cat([input_semantics, real_image], dim=1)
203
+
204
+ # In Batch Normalization, the fake and real images are
205
+ # recommended to be in the same batch to avoid disparate
206
+ # statistics in fake and real images.
207
+ # So both fake and real images are fed to D all at once.
208
+ fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
209
+
210
+ discriminator_out = self.netD(fake_and_real)
211
+
212
+ pred_fake, pred_real = self.divide_pred(discriminator_out)
213
+
214
+ return pred_fake, pred_real
215
+
216
+ # Take the prediction of fake and real images from the combined batch
217
+ def divide_pred(self, pred):
218
+ # the prediction contains the intermediate outputs of multiscale GAN,
219
+ # so it's usually a list
220
+ if type(pred) == list:
221
+ fake = []
222
+ real = []
223
+ for p in pred:
224
+ fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
225
+ real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
226
+ else:
227
+ fake = pred[: pred.size(0) // 2]
228
+ real = pred[pred.size(0) // 2 :]
229
+
230
+ return fake, real
231
+
232
+ def get_edges(self, t):
233
+ edge = self.ByteTensor(t.size()).zero_()
234
+ edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
235
+ edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
236
+ edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
237
+ edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
238
+ return edge.float()
239
+
240
+ def reparameterize(self, mu, logvar):
241
+ std = torch.exp(0.5 * logvar)
242
+ eps = torch.randn_like(std)
243
+ return eps.mul(std) + mu
244
+
245
+ def use_gpu(self):
246
+ return len(self.opt.gpu_ids) > 0
Face_Enhancement/options/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
Face_Enhancement/options/base_options.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import sys
5
+ import argparse
6
+ import os
7
+ from util import util
8
+ import torch
9
+ import models
10
+ import data
11
+ import pickle
12
+
13
+
14
+ class BaseOptions:
15
+ def __init__(self):
16
+ self.initialized = False
17
+
18
+ def initialize(self, parser):
19
+ # experiment specifics
20
+ parser.add_argument(
21
+ "--name",
22
+ type=str,
23
+ default="label2coco",
24
+ help="name of the experiment. It decides where to store samples and models",
25
+ )
26
+
27
+ parser.add_argument(
28
+ "--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU"
29
+ )
30
+ parser.add_argument(
31
+ "--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here"
32
+ )
33
+ parser.add_argument("--model", type=str, default="pix2pix", help="which model to use")
34
+ parser.add_argument(
35
+ "--norm_G",
36
+ type=str,
37
+ default="spectralinstance",
38
+ help="instance normalization or batch normalization",
39
+ )
40
+ parser.add_argument(
41
+ "--norm_D",
42
+ type=str,
43
+ default="spectralinstance",
44
+ help="instance normalization or batch normalization",
45
+ )
46
+ parser.add_argument(
47
+ "--norm_E",
48
+ type=str,
49
+ default="spectralinstance",
50
+ help="instance normalization or batch normalization",
51
+ )
52
+ parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc")
53
+
54
+ # input/output sizes
55
+ parser.add_argument("--batchSize", type=int, default=1, help="input batch size")
56
+ parser.add_argument(
57
+ "--preprocess_mode",
58
+ type=str,
59
+ default="scale_width_and_crop",
60
+ help="scaling and cropping of images at load time.",
61
+ choices=(
62
+ "resize_and_crop",
63
+ "crop",
64
+ "scale_width",
65
+ "scale_width_and_crop",
66
+ "scale_shortside",
67
+ "scale_shortside_and_crop",
68
+ "fixed",
69
+ "none",
70
+ "resize",
71
+ ),
72
+ )
73
+ parser.add_argument(
74
+ "--load_size",
75
+ type=int,
76
+ default=1024,
77
+ help="Scale images to this size. The final image will be cropped to --crop_size.",
78
+ )
79
+ parser.add_argument(
80
+ "--crop_size",
81
+ type=int,
82
+ default=512,
83
+ help="Crop to the width of crop_size (after initially scaling the images to load_size.)",
84
+ )
85
+ parser.add_argument(
86
+ "--aspect_ratio",
87
+ type=float,
88
+ default=1.0,
89
+ help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio",
90
+ )
91
+ parser.add_argument(
92
+ "--label_nc",
93
+ type=int,
94
+ default=182,
95
+ help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.",
96
+ )
97
+ parser.add_argument(
98
+ "--contain_dontcare_label",
99
+ action="store_true",
100
+ help="if the label map contains dontcare label (dontcare=255)",
101
+ )
102
+ parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels")
103
+
104
+ # for setting inputs
105
+ parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/")
106
+ parser.add_argument("--dataset_mode", type=str, default="coco")
107
+ parser.add_argument(
108
+ "--serial_batches",
109
+ action="store_true",
110
+ help="if true, takes images in order to make batches, otherwise takes them randomly",
111
+ )
112
+ parser.add_argument(
113
+ "--no_flip",
114
+ action="store_true",
115
+ help="if specified, do not flip the images for data argumentation",
116
+ )
117
+ parser.add_argument("--nThreads", default=0, type=int, help="# threads for loading data")
118
+ parser.add_argument(
119
+ "--max_dataset_size",
120
+ type=int,
121
+ default=sys.maxsize,
122
+ help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
123
+ )
124
+ parser.add_argument(
125
+ "--load_from_opt_file",
126
+ action="store_true",
127
+ help="load the options from checkpoints and use that as default",
128
+ )
129
+ parser.add_argument(
130
+ "--cache_filelist_write",
131
+ action="store_true",
132
+ help="saves the current filelist into a text file, so that it loads faster",
133
+ )
134
+ parser.add_argument(
135
+ "--cache_filelist_read", action="store_true", help="reads from the file list cache"
136
+ )
137
+
138
+ # for displays
139
+ parser.add_argument("--display_winsize", type=int, default=400, help="display window size")
140
+
141
+ # for generator
142
+ parser.add_argument(
143
+ "--netG", type=str, default="spade", help="selects model to use for netG (pix2pixhd | spade)"
144
+ )
145
+ parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer")
146
+ parser.add_argument(
147
+ "--init_type",
148
+ type=str,
149
+ default="xavier",
150
+ help="network initialization [normal|xavier|kaiming|orthogonal]",
151
+ )
152
+ parser.add_argument(
153
+ "--init_variance", type=float, default=0.02, help="variance of the initialization distribution"
154
+ )
155
+ parser.add_argument("--z_dim", type=int, default=256, help="dimension of the latent z vector")
156
+ parser.add_argument(
157
+ "--no_parsing_map", action="store_true", help="During training, we do not use the parsing map"
158
+ )
159
+
160
+ # for instance-wise features
161
+ parser.add_argument(
162
+ "--no_instance", action="store_true", help="if specified, do *not* add instance map as input"
163
+ )
164
+ parser.add_argument(
165
+ "--nef", type=int, default=16, help="# of encoder filters in the first conv layer"
166
+ )
167
+ parser.add_argument("--use_vae", action="store_true", help="enable training with an image encoder.")
168
+ parser.add_argument(
169
+ "--tensorboard_log", action="store_true", help="use tensorboard to record the resutls"
170
+ )
171
+
172
+ # parser.add_argument('--img_dir',)
173
+ parser.add_argument(
174
+ "--old_face_folder", type=str, default="", help="The folder name of input old face"
175
+ )
176
+ parser.add_argument(
177
+ "--old_face_label_folder", type=str, default="", help="The folder name of input old face label"
178
+ )
179
+
180
+ parser.add_argument("--injection_layer", type=str, default="all", help="")
181
+
182
+ self.initialized = True
183
+ return parser
184
+
185
+ def gather_options(self):
186
+ # initialize parser with basic options
187
+ if not self.initialized:
188
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
189
+ parser = self.initialize(parser)
190
+
191
+ # get the basic options
192
+ opt, unknown = parser.parse_known_args()
193
+
194
+ # modify model-related parser options
195
+ model_name = opt.model
196
+ model_option_setter = models.get_option_setter(model_name)
197
+ parser = model_option_setter(parser, self.isTrain)
198
+
199
+ # modify dataset-related parser options
200
+ # dataset_mode = opt.dataset_mode
201
+ # dataset_option_setter = data.get_option_setter(dataset_mode)
202
+ # parser = dataset_option_setter(parser, self.isTrain)
203
+
204
+ opt, unknown = parser.parse_known_args()
205
+
206
+ # if there is opt_file, load it.
207
+ # The previous default options will be overwritten
208
+ if opt.load_from_opt_file:
209
+ parser = self.update_options_from_file(parser, opt)
210
+
211
+ opt = parser.parse_args()
212
+ self.parser = parser
213
+ return opt
214
+
215
+ def print_options(self, opt):
216
+ message = ""
217
+ message += "----------------- Options ---------------\n"
218
+ for k, v in sorted(vars(opt).items()):
219
+ comment = ""
220
+ default = self.parser.get_default(k)
221
+ if v != default:
222
+ comment = "\t[default: %s]" % str(default)
223
+ message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
224
+ message += "----------------- End -------------------"
225
+ # print(message)
226
+
227
+ def option_file_path(self, opt, makedir=False):
228
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
229
+ if makedir:
230
+ util.mkdirs(expr_dir)
231
+ file_name = os.path.join(expr_dir, "opt")
232
+ return file_name
233
+
234
+ def save_options(self, opt):
235
+ file_name = self.option_file_path(opt, makedir=True)
236
+ with open(file_name + ".txt", "wt") as opt_file:
237
+ for k, v in sorted(vars(opt).items()):
238
+ comment = ""
239
+ default = self.parser.get_default(k)
240
+ if v != default:
241
+ comment = "\t[default: %s]" % str(default)
242
+ opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment))
243
+
244
+ with open(file_name + ".pkl", "wb") as opt_file:
245
+ pickle.dump(opt, opt_file)
246
+
247
+ def update_options_from_file(self, parser, opt):
248
+ new_opt = self.load_options(opt)
249
+ for k, v in sorted(vars(opt).items()):
250
+ if hasattr(new_opt, k) and v != getattr(new_opt, k):
251
+ new_val = getattr(new_opt, k)
252
+ parser.set_defaults(**{k: new_val})
253
+ return parser
254
+
255
+ def load_options(self, opt):
256
+ file_name = self.option_file_path(opt, makedir=False)
257
+ new_opt = pickle.load(open(file_name + ".pkl", "rb"))
258
+ return new_opt
259
+
260
+ def parse(self, save=False):
261
+
262
+ opt = self.gather_options()
263
+ opt.isTrain = self.isTrain # train or test
264
+ opt.contain_dontcare_label = False
265
+
266
+ self.print_options(opt)
267
+ if opt.isTrain:
268
+ self.save_options(opt)
269
+
270
+ # Set semantic_nc based on the option.
271
+ # This will be convenient in many places
272
+ opt.semantic_nc = (
273
+ opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)
274
+ )
275
+
276
+ # set gpu ids
277
+ str_ids = opt.gpu_ids.split(",")
278
+ opt.gpu_ids = []
279
+ for str_id in str_ids:
280
+ int_id = int(str_id)
281
+ if int_id >= 0:
282
+ opt.gpu_ids.append(int_id)
283
+
284
+ if len(opt.gpu_ids) > 0:
285
+ print("The main GPU is ")
286
+ print(opt.gpu_ids[0])
287
+ torch.cuda.set_device(opt.gpu_ids[0])
288
+
289
+ assert (
290
+ len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0
291
+ ), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids))
292
+
293
+ self.opt = opt
294
+ return self.opt
Face_Enhancement/options/test_options.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from .base_options import BaseOptions
5
+
6
+
7
+ class TestOptions(BaseOptions):
8
+ def initialize(self, parser):
9
+ BaseOptions.initialize(self, parser)
10
+ parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
11
+ parser.add_argument(
12
+ "--which_epoch",
13
+ type=str,
14
+ default="latest",
15
+ help="which epoch to load? set to latest to use latest cached model",
16
+ )
17
+ parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run")
18
+
19
+ parser.set_defaults(
20
+ preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256
21
+ )
22
+ parser.set_defaults(serial_batches=True)
23
+ parser.set_defaults(no_flip=True)
24
+ parser.set_defaults(phase="test")
25
+ self.isTrain = False
26
+ return parser
Face_Enhancement/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.0.0
2
+ torchvision
3
+ dominate>=2.3.1
4
+ wandb
5
+ dill
6
+ scikit-image
7
+ tensorboardX
8
+ scipy
9
+ opencv-python
Face_Enhancement/test_face.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import os
5
+ from collections import OrderedDict
6
+
7
+ import data
8
+ from options.test_options import TestOptions
9
+ from models.pix2pix_model import Pix2PixModel
10
+ from util.visualizer import Visualizer
11
+ import torchvision.utils as vutils
12
+ import warnings
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+ opt = TestOptions().parse()
16
+
17
+ dataloader = data.create_dataloader(opt)
18
+
19
+ model = Pix2PixModel(opt)
20
+ model.eval()
21
+
22
+ visualizer = Visualizer(opt)
23
+
24
+
25
+ single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img")
26
+
27
+
28
+ if not os.path.exists(single_save_url):
29
+ os.makedirs(single_save_url)
30
+
31
+
32
+ for i, data_i in enumerate(dataloader):
33
+ if i * opt.batchSize >= opt.how_many:
34
+ break
35
+
36
+ generated = model(data_i, mode="inference")
37
+
38
+ img_path = data_i["path"]
39
+
40
+ for b in range(generated.shape[0]):
41
+ img_name = os.path.split(img_path[b])[-1]
42
+ save_img_url = os.path.join(single_save_url, img_name)
43
+
44
+ vutils.save_image((generated[b] + 1) / 2, save_img_url)
45
+
Face_Enhancement/util/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
Face_Enhancement/util/iter_counter.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import os
5
+ import time
6
+ import numpy as np
7
+
8
+
9
+ # Helper class that keeps track of training iterations
10
+ class IterationCounter:
11
+ def __init__(self, opt, dataset_size):
12
+ self.opt = opt
13
+ self.dataset_size = dataset_size
14
+
15
+ self.first_epoch = 1
16
+ self.total_epochs = opt.niter + opt.niter_decay
17
+ self.epoch_iter = 0 # iter number within each epoch
18
+ self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt")
19
+ if opt.isTrain and opt.continue_train:
20
+ try:
21
+ self.first_epoch, self.epoch_iter = np.loadtxt(
22
+ self.iter_record_path, delimiter=",", dtype=int
23
+ )
24
+ print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter))
25
+ except:
26
+ print(
27
+ "Could not load iteration record at %s. Starting from beginning." % self.iter_record_path
28
+ )
29
+
30
+ self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
31
+
32
+ # return the iterator of epochs for the training
33
+ def training_epochs(self):
34
+ return range(self.first_epoch, self.total_epochs + 1)
35
+
36
+ def record_epoch_start(self, epoch):
37
+ self.epoch_start_time = time.time()
38
+ self.epoch_iter = 0
39
+ self.last_iter_time = time.time()
40
+ self.current_epoch = epoch
41
+
42
+ def record_one_iteration(self):
43
+ current_time = time.time()
44
+
45
+ # the last remaining batch is dropped (see data/__init__.py),
46
+ # so we can assume batch size is always opt.batchSize
47
+ self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
48
+ self.last_iter_time = current_time
49
+ self.total_steps_so_far += self.opt.batchSize
50
+ self.epoch_iter += self.opt.batchSize
51
+
52
+ def record_epoch_end(self):
53
+ current_time = time.time()
54
+ self.time_per_epoch = current_time - self.epoch_start_time
55
+ print(
56
+ "End of epoch %d / %d \t Time Taken: %d sec"
57
+ % (self.current_epoch, self.total_epochs, self.time_per_epoch)
58
+ )
59
+ if self.current_epoch % self.opt.save_epoch_freq == 0:
60
+ np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d")
61
+ print("Saved current iteration count at %s." % self.iter_record_path)
62
+
63
+ def record_current_iter(self):
64
+ np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d")
65
+ print("Saved current iteration count at %s." % self.iter_record_path)
66
+
67
+ def needs_saving(self):
68
+ return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
69
+
70
+ def needs_printing(self):
71
+ return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
72
+
73
+ def needs_displaying(self):
74
+ return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
Face_Enhancement/util/util.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import re
5
+ import importlib
6
+ import torch
7
+ from argparse import Namespace
8
+ import numpy as np
9
+ from PIL import Image
10
+ import os
11
+ import argparse
12
+ import dill as pickle
13
+
14
+
15
+ def save_obj(obj, name):
16
+ with open(name, "wb") as f:
17
+ pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
18
+
19
+
20
+ def load_obj(name):
21
+ with open(name, "rb") as f:
22
+ return pickle.load(f)
23
+
24
+
25
+ def copyconf(default_opt, **kwargs):
26
+ conf = argparse.Namespace(**vars(default_opt))
27
+ for key in kwargs:
28
+ print(key, kwargs[key])
29
+ setattr(conf, key, kwargs[key])
30
+ return conf
31
+
32
+
33
+ # Converts a Tensor into a Numpy array
34
+ # |imtype|: the desired type of the converted numpy array
35
+ def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
36
+ if isinstance(image_tensor, list):
37
+ image_numpy = []
38
+ for i in range(len(image_tensor)):
39
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
40
+ return image_numpy
41
+
42
+ if image_tensor.dim() == 4:
43
+ # transform each image in the batch
44
+ images_np = []
45
+ for b in range(image_tensor.size(0)):
46
+ one_image = image_tensor[b]
47
+ one_image_np = tensor2im(one_image)
48
+ images_np.append(one_image_np.reshape(1, *one_image_np.shape))
49
+ images_np = np.concatenate(images_np, axis=0)
50
+
51
+ return images_np
52
+
53
+ if image_tensor.dim() == 2:
54
+ image_tensor = image_tensor.unsqueeze(0)
55
+ image_numpy = image_tensor.detach().cpu().float().numpy()
56
+ if normalize:
57
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
58
+ else:
59
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
60
+ image_numpy = np.clip(image_numpy, 0, 255)
61
+ if image_numpy.shape[2] == 1:
62
+ image_numpy = image_numpy[:, :, 0]
63
+ return image_numpy.astype(imtype)
64
+
65
+
66
+ # Converts a one-hot tensor into a colorful label map
67
+ def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
68
+ if label_tensor.dim() == 4:
69
+ # transform each image in the batch
70
+ images_np = []
71
+ for b in range(label_tensor.size(0)):
72
+ one_image = label_tensor[b]
73
+ one_image_np = tensor2label(one_image, n_label, imtype)
74
+ images_np.append(one_image_np.reshape(1, *one_image_np.shape))
75
+ images_np = np.concatenate(images_np, axis=0)
76
+ # if tile:
77
+ # images_tiled = tile_images(images_np)
78
+ # return images_tiled
79
+ # else:
80
+ # images_np = images_np[0]
81
+ # return images_np
82
+ return images_np
83
+
84
+ if label_tensor.dim() == 1:
85
+ return np.zeros((64, 64, 3), dtype=np.uint8)
86
+ if n_label == 0:
87
+ return tensor2im(label_tensor, imtype)
88
+ label_tensor = label_tensor.cpu().float()
89
+ if label_tensor.size()[0] > 1:
90
+ label_tensor = label_tensor.max(0, keepdim=True)[1]
91
+ label_tensor = Colorize(n_label)(label_tensor)
92
+ label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
93
+ result = label_numpy.astype(imtype)
94
+ return result
95
+
96
+
97
+ def save_image(image_numpy, image_path, create_dir=False):
98
+ if create_dir:
99
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
100
+ if len(image_numpy.shape) == 2:
101
+ image_numpy = np.expand_dims(image_numpy, axis=2)
102
+ if image_numpy.shape[2] == 1:
103
+ image_numpy = np.repeat(image_numpy, 3, 2)
104
+ image_pil = Image.fromarray(image_numpy)
105
+
106
+ # save to png
107
+ image_pil.save(image_path.replace(".jpg", ".png"))
108
+
109
+
110
+ def mkdirs(paths):
111
+ if isinstance(paths, list) and not isinstance(paths, str):
112
+ for path in paths:
113
+ mkdir(path)
114
+ else:
115
+ mkdir(paths)
116
+
117
+
118
+ def mkdir(path):
119
+ if not os.path.exists(path):
120
+ os.makedirs(path)
121
+
122
+
123
+ def atoi(text):
124
+ return int(text) if text.isdigit() else text
125
+
126
+
127
+ def natural_keys(text):
128
+ """
129
+ alist.sort(key=natural_keys) sorts in human order
130
+ http://nedbatchelder.com/blog/200712/human_sorting.html
131
+ (See Toothy's implementation in the comments)
132
+ """
133
+ return [atoi(c) for c in re.split("(\d+)", text)]
134
+
135
+
136
+ def natural_sort(items):
137
+ items.sort(key=natural_keys)
138
+
139
+
140
+ def str2bool(v):
141
+ if v.lower() in ("yes", "true", "t", "y", "1"):
142
+ return True
143
+ elif v.lower() in ("no", "false", "f", "n", "0"):
144
+ return False
145
+ else:
146
+ raise argparse.ArgumentTypeError("Boolean value expected.")
147
+
148
+
149
+ def find_class_in_module(target_cls_name, module):
150
+ target_cls_name = target_cls_name.replace("_", "").lower()
151
+ clslib = importlib.import_module(module)
152
+ cls = None
153
+ for name, clsobj in clslib.__dict__.items():
154
+ if name.lower() == target_cls_name:
155
+ cls = clsobj
156
+
157
+ if cls is None:
158
+ print(
159
+ "In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
160
+ % (module, target_cls_name)
161
+ )
162
+ exit(0)
163
+
164
+ return cls
165
+
166
+
167
+ def save_network(net, label, epoch, opt):
168
+ save_filename = "%s_net_%s.pth" % (epoch, label)
169
+ save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
170
+ torch.save(net.cpu().state_dict(), save_path)
171
+ if len(opt.gpu_ids) and torch.cuda.is_available():
172
+ net.cuda()
173
+
174
+
175
+ def load_network(net, label, epoch, opt):
176
+ save_filename = "%s_net_%s.pth" % (epoch, label)
177
+ save_dir = os.path.join(opt.checkpoints_dir, opt.name)
178
+ save_path = os.path.join(save_dir, save_filename)
179
+ if os.path.exists(save_path):
180
+ weights = torch.load(save_path)
181
+ net.load_state_dict(weights)
182
+ return net
183
+
184
+
185
+ ###############################################################################
186
+ # Code from
187
+ # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
188
+ # Modified so it complies with the Citscape label map colors
189
+ ###############################################################################
190
+ def uint82bin(n, count=8):
191
+ """returns the binary of integer n, count refers to amount of bits"""
192
+ return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
193
+
194
+
195
+ class Colorize(object):
196
+ def __init__(self, n=35):
197
+ self.cmap = labelcolormap(n)
198
+ self.cmap = torch.from_numpy(self.cmap[:n])
199
+
200
+ def __call__(self, gray_image):
201
+ size = gray_image.size()
202
+ color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
203
+
204
+ for label in range(0, len(self.cmap)):
205
+ mask = (label == gray_image[0]).cpu()
206
+ color_image[0][mask] = self.cmap[label][0]
207
+ color_image[1][mask] = self.cmap[label][1]
208
+ color_image[2][mask] = self.cmap[label][2]
209
+
210
+ return color_image
Face_Enhancement/util/visualizer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import os
5
+ import ntpath
6
+ import time
7
+ from . import util
8
+ import scipy.misc
9
+
10
+ try:
11
+ from StringIO import StringIO # Python 2.7
12
+ except ImportError:
13
+ from io import BytesIO # Python 3.x
14
+ import torchvision.utils as vutils
15
+ from tensorboardX import SummaryWriter
16
+ import torch
17
+ import numpy as np
18
+
19
+
20
+ class Visualizer:
21
+ def __init__(self, opt):
22
+ self.opt = opt
23
+ self.tf_log = opt.isTrain and opt.tf_log
24
+
25
+ self.tensorboard_log = opt.tensorboard_log
26
+
27
+ self.win_size = opt.display_winsize
28
+ self.name = opt.name
29
+ if self.tensorboard_log:
30
+
31
+ if self.opt.isTrain:
32
+ self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
33
+ if not os.path.exists(self.log_dir):
34
+ os.makedirs(self.log_dir)
35
+ self.writer = SummaryWriter(log_dir=self.log_dir)
36
+ else:
37
+ print("hi :)")
38
+ self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
39
+ if not os.path.exists(self.log_dir):
40
+ os.makedirs(self.log_dir)
41
+
42
+ if opt.isTrain:
43
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
44
+ with open(self.log_name, "a") as log_file:
45
+ now = time.strftime("%c")
46
+ log_file.write("================ Training Loss (%s) ================\n" % now)
47
+
48
+ # |visuals|: dictionary of images to display or save
49
+ def display_current_results(self, visuals, epoch, step):
50
+
51
+ all_tensor = []
52
+ if self.tensorboard_log:
53
+
54
+ for key, tensor in visuals.items():
55
+ all_tensor.append((tensor.data.cpu() + 1) / 2)
56
+
57
+ output = torch.cat(all_tensor, 0)
58
+ img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)
59
+
60
+ if self.opt.isTrain:
61
+ self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
62
+ else:
63
+ vutils.save_image(
64
+ output,
65
+ os.path.join(self.log_dir, str(step) + ".png"),
66
+ nrow=self.opt.batchSize,
67
+ padding=0,
68
+ normalize=False,
69
+ )
70
+
71
+ # errors: dictionary of error labels and values
72
+ def plot_current_errors(self, errors, step):
73
+ if self.tf_log:
74
+ for tag, value in errors.items():
75
+ value = value.mean().float()
76
+ summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
77
+ self.writer.add_summary(summary, step)
78
+
79
+ if self.tensorboard_log:
80
+
81
+ self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
82
+ self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
83
+ self.writer.add_scalars(
84
+ "Loss/GAN",
85
+ {
86
+ "G": errors["GAN"].mean().float(),
87
+ "D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
88
+ },
89
+ step,
90
+ )
91
+
92
+ # errors: same format as |errors| of plotCurrentErrors
93
+ def print_current_errors(self, epoch, i, errors, t):
94
+ message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
95
+ for k, v in errors.items():
96
+ v = v.mean().float()
97
+ message += "%s: %.3f " % (k, v)
98
+
99
+ print(message)
100
+ with open(self.log_name, "a") as log_file:
101
+ log_file.write("%s\n" % message)
102
+
103
+ def convert_visuals_to_numpy(self, visuals):
104
+ for key, t in visuals.items():
105
+ tile = self.opt.batchSize > 8
106
+ if "input_label" == key:
107
+ t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy
108
+ else:
109
+ t = util.tensor2im(t, tile=tile)
110
+ visuals[key] = t
111
+ return visuals
112
+
113
+ # save image to the disk
114
+ def save_images(self, webpage, visuals, image_path):
115
+ visuals = self.convert_visuals_to_numpy(visuals)
116
+
117
+ image_dir = webpage.get_image_dir()
118
+ short_path = ntpath.basename(image_path[0])
119
+ name = os.path.splitext(short_path)[0]
120
+
121
+ webpage.add_header(name)
122
+ ims = []
123
+ txts = []
124
+ links = []
125
+
126
+ for label, image_numpy in visuals.items():
127
+ image_name = os.path.join(label, "%s.png" % (name))
128
+ save_path = os.path.join(image_dir, image_name)
129
+ util.save_image(image_numpy, save_path, create_dir=True)
130
+
131
+ ims.append(image_name)
132
+ txts.append(label)
133
+ links.append(image_name)
134
+ webpage.add_images(ims, txts, links, width=self.win_size)
GUI.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import PySimpleGUI as sg
4
+ import os.path
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import shutil
9
+ from subprocess import call
10
+
11
+ def modify(image_filename=None, cv2_frame=None):
12
+
13
+ def run_cmd(command):
14
+ try:
15
+ call(command, shell=True)
16
+ except KeyboardInterrupt:
17
+ print("Process interrupted")
18
+ sys.exit(1)
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--input_folder", type=str,
22
+ default= image_filename, help="Test images")
23
+ parser.add_argument(
24
+ "--output_folder",
25
+ type=str,
26
+ default="./output",
27
+ help="Restored images, please use the absolute path",
28
+ )
29
+ parser.add_argument("--GPU", type=str, default="-1", help="0,1,2")
30
+ parser.add_argument(
31
+ "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint"
32
+ )
33
+ parser.add_argument("--with_scratch",default="--with_scratch" ,action="store_true")
34
+ opts = parser.parse_args()
35
+
36
+ gpu1 = opts.GPU
37
+
38
+ # resolve relative paths before changing directory
39
+ opts.input_folder = os.path.abspath(opts.input_folder)
40
+ opts.output_folder = os.path.abspath(opts.output_folder)
41
+ if not os.path.exists(opts.output_folder):
42
+ os.makedirs(opts.output_folder)
43
+
44
+ main_environment = os.getcwd()
45
+
46
+ # Stage 1: Overall Quality Improve
47
+ print("Running Stage 1: Overall restoration")
48
+ os.chdir("./Global")
49
+ stage_1_input_dir = opts.input_folder
50
+ stage_1_output_dir = os.path.join(
51
+ opts.output_folder, "stage_1_restore_output")
52
+ if not os.path.exists(stage_1_output_dir):
53
+ os.makedirs(stage_1_output_dir)
54
+
55
+ if not opts.with_scratch:
56
+ stage_1_command = (
57
+ "python test.py --test_mode Full --Quality_restore --test_input "
58
+ + stage_1_input_dir
59
+ + " --outputs_dir "
60
+ + stage_1_output_dir
61
+ + " --gpu_ids "
62
+ + gpu1
63
+ )
64
+ run_cmd(stage_1_command)
65
+ else:
66
+
67
+ mask_dir = os.path.join(stage_1_output_dir, "masks")
68
+ new_input = os.path.join(mask_dir, "input")
69
+ new_mask = os.path.join(mask_dir, "mask")
70
+ stage_1_command_1 = (
71
+ "python detection.py --test_path "
72
+ + stage_1_input_dir
73
+ + " --output_dir "
74
+ + mask_dir
75
+ + " --input_size full_size"
76
+ + " --GPU "
77
+ + gpu1
78
+ )
79
+ stage_1_command_2 = (
80
+ "python test.py --Scratch_and_Quality_restore --test_input "
81
+ + new_input
82
+ + " --test_mask "
83
+ + new_mask
84
+ + " --outputs_dir "
85
+ + stage_1_output_dir
86
+ + " --gpu_ids "
87
+ + gpu1
88
+ )
89
+ run_cmd(stage_1_command_1)
90
+ run_cmd(stage_1_command_2)
91
+
92
+ # Solve the case when there is no face in the old photo
93
+ stage_1_results = os.path.join(stage_1_output_dir, "restored_image")
94
+ stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
95
+ if not os.path.exists(stage_4_output_dir):
96
+ os.makedirs(stage_4_output_dir)
97
+ for x in os.listdir(stage_1_results):
98
+ img_dir = os.path.join(stage_1_results, x)
99
+ shutil.copy(img_dir, stage_4_output_dir)
100
+
101
+ print("Finish Stage 1 ...")
102
+ print("\n")
103
+
104
+ # Stage 2: Face Detection
105
+
106
+ print("Running Stage 2: Face Detection")
107
+ os.chdir(".././Face_Detection")
108
+ stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image")
109
+ stage_2_output_dir = os.path.join(
110
+ opts.output_folder, "stage_2_detection_output")
111
+ if not os.path.exists(stage_2_output_dir):
112
+ os.makedirs(stage_2_output_dir)
113
+ stage_2_command = (
114
+ "python detect_all_dlib.py --url " + stage_2_input_dir +
115
+ " --save_url " + stage_2_output_dir
116
+ )
117
+ run_cmd(stage_2_command)
118
+ print("Finish Stage 2 ...")
119
+ print("\n")
120
+
121
+ # Stage 3: Face Restore
122
+ print("Running Stage 3: Face Enhancement")
123
+ os.chdir(".././Face_Enhancement")
124
+ stage_3_input_mask = "./"
125
+ stage_3_input_face = stage_2_output_dir
126
+ stage_3_output_dir = os.path.join(
127
+ opts.output_folder, "stage_3_face_output")
128
+ if not os.path.exists(stage_3_output_dir):
129
+ os.makedirs(stage_3_output_dir)
130
+ stage_3_command = (
131
+ "python test_face.py --old_face_folder "
132
+ + stage_3_input_face
133
+ + " --old_face_label_folder "
134
+ + stage_3_input_mask
135
+ + " --tensorboard_log --name "
136
+ + opts.checkpoint_name
137
+ + " --gpu_ids "
138
+ + gpu1
139
+ + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir "
140
+ + stage_3_output_dir
141
+ + " --no_parsing_map"
142
+ )
143
+ run_cmd(stage_3_command)
144
+ print("Finish Stage 3 ...")
145
+ print("\n")
146
+
147
+ # Stage 4: Warp back
148
+ print("Running Stage 4: Blending")
149
+ os.chdir(".././Face_Detection")
150
+ stage_4_input_image_dir = os.path.join(
151
+ stage_1_output_dir, "restored_image")
152
+ stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img")
153
+ stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
154
+ if not os.path.exists(stage_4_output_dir):
155
+ os.makedirs(stage_4_output_dir)
156
+ stage_4_command = (
157
+ "python align_warp_back_multiple_dlib.py --origin_url "
158
+ + stage_4_input_image_dir
159
+ + " --replace_url "
160
+ + stage_4_input_face_dir
161
+ + " --save_url "
162
+ + stage_4_output_dir
163
+ )
164
+ run_cmd(stage_4_command)
165
+ print("Finish Stage 4 ...")
166
+ print("\n")
167
+
168
+ print("All the processing is done. Please check the results.")
169
+
170
+ # --------------------------------- The GUI ---------------------------------
171
+
172
+ # First the window layout...
173
+
174
+ images_col = [[sg.Text('Input file:'), sg.In(enable_events=True, key='-IN FILE-'), sg.FileBrowse()],
175
+ [sg.Button('Modify Photo', key='-MPHOTO-'), sg.Button('Exit')],
176
+ [sg.Image(filename='', key='-IN-'), sg.Image(filename='', key='-OUT-')],]
177
+ # ----- Full layout -----
178
+ layout = [[sg.VSeperator(), sg.Column(images_col)]]
179
+
180
+ # ----- Make the window -----
181
+ window = sg.Window('Bringing-old-photos-back-to-life', layout, grab_anywhere=True)
182
+
183
+ # ----- Run the Event Loop -----
184
+ prev_filename = colorized = cap = None
185
+ while True:
186
+ event, values = window.read()
187
+ if event in (None, 'Exit'):
188
+ break
189
+
190
+ elif event == '-MPHOTO-':
191
+ try:
192
+ n1 = filename.split("/")[-2]
193
+ n2 = filename.split("/")[-3]
194
+ n3 = filename.split("/")[-1]
195
+ filename= str(f"./{n2}/{n1}")
196
+ modify(filename)
197
+
198
+ global f_image
199
+ f_image = f'./output/final_output/{n3}'
200
+ image = cv2.imread(f_image)
201
+ window['-OUT-'].update(data=cv2.imencode('.png', image)[1].tobytes())
202
+
203
+ except:
204
+ continue
205
+
206
+ elif event == '-IN FILE-': # A single filename was chosen
207
+ filename = values['-IN FILE-']
208
+ if filename != prev_filename:
209
+ prev_filename = filename
210
+ try:
211
+ image = cv2.imread(filename)
212
+ window['-IN-'].update(data=cv2.imencode('.png', image)[1].tobytes())
213
+ except:
214
+ continue
215
+
216
+ # ----- Exit program -----
217
+ window.close()
Global/data/Create_Bigfile.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import os
5
+ import struct
6
+ from PIL import Image
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+
18
+ def make_dataset(dir):
19
+ images = []
20
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
21
+
22
+ for root, _, fnames in sorted(os.walk(dir)):
23
+ for fname in fnames:
24
+ if is_image_file(fname):
25
+ #print(fname)
26
+ path = os.path.join(root, fname)
27
+ images.append(path)
28
+
29
+ return images
30
+
31
+ ### Modify these 3 lines in your own environment
32
+ indir="/home/ziyuwan/workspace/data/temp_old"
33
+ target_folders=['VOC','Real_L_old','Real_RGB_old']
34
+ out_dir ="/home/ziyuwan/workspace/data/temp_old"
35
+ ###
36
+
37
+ if os.path.exists(out_dir) is False:
38
+ os.makedirs(out_dir)
39
+
40
+ #
41
+ for target_folder in target_folders:
42
+ curr_indir = os.path.join(indir, target_folder)
43
+ curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder)))
44
+ image_lists = make_dataset(curr_indir)
45
+ image_lists.sort()
46
+ with open(curr_out_file, 'wb') as wfid:
47
+ # write total image number
48
+ wfid.write(struct.pack('i', len(image_lists)))
49
+ for i, img_path in enumerate(image_lists):
50
+ # write file name first
51
+ img_name = os.path.basename(img_path)
52
+ img_name_bytes = img_name.encode('utf-8')
53
+ wfid.write(struct.pack('i', len(img_name_bytes)))
54
+ wfid.write(img_name_bytes)
55
+ #
56
+ # # write image data in
57
+ with open(img_path, 'rb') as img_fid:
58
+ img_bytes = img_fid.read()
59
+ wfid.write(struct.pack('i', len(img_bytes)))
60
+ wfid.write(img_bytes)
61
+
62
+ if i % 1000 == 0:
63
+ print('write %d images done' % i)
Global/data/Load_Bigfile.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import io
5
+ import os
6
+ import struct
7
+ from PIL import Image
8
+
9
+ class BigFileMemoryLoader(object):
10
+ def __load_bigfile(self):
11
+ print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
12
+ with open(self.file_path, 'rb') as fid:
13
+ self.img_num = struct.unpack('i', fid.read(4))[0]
14
+ self.img_names = []
15
+ self.img_bytes = []
16
+ print('find total %d images' % self.img_num)
17
+ for i in range(self.img_num):
18
+ img_name_len = struct.unpack('i', fid.read(4))[0]
19
+ img_name = fid.read(img_name_len).decode('utf-8')
20
+ self.img_names.append(img_name)
21
+ img_bytes_len = struct.unpack('i', fid.read(4))[0]
22
+ self.img_bytes.append(fid.read(img_bytes_len))
23
+ if i % 5000 == 0:
24
+ print('load %d images done' % i)
25
+ print('load all %d images done' % self.img_num)
26
+
27
+ def __init__(self, file_path):
28
+ super(BigFileMemoryLoader, self).__init__()
29
+ self.file_path = file_path
30
+ self.__load_bigfile()
31
+
32
+ def __getitem__(self, index):
33
+ try:
34
+ img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
35
+ return self.img_names[index], img
36
+ except Exception:
37
+ print('Image read error for index %d: %s' % (index, self.img_names[index]))
38
+ return self.__getitem__((index+1)%self.img_num)
39
+
40
+
41
+ def __len__(self):
42
+ return self.img_num
Global/data/__init__.py ADDED
File without changes