Spaces:
Runtime error
Runtime error
manhkhanhUIT
commited on
Commit
·
7fab858
1
Parent(s):
a8b38c0
Add code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CODE_OF_CONDUCT.md +9 -0
- Dockerfile +43 -0
- Face_Detection/align_warp_back_multiple_dlib.py +437 -0
- Face_Detection/align_warp_back_multiple_dlib_HR.py +437 -0
- Face_Detection/detect_all_dlib.py +184 -0
- Face_Detection/detect_all_dlib_HR.py +184 -0
- Face_Enhancement/data/__init__.py +22 -0
- Face_Enhancement/data/base_dataset.py +125 -0
- Face_Enhancement/data/custom_dataset.py +56 -0
- Face_Enhancement/data/face_dataset.py +102 -0
- Face_Enhancement/data/image_folder.py +101 -0
- Face_Enhancement/data/pix2pix_dataset.py +108 -0
- Face_Enhancement/models/__init__.py +44 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/LICENSE +21 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/README.md +118 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py +14 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py +412 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py +74 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py +137 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py +94 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py +29 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py +56 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py +62 -0
- Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py +114 -0
- Face_Enhancement/models/networks/__init__.py +58 -0
- Face_Enhancement/models/networks/architecture.py +173 -0
- Face_Enhancement/models/networks/base_network.py +58 -0
- Face_Enhancement/models/networks/encoder.py +53 -0
- Face_Enhancement/models/networks/generator.py +233 -0
- Face_Enhancement/models/networks/normalization.py +100 -0
- Face_Enhancement/models/networks/sync_batchnorm/__init__.py +14 -0
- Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py +412 -0
- Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py +74 -0
- Face_Enhancement/models/networks/sync_batchnorm/comm.py +137 -0
- Face_Enhancement/models/networks/sync_batchnorm/replicate.py +94 -0
- Face_Enhancement/models/networks/sync_batchnorm/unittest.py +29 -0
- Face_Enhancement/models/pix2pix_model.py +246 -0
- Face_Enhancement/options/__init__.py +2 -0
- Face_Enhancement/options/base_options.py +294 -0
- Face_Enhancement/options/test_options.py +26 -0
- Face_Enhancement/requirements.txt +9 -0
- Face_Enhancement/test_face.py +45 -0
- Face_Enhancement/util/__init__.py +2 -0
- Face_Enhancement/util/iter_counter.py +74 -0
- Face_Enhancement/util/util.py +210 -0
- Face_Enhancement/util/visualizer.py +134 -0
- GUI.py +217 -0
- Global/data/Create_Bigfile.py +63 -0
- Global/data/Load_Bigfile.py +42 -0
- Global/data/__init__.py +0 -0
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Microsoft Open Source Code of Conduct
|
2 |
+
|
3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
+
|
5 |
+
Resources:
|
6 |
+
|
7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
+
- Contact [[email protected]](mailto:[email protected]) with questions or concerns
|
Dockerfile
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.1-base-ubuntu20.04
|
2 |
+
|
3 |
+
RUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzip python3-pip python3-dev cmake libgl1-mesa-dev python-is-python3 libgtk2.0-dev -yq
|
4 |
+
ADD . /app
|
5 |
+
WORKDIR /app
|
6 |
+
RUN cd Face_Enhancement/models/networks/ &&\
|
7 |
+
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
|
8 |
+
cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
|
9 |
+
cd ../../../
|
10 |
+
|
11 |
+
RUN cd Global/detection_models &&\
|
12 |
+
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
|
13 |
+
cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
|
14 |
+
cd ../../
|
15 |
+
|
16 |
+
RUN cd Face_Detection/ &&\
|
17 |
+
wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 &&\
|
18 |
+
bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 &&\
|
19 |
+
cd ../
|
20 |
+
|
21 |
+
RUN cd Face_Enhancement/ &&\
|
22 |
+
wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip &&\
|
23 |
+
unzip checkpoints.zip &&\
|
24 |
+
cd ../ &&\
|
25 |
+
cd Global/ &&\
|
26 |
+
wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip &&\
|
27 |
+
unzip checkpoints.zip &&\
|
28 |
+
rm -f checkpoints.zip &&\
|
29 |
+
cd ../
|
30 |
+
|
31 |
+
RUN pip3 install numpy
|
32 |
+
|
33 |
+
RUN pip3 install dlib
|
34 |
+
|
35 |
+
RUN pip3 install -r requirements.txt
|
36 |
+
|
37 |
+
RUN git clone https://github.com/NVlabs/SPADE.git
|
38 |
+
|
39 |
+
RUN cd SPADE/ && pip3 install -r requirements.txt
|
40 |
+
|
41 |
+
RUN cd ..
|
42 |
+
|
43 |
+
CMD ["python3", "run.py"]
|
Face_Detection/align_warp_back_multiple_dlib.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import skimage.io as io
|
7 |
+
|
8 |
+
# from face_sdk import FaceDetection
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from matplotlib.patches import Rectangle
|
11 |
+
from skimage.transform import SimilarityTransform
|
12 |
+
from skimage.transform import warp
|
13 |
+
from PIL import Image, ImageFilter
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torchvision as tv
|
16 |
+
import torchvision.utils as vutils
|
17 |
+
import time
|
18 |
+
import cv2
|
19 |
+
import os
|
20 |
+
from skimage import img_as_ubyte
|
21 |
+
import json
|
22 |
+
import argparse
|
23 |
+
import dlib
|
24 |
+
|
25 |
+
|
26 |
+
def calculate_cdf(histogram):
|
27 |
+
"""
|
28 |
+
This method calculates the cumulative distribution function
|
29 |
+
:param array histogram: The values of the histogram
|
30 |
+
:return: normalized_cdf: The normalized cumulative distribution function
|
31 |
+
:rtype: array
|
32 |
+
"""
|
33 |
+
# Get the cumulative sum of the elements
|
34 |
+
cdf = histogram.cumsum()
|
35 |
+
|
36 |
+
# Normalize the cdf
|
37 |
+
normalized_cdf = cdf / float(cdf.max())
|
38 |
+
|
39 |
+
return normalized_cdf
|
40 |
+
|
41 |
+
|
42 |
+
def calculate_lookup(src_cdf, ref_cdf):
|
43 |
+
"""
|
44 |
+
This method creates the lookup table
|
45 |
+
:param array src_cdf: The cdf for the source image
|
46 |
+
:param array ref_cdf: The cdf for the reference image
|
47 |
+
:return: lookup_table: The lookup table
|
48 |
+
:rtype: array
|
49 |
+
"""
|
50 |
+
lookup_table = np.zeros(256)
|
51 |
+
lookup_val = 0
|
52 |
+
for src_pixel_val in range(len(src_cdf)):
|
53 |
+
lookup_val
|
54 |
+
for ref_pixel_val in range(len(ref_cdf)):
|
55 |
+
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
|
56 |
+
lookup_val = ref_pixel_val
|
57 |
+
break
|
58 |
+
lookup_table[src_pixel_val] = lookup_val
|
59 |
+
return lookup_table
|
60 |
+
|
61 |
+
|
62 |
+
def match_histograms(src_image, ref_image):
|
63 |
+
"""
|
64 |
+
This method matches the source image histogram to the
|
65 |
+
reference signal
|
66 |
+
:param image src_image: The original source image
|
67 |
+
:param image ref_image: The reference image
|
68 |
+
:return: image_after_matching
|
69 |
+
:rtype: image (array)
|
70 |
+
"""
|
71 |
+
# Split the images into the different color channels
|
72 |
+
# b means blue, g means green and r means red
|
73 |
+
src_b, src_g, src_r = cv2.split(src_image)
|
74 |
+
ref_b, ref_g, ref_r = cv2.split(ref_image)
|
75 |
+
|
76 |
+
# Compute the b, g, and r histograms separately
|
77 |
+
# The flatten() Numpy method returns a copy of the array c
|
78 |
+
# collapsed into one dimension.
|
79 |
+
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
|
80 |
+
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
|
81 |
+
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
|
82 |
+
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
|
83 |
+
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
|
84 |
+
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
|
85 |
+
|
86 |
+
# Compute the normalized cdf for the source and reference image
|
87 |
+
src_cdf_blue = calculate_cdf(src_hist_blue)
|
88 |
+
src_cdf_green = calculate_cdf(src_hist_green)
|
89 |
+
src_cdf_red = calculate_cdf(src_hist_red)
|
90 |
+
ref_cdf_blue = calculate_cdf(ref_hist_blue)
|
91 |
+
ref_cdf_green = calculate_cdf(ref_hist_green)
|
92 |
+
ref_cdf_red = calculate_cdf(ref_hist_red)
|
93 |
+
|
94 |
+
# Make a separate lookup table for each color
|
95 |
+
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
|
96 |
+
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
|
97 |
+
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
|
98 |
+
|
99 |
+
# Use the lookup function to transform the colors of the original
|
100 |
+
# source image
|
101 |
+
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
|
102 |
+
green_after_transform = cv2.LUT(src_g, green_lookup_table)
|
103 |
+
red_after_transform = cv2.LUT(src_r, red_lookup_table)
|
104 |
+
|
105 |
+
# Put the image back together
|
106 |
+
image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
|
107 |
+
image_after_matching = cv2.convertScaleAbs(image_after_matching)
|
108 |
+
|
109 |
+
return image_after_matching
|
110 |
+
|
111 |
+
|
112 |
+
def _standard_face_pts():
|
113 |
+
pts = (
|
114 |
+
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
|
115 |
+
- 1.0
|
116 |
+
)
|
117 |
+
|
118 |
+
return np.reshape(pts, (5, 2))
|
119 |
+
|
120 |
+
|
121 |
+
def _origin_face_pts():
|
122 |
+
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
|
123 |
+
|
124 |
+
return np.reshape(pts, (5, 2))
|
125 |
+
|
126 |
+
|
127 |
+
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
128 |
+
|
129 |
+
std_pts = _standard_face_pts() # [-1,1]
|
130 |
+
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
|
131 |
+
|
132 |
+
# print(target_pts)
|
133 |
+
|
134 |
+
h, w, c = img.shape
|
135 |
+
if normalize == True:
|
136 |
+
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
137 |
+
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
138 |
+
|
139 |
+
# print(landmark)
|
140 |
+
|
141 |
+
affine = SimilarityTransform()
|
142 |
+
|
143 |
+
affine.estimate(target_pts, landmark)
|
144 |
+
|
145 |
+
return affine
|
146 |
+
|
147 |
+
|
148 |
+
def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
149 |
+
|
150 |
+
std_pts = _standard_face_pts() # [-1,1]
|
151 |
+
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
|
152 |
+
|
153 |
+
# print(target_pts)
|
154 |
+
|
155 |
+
h, w, c = img.shape
|
156 |
+
if normalize == True:
|
157 |
+
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
158 |
+
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
159 |
+
|
160 |
+
# print(landmark)
|
161 |
+
|
162 |
+
affine = SimilarityTransform()
|
163 |
+
|
164 |
+
affine.estimate(landmark, target_pts)
|
165 |
+
|
166 |
+
return affine
|
167 |
+
|
168 |
+
|
169 |
+
def show_detection(image, box, landmark):
|
170 |
+
plt.imshow(image)
|
171 |
+
print(box[2] - box[0])
|
172 |
+
plt.gca().add_patch(
|
173 |
+
Rectangle(
|
174 |
+
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
|
175 |
+
)
|
176 |
+
)
|
177 |
+
plt.scatter(landmark[0][0], landmark[0][1])
|
178 |
+
plt.scatter(landmark[1][0], landmark[1][1])
|
179 |
+
plt.scatter(landmark[2][0], landmark[2][1])
|
180 |
+
plt.scatter(landmark[3][0], landmark[3][1])
|
181 |
+
plt.scatter(landmark[4][0], landmark[4][1])
|
182 |
+
plt.show()
|
183 |
+
|
184 |
+
|
185 |
+
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
186 |
+
# param = np.linalg.inv(affine)
|
187 |
+
param = affine
|
188 |
+
theta = np.zeros([2, 3])
|
189 |
+
theta[0, 0] = param[0, 0] * input_h / target_h
|
190 |
+
theta[0, 1] = param[0, 1] * input_w / target_h
|
191 |
+
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
|
192 |
+
theta[1, 0] = param[1, 0] * input_h / target_w
|
193 |
+
theta[1, 1] = param[1, 1] * input_w / target_w
|
194 |
+
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
|
195 |
+
return theta
|
196 |
+
|
197 |
+
|
198 |
+
def blur_blending(im1, im2, mask):
|
199 |
+
|
200 |
+
mask *= 255.0
|
201 |
+
|
202 |
+
kernel = np.ones((10, 10), np.uint8)
|
203 |
+
mask = cv2.erode(mask, kernel, iterations=1)
|
204 |
+
|
205 |
+
mask = Image.fromarray(mask.astype("uint8")).convert("L")
|
206 |
+
im1 = Image.fromarray(im1.astype("uint8"))
|
207 |
+
im2 = Image.fromarray(im2.astype("uint8"))
|
208 |
+
|
209 |
+
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
|
210 |
+
im = Image.composite(im1, im2, mask)
|
211 |
+
|
212 |
+
im = Image.composite(im, im2, mask_blur)
|
213 |
+
|
214 |
+
return np.array(im) / 255.0
|
215 |
+
|
216 |
+
|
217 |
+
def blur_blending_cv2(im1, im2, mask):
|
218 |
+
|
219 |
+
mask *= 255.0
|
220 |
+
|
221 |
+
kernel = np.ones((9, 9), np.uint8)
|
222 |
+
mask = cv2.erode(mask, kernel, iterations=3)
|
223 |
+
|
224 |
+
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
|
225 |
+
mask_blur /= 255.0
|
226 |
+
|
227 |
+
im = im1 * mask_blur + (1 - mask_blur) * im2
|
228 |
+
|
229 |
+
im /= 255.0
|
230 |
+
im = np.clip(im, 0.0, 1.0)
|
231 |
+
|
232 |
+
return im
|
233 |
+
|
234 |
+
|
235 |
+
# def Poisson_blending(im1,im2,mask):
|
236 |
+
|
237 |
+
|
238 |
+
# Image.composite(
|
239 |
+
def Poisson_blending(im1, im2, mask):
|
240 |
+
|
241 |
+
# mask=1-mask
|
242 |
+
mask *= 255
|
243 |
+
kernel = np.ones((10, 10), np.uint8)
|
244 |
+
mask = cv2.erode(mask, kernel, iterations=1)
|
245 |
+
mask /= 255
|
246 |
+
mask = 1 - mask
|
247 |
+
mask *= 255
|
248 |
+
|
249 |
+
mask = mask[:, :, 0]
|
250 |
+
width, height, channels = im1.shape
|
251 |
+
center = (int(height / 2), int(width / 2))
|
252 |
+
result = cv2.seamlessClone(
|
253 |
+
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
|
254 |
+
)
|
255 |
+
|
256 |
+
return result / 255.0
|
257 |
+
|
258 |
+
|
259 |
+
def Poisson_B(im1, im2, mask, center):
|
260 |
+
|
261 |
+
mask *= 255
|
262 |
+
|
263 |
+
result = cv2.seamlessClone(
|
264 |
+
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
|
265 |
+
)
|
266 |
+
|
267 |
+
return result / 255
|
268 |
+
|
269 |
+
|
270 |
+
def seamless_clone(old_face, new_face, raw_mask):
|
271 |
+
|
272 |
+
height, width, _ = old_face.shape
|
273 |
+
height = height // 2
|
274 |
+
width = width // 2
|
275 |
+
|
276 |
+
y_indices, x_indices, _ = np.nonzero(raw_mask)
|
277 |
+
y_crop = slice(np.min(y_indices), np.max(y_indices))
|
278 |
+
x_crop = slice(np.min(x_indices), np.max(x_indices))
|
279 |
+
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
|
280 |
+
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
|
281 |
+
|
282 |
+
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
|
283 |
+
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
|
284 |
+
insertion_mask[insertion_mask != 0] = 255
|
285 |
+
prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
|
286 |
+
"uint8"
|
287 |
+
)
|
288 |
+
# if np.sum(insertion_mask) == 0:
|
289 |
+
n_mask = insertion_mask[1:-1, 1:-1, :]
|
290 |
+
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
|
291 |
+
print(n_mask.shape)
|
292 |
+
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
|
293 |
+
if w < 4 or h < 4:
|
294 |
+
blended = prior
|
295 |
+
else:
|
296 |
+
blended = cv2.seamlessClone(
|
297 |
+
insertion, # pylint: disable=no-member
|
298 |
+
prior,
|
299 |
+
insertion_mask,
|
300 |
+
(x_center, y_center),
|
301 |
+
cv2.NORMAL_CLONE,
|
302 |
+
) # pylint: disable=no-member
|
303 |
+
|
304 |
+
blended = blended[height:-height, width:-width]
|
305 |
+
|
306 |
+
return blended.astype("float32") / 255.0
|
307 |
+
|
308 |
+
|
309 |
+
def get_landmark(face_landmarks, id):
|
310 |
+
part = face_landmarks.part(id)
|
311 |
+
x = part.x
|
312 |
+
y = part.y
|
313 |
+
|
314 |
+
return (x, y)
|
315 |
+
|
316 |
+
|
317 |
+
def search(face_landmarks):
|
318 |
+
|
319 |
+
x1, y1 = get_landmark(face_landmarks, 36)
|
320 |
+
x2, y2 = get_landmark(face_landmarks, 39)
|
321 |
+
x3, y3 = get_landmark(face_landmarks, 42)
|
322 |
+
x4, y4 = get_landmark(face_landmarks, 45)
|
323 |
+
|
324 |
+
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
325 |
+
|
326 |
+
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
327 |
+
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
328 |
+
|
329 |
+
x_left_eye = int((x1 + x2) / 2)
|
330 |
+
y_left_eye = int((y1 + y2) / 2)
|
331 |
+
x_right_eye = int((x3 + x4) / 2)
|
332 |
+
y_right_eye = int((y3 + y4) / 2)
|
333 |
+
|
334 |
+
results = np.array(
|
335 |
+
[
|
336 |
+
[x_left_eye, y_left_eye],
|
337 |
+
[x_right_eye, y_right_eye],
|
338 |
+
[x_nose, y_nose],
|
339 |
+
[x_left_mouth, y_left_mouth],
|
340 |
+
[x_right_mouth, y_right_mouth],
|
341 |
+
]
|
342 |
+
)
|
343 |
+
|
344 |
+
return results
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == "__main__":
|
348 |
+
|
349 |
+
parser = argparse.ArgumentParser()
|
350 |
+
parser.add_argument("--origin_url", type=str, default="./", help="origin images")
|
351 |
+
parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
|
352 |
+
parser.add_argument("--save_url", type=str, default="./save")
|
353 |
+
opts = parser.parse_args()
|
354 |
+
|
355 |
+
origin_url = opts.origin_url
|
356 |
+
replace_url = opts.replace_url
|
357 |
+
save_url = opts.save_url
|
358 |
+
|
359 |
+
if not os.path.exists(save_url):
|
360 |
+
os.makedirs(save_url)
|
361 |
+
|
362 |
+
face_detector = dlib.get_frontal_face_detector()
|
363 |
+
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
364 |
+
|
365 |
+
count = 0
|
366 |
+
|
367 |
+
for x in os.listdir(origin_url):
|
368 |
+
img_url = os.path.join(origin_url, x)
|
369 |
+
pil_img = Image.open(img_url).convert("RGB")
|
370 |
+
|
371 |
+
origin_width, origin_height = pil_img.size
|
372 |
+
image = np.array(pil_img)
|
373 |
+
|
374 |
+
start = time.time()
|
375 |
+
faces = face_detector(image)
|
376 |
+
done = time.time()
|
377 |
+
|
378 |
+
if len(faces) == 0:
|
379 |
+
print("Warning: There is no face in %s" % (x))
|
380 |
+
continue
|
381 |
+
|
382 |
+
blended = image
|
383 |
+
for face_id in range(len(faces)):
|
384 |
+
|
385 |
+
current_face = faces[face_id]
|
386 |
+
face_landmarks = landmark_locator(image, current_face)
|
387 |
+
current_fl = search(face_landmarks)
|
388 |
+
|
389 |
+
forward_mask = np.ones_like(image).astype("uint8")
|
390 |
+
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
|
391 |
+
aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True)
|
392 |
+
forward_mask = warp(
|
393 |
+
forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True
|
394 |
+
)
|
395 |
+
|
396 |
+
affine_inverse = affine.inverse
|
397 |
+
cur_face = aligned_face
|
398 |
+
if replace_url != "":
|
399 |
+
|
400 |
+
face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
|
401 |
+
cur_url = os.path.join(replace_url, face_name)
|
402 |
+
restored_face = Image.open(cur_url).convert("RGB")
|
403 |
+
restored_face = np.array(restored_face)
|
404 |
+
cur_face = restored_face
|
405 |
+
|
406 |
+
## Histogram Color matching
|
407 |
+
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
408 |
+
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
409 |
+
B = match_histograms(B, A)
|
410 |
+
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
|
411 |
+
|
412 |
+
warped_back = warp(
|
413 |
+
cur_face,
|
414 |
+
affine_inverse,
|
415 |
+
output_shape=(origin_height, origin_width, 3),
|
416 |
+
order=3,
|
417 |
+
preserve_range=True,
|
418 |
+
)
|
419 |
+
|
420 |
+
backward_mask = warp(
|
421 |
+
forward_mask,
|
422 |
+
affine_inverse,
|
423 |
+
output_shape=(origin_height, origin_width, 3),
|
424 |
+
order=0,
|
425 |
+
preserve_range=True,
|
426 |
+
) ## Nearest neighbour
|
427 |
+
|
428 |
+
blended = blur_blending_cv2(warped_back, blended, backward_mask)
|
429 |
+
blended *= 255.0
|
430 |
+
|
431 |
+
io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))
|
432 |
+
|
433 |
+
count += 1
|
434 |
+
|
435 |
+
if count % 1000 == 0:
|
436 |
+
print("%d have finished ..." % (count))
|
437 |
+
|
Face_Detection/align_warp_back_multiple_dlib_HR.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import skimage.io as io
|
7 |
+
|
8 |
+
# from face_sdk import FaceDetection
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from matplotlib.patches import Rectangle
|
11 |
+
from skimage.transform import SimilarityTransform
|
12 |
+
from skimage.transform import warp
|
13 |
+
from PIL import Image, ImageFilter
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torchvision as tv
|
16 |
+
import torchvision.utils as vutils
|
17 |
+
import time
|
18 |
+
import cv2
|
19 |
+
import os
|
20 |
+
from skimage import img_as_ubyte
|
21 |
+
import json
|
22 |
+
import argparse
|
23 |
+
import dlib
|
24 |
+
|
25 |
+
|
26 |
+
def calculate_cdf(histogram):
|
27 |
+
"""
|
28 |
+
This method calculates the cumulative distribution function
|
29 |
+
:param array histogram: The values of the histogram
|
30 |
+
:return: normalized_cdf: The normalized cumulative distribution function
|
31 |
+
:rtype: array
|
32 |
+
"""
|
33 |
+
# Get the cumulative sum of the elements
|
34 |
+
cdf = histogram.cumsum()
|
35 |
+
|
36 |
+
# Normalize the cdf
|
37 |
+
normalized_cdf = cdf / float(cdf.max())
|
38 |
+
|
39 |
+
return normalized_cdf
|
40 |
+
|
41 |
+
|
42 |
+
def calculate_lookup(src_cdf, ref_cdf):
|
43 |
+
"""
|
44 |
+
This method creates the lookup table
|
45 |
+
:param array src_cdf: The cdf for the source image
|
46 |
+
:param array ref_cdf: The cdf for the reference image
|
47 |
+
:return: lookup_table: The lookup table
|
48 |
+
:rtype: array
|
49 |
+
"""
|
50 |
+
lookup_table = np.zeros(256)
|
51 |
+
lookup_val = 0
|
52 |
+
for src_pixel_val in range(len(src_cdf)):
|
53 |
+
lookup_val
|
54 |
+
for ref_pixel_val in range(len(ref_cdf)):
|
55 |
+
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
|
56 |
+
lookup_val = ref_pixel_val
|
57 |
+
break
|
58 |
+
lookup_table[src_pixel_val] = lookup_val
|
59 |
+
return lookup_table
|
60 |
+
|
61 |
+
|
62 |
+
def match_histograms(src_image, ref_image):
|
63 |
+
"""
|
64 |
+
This method matches the source image histogram to the
|
65 |
+
reference signal
|
66 |
+
:param image src_image: The original source image
|
67 |
+
:param image ref_image: The reference image
|
68 |
+
:return: image_after_matching
|
69 |
+
:rtype: image (array)
|
70 |
+
"""
|
71 |
+
# Split the images into the different color channels
|
72 |
+
# b means blue, g means green and r means red
|
73 |
+
src_b, src_g, src_r = cv2.split(src_image)
|
74 |
+
ref_b, ref_g, ref_r = cv2.split(ref_image)
|
75 |
+
|
76 |
+
# Compute the b, g, and r histograms separately
|
77 |
+
# The flatten() Numpy method returns a copy of the array c
|
78 |
+
# collapsed into one dimension.
|
79 |
+
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
|
80 |
+
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
|
81 |
+
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
|
82 |
+
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
|
83 |
+
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
|
84 |
+
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
|
85 |
+
|
86 |
+
# Compute the normalized cdf for the source and reference image
|
87 |
+
src_cdf_blue = calculate_cdf(src_hist_blue)
|
88 |
+
src_cdf_green = calculate_cdf(src_hist_green)
|
89 |
+
src_cdf_red = calculate_cdf(src_hist_red)
|
90 |
+
ref_cdf_blue = calculate_cdf(ref_hist_blue)
|
91 |
+
ref_cdf_green = calculate_cdf(ref_hist_green)
|
92 |
+
ref_cdf_red = calculate_cdf(ref_hist_red)
|
93 |
+
|
94 |
+
# Make a separate lookup table for each color
|
95 |
+
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
|
96 |
+
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
|
97 |
+
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
|
98 |
+
|
99 |
+
# Use the lookup function to transform the colors of the original
|
100 |
+
# source image
|
101 |
+
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
|
102 |
+
green_after_transform = cv2.LUT(src_g, green_lookup_table)
|
103 |
+
red_after_transform = cv2.LUT(src_r, red_lookup_table)
|
104 |
+
|
105 |
+
# Put the image back together
|
106 |
+
image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
|
107 |
+
image_after_matching = cv2.convertScaleAbs(image_after_matching)
|
108 |
+
|
109 |
+
return image_after_matching
|
110 |
+
|
111 |
+
|
112 |
+
def _standard_face_pts():
|
113 |
+
pts = (
|
114 |
+
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
|
115 |
+
- 1.0
|
116 |
+
)
|
117 |
+
|
118 |
+
return np.reshape(pts, (5, 2))
|
119 |
+
|
120 |
+
|
121 |
+
def _origin_face_pts():
|
122 |
+
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
|
123 |
+
|
124 |
+
return np.reshape(pts, (5, 2))
|
125 |
+
|
126 |
+
|
127 |
+
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
128 |
+
|
129 |
+
std_pts = _standard_face_pts() # [-1,1]
|
130 |
+
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
|
131 |
+
|
132 |
+
# print(target_pts)
|
133 |
+
|
134 |
+
h, w, c = img.shape
|
135 |
+
if normalize == True:
|
136 |
+
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
137 |
+
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
138 |
+
|
139 |
+
# print(landmark)
|
140 |
+
|
141 |
+
affine = SimilarityTransform()
|
142 |
+
|
143 |
+
affine.estimate(target_pts, landmark)
|
144 |
+
|
145 |
+
return affine
|
146 |
+
|
147 |
+
|
148 |
+
def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
149 |
+
|
150 |
+
std_pts = _standard_face_pts() # [-1,1]
|
151 |
+
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
|
152 |
+
|
153 |
+
# print(target_pts)
|
154 |
+
|
155 |
+
h, w, c = img.shape
|
156 |
+
if normalize == True:
|
157 |
+
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
158 |
+
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
159 |
+
|
160 |
+
# print(landmark)
|
161 |
+
|
162 |
+
affine = SimilarityTransform()
|
163 |
+
|
164 |
+
affine.estimate(landmark, target_pts)
|
165 |
+
|
166 |
+
return affine
|
167 |
+
|
168 |
+
|
169 |
+
def show_detection(image, box, landmark):
|
170 |
+
plt.imshow(image)
|
171 |
+
print(box[2] - box[0])
|
172 |
+
plt.gca().add_patch(
|
173 |
+
Rectangle(
|
174 |
+
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
|
175 |
+
)
|
176 |
+
)
|
177 |
+
plt.scatter(landmark[0][0], landmark[0][1])
|
178 |
+
plt.scatter(landmark[1][0], landmark[1][1])
|
179 |
+
plt.scatter(landmark[2][0], landmark[2][1])
|
180 |
+
plt.scatter(landmark[3][0], landmark[3][1])
|
181 |
+
plt.scatter(landmark[4][0], landmark[4][1])
|
182 |
+
plt.show()
|
183 |
+
|
184 |
+
|
185 |
+
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
186 |
+
# param = np.linalg.inv(affine)
|
187 |
+
param = affine
|
188 |
+
theta = np.zeros([2, 3])
|
189 |
+
theta[0, 0] = param[0, 0] * input_h / target_h
|
190 |
+
theta[0, 1] = param[0, 1] * input_w / target_h
|
191 |
+
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
|
192 |
+
theta[1, 0] = param[1, 0] * input_h / target_w
|
193 |
+
theta[1, 1] = param[1, 1] * input_w / target_w
|
194 |
+
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
|
195 |
+
return theta
|
196 |
+
|
197 |
+
|
198 |
+
def blur_blending(im1, im2, mask):
|
199 |
+
|
200 |
+
mask *= 255.0
|
201 |
+
|
202 |
+
kernel = np.ones((10, 10), np.uint8)
|
203 |
+
mask = cv2.erode(mask, kernel, iterations=1)
|
204 |
+
|
205 |
+
mask = Image.fromarray(mask.astype("uint8")).convert("L")
|
206 |
+
im1 = Image.fromarray(im1.astype("uint8"))
|
207 |
+
im2 = Image.fromarray(im2.astype("uint8"))
|
208 |
+
|
209 |
+
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
|
210 |
+
im = Image.composite(im1, im2, mask)
|
211 |
+
|
212 |
+
im = Image.composite(im, im2, mask_blur)
|
213 |
+
|
214 |
+
return np.array(im) / 255.0
|
215 |
+
|
216 |
+
|
217 |
+
def blur_blending_cv2(im1, im2, mask):
|
218 |
+
|
219 |
+
mask *= 255.0
|
220 |
+
|
221 |
+
kernel = np.ones((9, 9), np.uint8)
|
222 |
+
mask = cv2.erode(mask, kernel, iterations=3)
|
223 |
+
|
224 |
+
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
|
225 |
+
mask_blur /= 255.0
|
226 |
+
|
227 |
+
im = im1 * mask_blur + (1 - mask_blur) * im2
|
228 |
+
|
229 |
+
im /= 255.0
|
230 |
+
im = np.clip(im, 0.0, 1.0)
|
231 |
+
|
232 |
+
return im
|
233 |
+
|
234 |
+
|
235 |
+
# def Poisson_blending(im1,im2,mask):
|
236 |
+
|
237 |
+
|
238 |
+
# Image.composite(
|
239 |
+
def Poisson_blending(im1, im2, mask):
|
240 |
+
|
241 |
+
# mask=1-mask
|
242 |
+
mask *= 255
|
243 |
+
kernel = np.ones((10, 10), np.uint8)
|
244 |
+
mask = cv2.erode(mask, kernel, iterations=1)
|
245 |
+
mask /= 255
|
246 |
+
mask = 1 - mask
|
247 |
+
mask *= 255
|
248 |
+
|
249 |
+
mask = mask[:, :, 0]
|
250 |
+
width, height, channels = im1.shape
|
251 |
+
center = (int(height / 2), int(width / 2))
|
252 |
+
result = cv2.seamlessClone(
|
253 |
+
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
|
254 |
+
)
|
255 |
+
|
256 |
+
return result / 255.0
|
257 |
+
|
258 |
+
|
259 |
+
def Poisson_B(im1, im2, mask, center):
|
260 |
+
|
261 |
+
mask *= 255
|
262 |
+
|
263 |
+
result = cv2.seamlessClone(
|
264 |
+
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
|
265 |
+
)
|
266 |
+
|
267 |
+
return result / 255
|
268 |
+
|
269 |
+
|
270 |
+
def seamless_clone(old_face, new_face, raw_mask):
|
271 |
+
|
272 |
+
height, width, _ = old_face.shape
|
273 |
+
height = height // 2
|
274 |
+
width = width // 2
|
275 |
+
|
276 |
+
y_indices, x_indices, _ = np.nonzero(raw_mask)
|
277 |
+
y_crop = slice(np.min(y_indices), np.max(y_indices))
|
278 |
+
x_crop = slice(np.min(x_indices), np.max(x_indices))
|
279 |
+
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
|
280 |
+
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
|
281 |
+
|
282 |
+
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
|
283 |
+
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
|
284 |
+
insertion_mask[insertion_mask != 0] = 255
|
285 |
+
prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
|
286 |
+
"uint8"
|
287 |
+
)
|
288 |
+
# if np.sum(insertion_mask) == 0:
|
289 |
+
n_mask = insertion_mask[1:-1, 1:-1, :]
|
290 |
+
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
|
291 |
+
print(n_mask.shape)
|
292 |
+
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
|
293 |
+
if w < 4 or h < 4:
|
294 |
+
blended = prior
|
295 |
+
else:
|
296 |
+
blended = cv2.seamlessClone(
|
297 |
+
insertion, # pylint: disable=no-member
|
298 |
+
prior,
|
299 |
+
insertion_mask,
|
300 |
+
(x_center, y_center),
|
301 |
+
cv2.NORMAL_CLONE,
|
302 |
+
) # pylint: disable=no-member
|
303 |
+
|
304 |
+
blended = blended[height:-height, width:-width]
|
305 |
+
|
306 |
+
return blended.astype("float32") / 255.0
|
307 |
+
|
308 |
+
|
309 |
+
def get_landmark(face_landmarks, id):
|
310 |
+
part = face_landmarks.part(id)
|
311 |
+
x = part.x
|
312 |
+
y = part.y
|
313 |
+
|
314 |
+
return (x, y)
|
315 |
+
|
316 |
+
|
317 |
+
def search(face_landmarks):
|
318 |
+
|
319 |
+
x1, y1 = get_landmark(face_landmarks, 36)
|
320 |
+
x2, y2 = get_landmark(face_landmarks, 39)
|
321 |
+
x3, y3 = get_landmark(face_landmarks, 42)
|
322 |
+
x4, y4 = get_landmark(face_landmarks, 45)
|
323 |
+
|
324 |
+
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
325 |
+
|
326 |
+
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
327 |
+
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
328 |
+
|
329 |
+
x_left_eye = int((x1 + x2) / 2)
|
330 |
+
y_left_eye = int((y1 + y2) / 2)
|
331 |
+
x_right_eye = int((x3 + x4) / 2)
|
332 |
+
y_right_eye = int((y3 + y4) / 2)
|
333 |
+
|
334 |
+
results = np.array(
|
335 |
+
[
|
336 |
+
[x_left_eye, y_left_eye],
|
337 |
+
[x_right_eye, y_right_eye],
|
338 |
+
[x_nose, y_nose],
|
339 |
+
[x_left_mouth, y_left_mouth],
|
340 |
+
[x_right_mouth, y_right_mouth],
|
341 |
+
]
|
342 |
+
)
|
343 |
+
|
344 |
+
return results
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == "__main__":
|
348 |
+
|
349 |
+
parser = argparse.ArgumentParser()
|
350 |
+
parser.add_argument("--origin_url", type=str, default="./", help="origin images")
|
351 |
+
parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
|
352 |
+
parser.add_argument("--save_url", type=str, default="./save")
|
353 |
+
opts = parser.parse_args()
|
354 |
+
|
355 |
+
origin_url = opts.origin_url
|
356 |
+
replace_url = opts.replace_url
|
357 |
+
save_url = opts.save_url
|
358 |
+
|
359 |
+
if not os.path.exists(save_url):
|
360 |
+
os.makedirs(save_url)
|
361 |
+
|
362 |
+
face_detector = dlib.get_frontal_face_detector()
|
363 |
+
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
364 |
+
|
365 |
+
count = 0
|
366 |
+
|
367 |
+
for x in os.listdir(origin_url):
|
368 |
+
img_url = os.path.join(origin_url, x)
|
369 |
+
pil_img = Image.open(img_url).convert("RGB")
|
370 |
+
|
371 |
+
origin_width, origin_height = pil_img.size
|
372 |
+
image = np.array(pil_img)
|
373 |
+
|
374 |
+
start = time.time()
|
375 |
+
faces = face_detector(image)
|
376 |
+
done = time.time()
|
377 |
+
|
378 |
+
if len(faces) == 0:
|
379 |
+
print("Warning: There is no face in %s" % (x))
|
380 |
+
continue
|
381 |
+
|
382 |
+
blended = image
|
383 |
+
for face_id in range(len(faces)):
|
384 |
+
|
385 |
+
current_face = faces[face_id]
|
386 |
+
face_landmarks = landmark_locator(image, current_face)
|
387 |
+
current_fl = search(face_landmarks)
|
388 |
+
|
389 |
+
forward_mask = np.ones_like(image).astype("uint8")
|
390 |
+
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
|
391 |
+
aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True)
|
392 |
+
forward_mask = warp(
|
393 |
+
forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True
|
394 |
+
)
|
395 |
+
|
396 |
+
affine_inverse = affine.inverse
|
397 |
+
cur_face = aligned_face
|
398 |
+
if replace_url != "":
|
399 |
+
|
400 |
+
face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
|
401 |
+
cur_url = os.path.join(replace_url, face_name)
|
402 |
+
restored_face = Image.open(cur_url).convert("RGB")
|
403 |
+
restored_face = np.array(restored_face)
|
404 |
+
cur_face = restored_face
|
405 |
+
|
406 |
+
## Histogram Color matching
|
407 |
+
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
408 |
+
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
409 |
+
B = match_histograms(B, A)
|
410 |
+
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
|
411 |
+
|
412 |
+
warped_back = warp(
|
413 |
+
cur_face,
|
414 |
+
affine_inverse,
|
415 |
+
output_shape=(origin_height, origin_width, 3),
|
416 |
+
order=3,
|
417 |
+
preserve_range=True,
|
418 |
+
)
|
419 |
+
|
420 |
+
backward_mask = warp(
|
421 |
+
forward_mask,
|
422 |
+
affine_inverse,
|
423 |
+
output_shape=(origin_height, origin_width, 3),
|
424 |
+
order=0,
|
425 |
+
preserve_range=True,
|
426 |
+
) ## Nearest neighbour
|
427 |
+
|
428 |
+
blended = blur_blending_cv2(warped_back, blended, backward_mask)
|
429 |
+
blended *= 255.0
|
430 |
+
|
431 |
+
io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))
|
432 |
+
|
433 |
+
count += 1
|
434 |
+
|
435 |
+
if count % 1000 == 0:
|
436 |
+
print("%d have finished ..." % (count))
|
437 |
+
|
Face_Detection/detect_all_dlib.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import skimage.io as io
|
7 |
+
|
8 |
+
# from FaceSDK.face_sdk import FaceDetection
|
9 |
+
# from face_sdk import FaceDetection
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from matplotlib.patches import Rectangle
|
12 |
+
from skimage.transform import SimilarityTransform
|
13 |
+
from skimage.transform import warp
|
14 |
+
from PIL import Image
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torchvision as tv
|
17 |
+
import torchvision.utils as vutils
|
18 |
+
import time
|
19 |
+
import cv2
|
20 |
+
import os
|
21 |
+
from skimage import img_as_ubyte
|
22 |
+
import json
|
23 |
+
import argparse
|
24 |
+
import dlib
|
25 |
+
|
26 |
+
|
27 |
+
def _standard_face_pts():
|
28 |
+
pts = (
|
29 |
+
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
|
30 |
+
- 1.0
|
31 |
+
)
|
32 |
+
|
33 |
+
return np.reshape(pts, (5, 2))
|
34 |
+
|
35 |
+
|
36 |
+
def _origin_face_pts():
|
37 |
+
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
|
38 |
+
|
39 |
+
return np.reshape(pts, (5, 2))
|
40 |
+
|
41 |
+
|
42 |
+
def get_landmark(face_landmarks, id):
|
43 |
+
part = face_landmarks.part(id)
|
44 |
+
x = part.x
|
45 |
+
y = part.y
|
46 |
+
|
47 |
+
return (x, y)
|
48 |
+
|
49 |
+
|
50 |
+
def search(face_landmarks):
|
51 |
+
|
52 |
+
x1, y1 = get_landmark(face_landmarks, 36)
|
53 |
+
x2, y2 = get_landmark(face_landmarks, 39)
|
54 |
+
x3, y3 = get_landmark(face_landmarks, 42)
|
55 |
+
x4, y4 = get_landmark(face_landmarks, 45)
|
56 |
+
|
57 |
+
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
58 |
+
|
59 |
+
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
60 |
+
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
61 |
+
|
62 |
+
x_left_eye = int((x1 + x2) / 2)
|
63 |
+
y_left_eye = int((y1 + y2) / 2)
|
64 |
+
x_right_eye = int((x3 + x4) / 2)
|
65 |
+
y_right_eye = int((y3 + y4) / 2)
|
66 |
+
|
67 |
+
results = np.array(
|
68 |
+
[
|
69 |
+
[x_left_eye, y_left_eye],
|
70 |
+
[x_right_eye, y_right_eye],
|
71 |
+
[x_nose, y_nose],
|
72 |
+
[x_left_mouth, y_left_mouth],
|
73 |
+
[x_right_mouth, y_right_mouth],
|
74 |
+
]
|
75 |
+
)
|
76 |
+
|
77 |
+
return results
|
78 |
+
|
79 |
+
|
80 |
+
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
81 |
+
|
82 |
+
std_pts = _standard_face_pts() # [-1,1]
|
83 |
+
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
|
84 |
+
|
85 |
+
# print(target_pts)
|
86 |
+
|
87 |
+
h, w, c = img.shape
|
88 |
+
if normalize == True:
|
89 |
+
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
90 |
+
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
91 |
+
|
92 |
+
# print(landmark)
|
93 |
+
|
94 |
+
affine = SimilarityTransform()
|
95 |
+
|
96 |
+
affine.estimate(target_pts, landmark)
|
97 |
+
|
98 |
+
return affine.params
|
99 |
+
|
100 |
+
|
101 |
+
def show_detection(image, box, landmark):
|
102 |
+
plt.imshow(image)
|
103 |
+
print(box[2] - box[0])
|
104 |
+
plt.gca().add_patch(
|
105 |
+
Rectangle(
|
106 |
+
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
|
107 |
+
)
|
108 |
+
)
|
109 |
+
plt.scatter(landmark[0][0], landmark[0][1])
|
110 |
+
plt.scatter(landmark[1][0], landmark[1][1])
|
111 |
+
plt.scatter(landmark[2][0], landmark[2][1])
|
112 |
+
plt.scatter(landmark[3][0], landmark[3][1])
|
113 |
+
plt.scatter(landmark[4][0], landmark[4][1])
|
114 |
+
plt.show()
|
115 |
+
|
116 |
+
|
117 |
+
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
118 |
+
# param = np.linalg.inv(affine)
|
119 |
+
param = affine
|
120 |
+
theta = np.zeros([2, 3])
|
121 |
+
theta[0, 0] = param[0, 0] * input_h / target_h
|
122 |
+
theta[0, 1] = param[0, 1] * input_w / target_h
|
123 |
+
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
|
124 |
+
theta[1, 0] = param[1, 0] * input_h / target_w
|
125 |
+
theta[1, 1] = param[1, 1] * input_w / target_w
|
126 |
+
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
|
127 |
+
return theta
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
|
132 |
+
parser = argparse.ArgumentParser()
|
133 |
+
parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
|
134 |
+
parser.add_argument(
|
135 |
+
"--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
|
136 |
+
)
|
137 |
+
opts = parser.parse_args()
|
138 |
+
|
139 |
+
url = opts.url
|
140 |
+
save_url = opts.save_url
|
141 |
+
|
142 |
+
### If the origin url is None, then we don't need to reid the origin image
|
143 |
+
|
144 |
+
os.makedirs(url, exist_ok=True)
|
145 |
+
os.makedirs(save_url, exist_ok=True)
|
146 |
+
|
147 |
+
face_detector = dlib.get_frontal_face_detector()
|
148 |
+
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
149 |
+
|
150 |
+
count = 0
|
151 |
+
|
152 |
+
map_id = {}
|
153 |
+
for x in os.listdir(url):
|
154 |
+
img_url = os.path.join(url, x)
|
155 |
+
pil_img = Image.open(img_url).convert("RGB")
|
156 |
+
|
157 |
+
image = np.array(pil_img)
|
158 |
+
|
159 |
+
start = time.time()
|
160 |
+
faces = face_detector(image)
|
161 |
+
done = time.time()
|
162 |
+
|
163 |
+
if len(faces) == 0:
|
164 |
+
print("Warning: There is no face in %s" % (x))
|
165 |
+
continue
|
166 |
+
|
167 |
+
print(len(faces))
|
168 |
+
|
169 |
+
if len(faces) > 0:
|
170 |
+
for face_id in range(len(faces)):
|
171 |
+
current_face = faces[face_id]
|
172 |
+
face_landmarks = landmark_locator(image, current_face)
|
173 |
+
current_fl = search(face_landmarks)
|
174 |
+
|
175 |
+
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
|
176 |
+
aligned_face = warp(image, affine, output_shape=(256, 256, 3))
|
177 |
+
img_name = x[:-4] + "_" + str(face_id + 1)
|
178 |
+
io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))
|
179 |
+
|
180 |
+
count += 1
|
181 |
+
|
182 |
+
if count % 1000 == 0:
|
183 |
+
print("%d have finished ..." % (count))
|
184 |
+
|
Face_Detection/detect_all_dlib_HR.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import skimage.io as io
|
7 |
+
|
8 |
+
# from FaceSDK.face_sdk import FaceDetection
|
9 |
+
# from face_sdk import FaceDetection
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from matplotlib.patches import Rectangle
|
12 |
+
from skimage.transform import SimilarityTransform
|
13 |
+
from skimage.transform import warp
|
14 |
+
from PIL import Image
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torchvision as tv
|
17 |
+
import torchvision.utils as vutils
|
18 |
+
import time
|
19 |
+
import cv2
|
20 |
+
import os
|
21 |
+
from skimage import img_as_ubyte
|
22 |
+
import json
|
23 |
+
import argparse
|
24 |
+
import dlib
|
25 |
+
|
26 |
+
|
27 |
+
def _standard_face_pts():
|
28 |
+
pts = (
|
29 |
+
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
|
30 |
+
- 1.0
|
31 |
+
)
|
32 |
+
|
33 |
+
return np.reshape(pts, (5, 2))
|
34 |
+
|
35 |
+
|
36 |
+
def _origin_face_pts():
|
37 |
+
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
|
38 |
+
|
39 |
+
return np.reshape(pts, (5, 2))
|
40 |
+
|
41 |
+
|
42 |
+
def get_landmark(face_landmarks, id):
|
43 |
+
part = face_landmarks.part(id)
|
44 |
+
x = part.x
|
45 |
+
y = part.y
|
46 |
+
|
47 |
+
return (x, y)
|
48 |
+
|
49 |
+
|
50 |
+
def search(face_landmarks):
|
51 |
+
|
52 |
+
x1, y1 = get_landmark(face_landmarks, 36)
|
53 |
+
x2, y2 = get_landmark(face_landmarks, 39)
|
54 |
+
x3, y3 = get_landmark(face_landmarks, 42)
|
55 |
+
x4, y4 = get_landmark(face_landmarks, 45)
|
56 |
+
|
57 |
+
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
58 |
+
|
59 |
+
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
60 |
+
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
61 |
+
|
62 |
+
x_left_eye = int((x1 + x2) / 2)
|
63 |
+
y_left_eye = int((y1 + y2) / 2)
|
64 |
+
x_right_eye = int((x3 + x4) / 2)
|
65 |
+
y_right_eye = int((y3 + y4) / 2)
|
66 |
+
|
67 |
+
results = np.array(
|
68 |
+
[
|
69 |
+
[x_left_eye, y_left_eye],
|
70 |
+
[x_right_eye, y_right_eye],
|
71 |
+
[x_nose, y_nose],
|
72 |
+
[x_left_mouth, y_left_mouth],
|
73 |
+
[x_right_mouth, y_right_mouth],
|
74 |
+
]
|
75 |
+
)
|
76 |
+
|
77 |
+
return results
|
78 |
+
|
79 |
+
|
80 |
+
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
81 |
+
|
82 |
+
std_pts = _standard_face_pts() # [-1,1]
|
83 |
+
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
|
84 |
+
|
85 |
+
# print(target_pts)
|
86 |
+
|
87 |
+
h, w, c = img.shape
|
88 |
+
if normalize == True:
|
89 |
+
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
90 |
+
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
91 |
+
|
92 |
+
# print(landmark)
|
93 |
+
|
94 |
+
affine = SimilarityTransform()
|
95 |
+
|
96 |
+
affine.estimate(target_pts, landmark)
|
97 |
+
|
98 |
+
return affine.params
|
99 |
+
|
100 |
+
|
101 |
+
def show_detection(image, box, landmark):
|
102 |
+
plt.imshow(image)
|
103 |
+
print(box[2] - box[0])
|
104 |
+
plt.gca().add_patch(
|
105 |
+
Rectangle(
|
106 |
+
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
|
107 |
+
)
|
108 |
+
)
|
109 |
+
plt.scatter(landmark[0][0], landmark[0][1])
|
110 |
+
plt.scatter(landmark[1][0], landmark[1][1])
|
111 |
+
plt.scatter(landmark[2][0], landmark[2][1])
|
112 |
+
plt.scatter(landmark[3][0], landmark[3][1])
|
113 |
+
plt.scatter(landmark[4][0], landmark[4][1])
|
114 |
+
plt.show()
|
115 |
+
|
116 |
+
|
117 |
+
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
118 |
+
# param = np.linalg.inv(affine)
|
119 |
+
param = affine
|
120 |
+
theta = np.zeros([2, 3])
|
121 |
+
theta[0, 0] = param[0, 0] * input_h / target_h
|
122 |
+
theta[0, 1] = param[0, 1] * input_w / target_h
|
123 |
+
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
|
124 |
+
theta[1, 0] = param[1, 0] * input_h / target_w
|
125 |
+
theta[1, 1] = param[1, 1] * input_w / target_w
|
126 |
+
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
|
127 |
+
return theta
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
|
132 |
+
parser = argparse.ArgumentParser()
|
133 |
+
parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
|
134 |
+
parser.add_argument(
|
135 |
+
"--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
|
136 |
+
)
|
137 |
+
opts = parser.parse_args()
|
138 |
+
|
139 |
+
url = opts.url
|
140 |
+
save_url = opts.save_url
|
141 |
+
|
142 |
+
### If the origin url is None, then we don't need to reid the origin image
|
143 |
+
|
144 |
+
os.makedirs(url, exist_ok=True)
|
145 |
+
os.makedirs(save_url, exist_ok=True)
|
146 |
+
|
147 |
+
face_detector = dlib.get_frontal_face_detector()
|
148 |
+
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
149 |
+
|
150 |
+
count = 0
|
151 |
+
|
152 |
+
map_id = {}
|
153 |
+
for x in os.listdir(url):
|
154 |
+
img_url = os.path.join(url, x)
|
155 |
+
pil_img = Image.open(img_url).convert("RGB")
|
156 |
+
|
157 |
+
image = np.array(pil_img)
|
158 |
+
|
159 |
+
start = time.time()
|
160 |
+
faces = face_detector(image)
|
161 |
+
done = time.time()
|
162 |
+
|
163 |
+
if len(faces) == 0:
|
164 |
+
print("Warning: There is no face in %s" % (x))
|
165 |
+
continue
|
166 |
+
|
167 |
+
print(len(faces))
|
168 |
+
|
169 |
+
if len(faces) > 0:
|
170 |
+
for face_id in range(len(faces)):
|
171 |
+
current_face = faces[face_id]
|
172 |
+
face_landmarks = landmark_locator(image, current_face)
|
173 |
+
current_fl = search(face_landmarks)
|
174 |
+
|
175 |
+
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
|
176 |
+
aligned_face = warp(image, affine, output_shape=(512, 512, 3))
|
177 |
+
img_name = x[:-4] + "_" + str(face_id + 1)
|
178 |
+
io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))
|
179 |
+
|
180 |
+
count += 1
|
181 |
+
|
182 |
+
if count % 1000 == 0:
|
183 |
+
print("%d have finished ..." % (count))
|
184 |
+
|
Face_Enhancement/data/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import importlib
|
5 |
+
import torch.utils.data
|
6 |
+
from data.base_dataset import BaseDataset
|
7 |
+
from data.face_dataset import FaceTestDataset
|
8 |
+
|
9 |
+
|
10 |
+
def create_dataloader(opt):
|
11 |
+
|
12 |
+
instance = FaceTestDataset()
|
13 |
+
instance.initialize(opt)
|
14 |
+
print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
|
15 |
+
dataloader = torch.utils.data.DataLoader(
|
16 |
+
instance,
|
17 |
+
batch_size=opt.batchSize,
|
18 |
+
shuffle=not opt.serial_batches,
|
19 |
+
num_workers=int(opt.nThreads),
|
20 |
+
drop_last=opt.isTrain,
|
21 |
+
)
|
22 |
+
return dataloader
|
Face_Enhancement/data/base_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch.utils.data as data
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
|
10 |
+
|
11 |
+
class BaseDataset(data.Dataset):
|
12 |
+
def __init__(self):
|
13 |
+
super(BaseDataset, self).__init__()
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def modify_commandline_options(parser, is_train):
|
17 |
+
return parser
|
18 |
+
|
19 |
+
def initialize(self, opt):
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
def get_params(opt, size):
|
24 |
+
w, h = size
|
25 |
+
new_h = h
|
26 |
+
new_w = w
|
27 |
+
if opt.preprocess_mode == "resize_and_crop":
|
28 |
+
new_h = new_w = opt.load_size
|
29 |
+
elif opt.preprocess_mode == "scale_width_and_crop":
|
30 |
+
new_w = opt.load_size
|
31 |
+
new_h = opt.load_size * h // w
|
32 |
+
elif opt.preprocess_mode == "scale_shortside_and_crop":
|
33 |
+
ss, ls = min(w, h), max(w, h) # shortside and longside
|
34 |
+
width_is_shorter = w == ss
|
35 |
+
ls = int(opt.load_size * ls / ss)
|
36 |
+
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
|
37 |
+
|
38 |
+
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
39 |
+
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
40 |
+
|
41 |
+
flip = random.random() > 0.5
|
42 |
+
return {"crop_pos": (x, y), "flip": flip}
|
43 |
+
|
44 |
+
|
45 |
+
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
|
46 |
+
transform_list = []
|
47 |
+
if "resize" in opt.preprocess_mode:
|
48 |
+
osize = [opt.load_size, opt.load_size]
|
49 |
+
transform_list.append(transforms.Resize(osize, interpolation=method))
|
50 |
+
elif "scale_width" in opt.preprocess_mode:
|
51 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
|
52 |
+
elif "scale_shortside" in opt.preprocess_mode:
|
53 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
|
54 |
+
|
55 |
+
if "crop" in opt.preprocess_mode:
|
56 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size)))
|
57 |
+
|
58 |
+
if opt.preprocess_mode == "none":
|
59 |
+
base = 32
|
60 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
61 |
+
|
62 |
+
if opt.preprocess_mode == "fixed":
|
63 |
+
w = opt.crop_size
|
64 |
+
h = round(opt.crop_size / opt.aspect_ratio)
|
65 |
+
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
|
66 |
+
|
67 |
+
if opt.isTrain and not opt.no_flip:
|
68 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))
|
69 |
+
|
70 |
+
if toTensor:
|
71 |
+
transform_list += [transforms.ToTensor()]
|
72 |
+
|
73 |
+
if normalize:
|
74 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
75 |
+
return transforms.Compose(transform_list)
|
76 |
+
|
77 |
+
|
78 |
+
def normalize():
|
79 |
+
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
80 |
+
|
81 |
+
|
82 |
+
def __resize(img, w, h, method=Image.BICUBIC):
|
83 |
+
return img.resize((w, h), method)
|
84 |
+
|
85 |
+
|
86 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
87 |
+
ow, oh = img.size
|
88 |
+
h = int(round(oh / base) * base)
|
89 |
+
w = int(round(ow / base) * base)
|
90 |
+
if (h == oh) and (w == ow):
|
91 |
+
return img
|
92 |
+
return img.resize((w, h), method)
|
93 |
+
|
94 |
+
|
95 |
+
def __scale_width(img, target_width, method=Image.BICUBIC):
|
96 |
+
ow, oh = img.size
|
97 |
+
if ow == target_width:
|
98 |
+
return img
|
99 |
+
w = target_width
|
100 |
+
h = int(target_width * oh / ow)
|
101 |
+
return img.resize((w, h), method)
|
102 |
+
|
103 |
+
|
104 |
+
def __scale_shortside(img, target_width, method=Image.BICUBIC):
|
105 |
+
ow, oh = img.size
|
106 |
+
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
|
107 |
+
width_is_shorter = ow == ss
|
108 |
+
if ss == target_width:
|
109 |
+
return img
|
110 |
+
ls = int(target_width * ls / ss)
|
111 |
+
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
|
112 |
+
return img.resize((nw, nh), method)
|
113 |
+
|
114 |
+
|
115 |
+
def __crop(img, pos, size):
|
116 |
+
ow, oh = img.size
|
117 |
+
x1, y1 = pos
|
118 |
+
tw = th = size
|
119 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
120 |
+
|
121 |
+
|
122 |
+
def __flip(img, flip):
|
123 |
+
if flip:
|
124 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
125 |
+
return img
|
Face_Enhancement/data/custom_dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from data.pix2pix_dataset import Pix2pixDataset
|
5 |
+
from data.image_folder import make_dataset
|
6 |
+
|
7 |
+
|
8 |
+
class CustomDataset(Pix2pixDataset):
|
9 |
+
""" Dataset that loads images from directories
|
10 |
+
Use option --label_dir, --image_dir, --instance_dir to specify the directories.
|
11 |
+
The images in the directories are sorted in alphabetical order and paired in order.
|
12 |
+
"""
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def modify_commandline_options(parser, is_train):
|
16 |
+
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
|
17 |
+
parser.set_defaults(preprocess_mode="resize_and_crop")
|
18 |
+
load_size = 286 if is_train else 256
|
19 |
+
parser.set_defaults(load_size=load_size)
|
20 |
+
parser.set_defaults(crop_size=256)
|
21 |
+
parser.set_defaults(display_winsize=256)
|
22 |
+
parser.set_defaults(label_nc=13)
|
23 |
+
parser.set_defaults(contain_dontcare_label=False)
|
24 |
+
|
25 |
+
parser.add_argument(
|
26 |
+
"--label_dir", type=str, required=True, help="path to the directory that contains label images"
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--image_dir", type=str, required=True, help="path to the directory that contains photo images"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--instance_dir",
|
33 |
+
type=str,
|
34 |
+
default="",
|
35 |
+
help="path to the directory that contains instance maps. Leave black if not exists",
|
36 |
+
)
|
37 |
+
return parser
|
38 |
+
|
39 |
+
def get_paths(self, opt):
|
40 |
+
label_dir = opt.label_dir
|
41 |
+
label_paths = make_dataset(label_dir, recursive=False, read_cache=True)
|
42 |
+
|
43 |
+
image_dir = opt.image_dir
|
44 |
+
image_paths = make_dataset(image_dir, recursive=False, read_cache=True)
|
45 |
+
|
46 |
+
if len(opt.instance_dir) > 0:
|
47 |
+
instance_dir = opt.instance_dir
|
48 |
+
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
|
49 |
+
else:
|
50 |
+
instance_paths = []
|
51 |
+
|
52 |
+
assert len(label_paths) == len(
|
53 |
+
image_paths
|
54 |
+
), "The #images in %s and %s do not match. Is there something wrong?"
|
55 |
+
|
56 |
+
return label_paths, image_paths, instance_paths
|
Face_Enhancement/data/face_dataset.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from data.base_dataset import BaseDataset, get_params, get_transform
|
5 |
+
from PIL import Image
|
6 |
+
import util.util as util
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class FaceTestDataset(BaseDataset):
|
12 |
+
@staticmethod
|
13 |
+
def modify_commandline_options(parser, is_train):
|
14 |
+
parser.add_argument(
|
15 |
+
"--no_pairing_check",
|
16 |
+
action="store_true",
|
17 |
+
help="If specified, skip sanity check of correct label-image file pairing",
|
18 |
+
)
|
19 |
+
# parser.set_defaults(contain_dontcare_label=False)
|
20 |
+
# parser.set_defaults(no_instance=True)
|
21 |
+
return parser
|
22 |
+
|
23 |
+
def initialize(self, opt):
|
24 |
+
self.opt = opt
|
25 |
+
|
26 |
+
image_path = os.path.join(opt.dataroot, opt.old_face_folder)
|
27 |
+
label_path = os.path.join(opt.dataroot, opt.old_face_label_folder)
|
28 |
+
|
29 |
+
image_list = os.listdir(image_path)
|
30 |
+
image_list = sorted(image_list)
|
31 |
+
# image_list=image_list[:opt.max_dataset_size]
|
32 |
+
|
33 |
+
self.label_paths = label_path ## Just the root dir
|
34 |
+
self.image_paths = image_list ## All the image name
|
35 |
+
|
36 |
+
self.parts = [
|
37 |
+
"skin",
|
38 |
+
"hair",
|
39 |
+
"l_brow",
|
40 |
+
"r_brow",
|
41 |
+
"l_eye",
|
42 |
+
"r_eye",
|
43 |
+
"eye_g",
|
44 |
+
"l_ear",
|
45 |
+
"r_ear",
|
46 |
+
"ear_r",
|
47 |
+
"nose",
|
48 |
+
"mouth",
|
49 |
+
"u_lip",
|
50 |
+
"l_lip",
|
51 |
+
"neck",
|
52 |
+
"neck_l",
|
53 |
+
"cloth",
|
54 |
+
"hat",
|
55 |
+
]
|
56 |
+
|
57 |
+
size = len(self.image_paths)
|
58 |
+
self.dataset_size = size
|
59 |
+
|
60 |
+
def __getitem__(self, index):
|
61 |
+
|
62 |
+
params = get_params(self.opt, (-1, -1))
|
63 |
+
image_name = self.image_paths[index]
|
64 |
+
image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name)
|
65 |
+
image = Image.open(image_path)
|
66 |
+
image = image.convert("RGB")
|
67 |
+
|
68 |
+
transform_image = get_transform(self.opt, params)
|
69 |
+
image_tensor = transform_image(image)
|
70 |
+
|
71 |
+
img_name = image_name[:-4]
|
72 |
+
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
73 |
+
full_label = []
|
74 |
+
|
75 |
+
cnt = 0
|
76 |
+
|
77 |
+
for each_part in self.parts:
|
78 |
+
part_name = img_name + "_" + each_part + ".png"
|
79 |
+
part_url = os.path.join(self.label_paths, part_name)
|
80 |
+
|
81 |
+
if os.path.exists(part_url):
|
82 |
+
label = Image.open(part_url).convert("RGB")
|
83 |
+
label_tensor = transform_label(label) ## 3 channels and pixel [0,1]
|
84 |
+
full_label.append(label_tensor[0])
|
85 |
+
else:
|
86 |
+
current_part = torch.zeros((self.opt.load_size, self.opt.load_size))
|
87 |
+
full_label.append(current_part)
|
88 |
+
cnt += 1
|
89 |
+
|
90 |
+
full_label_tensor = torch.stack(full_label, 0)
|
91 |
+
|
92 |
+
input_dict = {
|
93 |
+
"label": full_label_tensor,
|
94 |
+
"image": image_tensor,
|
95 |
+
"path": image_path,
|
96 |
+
}
|
97 |
+
|
98 |
+
return input_dict
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return self.dataset_size
|
102 |
+
|
Face_Enhancement/data/image_folder.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch.utils.data as data
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
".jpg",
|
10 |
+
".JPG",
|
11 |
+
".jpeg",
|
12 |
+
".JPEG",
|
13 |
+
".png",
|
14 |
+
".PNG",
|
15 |
+
".ppm",
|
16 |
+
".PPM",
|
17 |
+
".bmp",
|
18 |
+
".BMP",
|
19 |
+
".tiff",
|
20 |
+
".webp",
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
def is_image_file(filename):
|
25 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
26 |
+
|
27 |
+
|
28 |
+
def make_dataset_rec(dir, images):
|
29 |
+
assert os.path.isdir(dir), "%s is not a valid directory" % dir
|
30 |
+
|
31 |
+
for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
|
32 |
+
for fname in fnames:
|
33 |
+
if is_image_file(fname):
|
34 |
+
path = os.path.join(root, fname)
|
35 |
+
images.append(path)
|
36 |
+
|
37 |
+
|
38 |
+
def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
|
39 |
+
images = []
|
40 |
+
|
41 |
+
if read_cache:
|
42 |
+
possible_filelist = os.path.join(dir, "files.list")
|
43 |
+
if os.path.isfile(possible_filelist):
|
44 |
+
with open(possible_filelist, "r") as f:
|
45 |
+
images = f.read().splitlines()
|
46 |
+
return images
|
47 |
+
|
48 |
+
if recursive:
|
49 |
+
make_dataset_rec(dir, images)
|
50 |
+
else:
|
51 |
+
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
|
52 |
+
|
53 |
+
for root, dnames, fnames in sorted(os.walk(dir)):
|
54 |
+
for fname in fnames:
|
55 |
+
if is_image_file(fname):
|
56 |
+
path = os.path.join(root, fname)
|
57 |
+
images.append(path)
|
58 |
+
|
59 |
+
if write_cache:
|
60 |
+
filelist_cache = os.path.join(dir, "files.list")
|
61 |
+
with open(filelist_cache, "w") as f:
|
62 |
+
for path in images:
|
63 |
+
f.write("%s\n" % path)
|
64 |
+
print("wrote filelist cache at %s" % filelist_cache)
|
65 |
+
|
66 |
+
return images
|
67 |
+
|
68 |
+
|
69 |
+
def default_loader(path):
|
70 |
+
return Image.open(path).convert("RGB")
|
71 |
+
|
72 |
+
|
73 |
+
class ImageFolder(data.Dataset):
|
74 |
+
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
|
75 |
+
imgs = make_dataset(root)
|
76 |
+
if len(imgs) == 0:
|
77 |
+
raise (
|
78 |
+
RuntimeError(
|
79 |
+
"Found 0 images in: " + root + "\n"
|
80 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
|
81 |
+
)
|
82 |
+
)
|
83 |
+
|
84 |
+
self.root = root
|
85 |
+
self.imgs = imgs
|
86 |
+
self.transform = transform
|
87 |
+
self.return_paths = return_paths
|
88 |
+
self.loader = loader
|
89 |
+
|
90 |
+
def __getitem__(self, index):
|
91 |
+
path = self.imgs[index]
|
92 |
+
img = self.loader(path)
|
93 |
+
if self.transform is not None:
|
94 |
+
img = self.transform(img)
|
95 |
+
if self.return_paths:
|
96 |
+
return img, path
|
97 |
+
else:
|
98 |
+
return img
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.imgs)
|
Face_Enhancement/data/pix2pix_dataset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from data.base_dataset import BaseDataset, get_params, get_transform
|
5 |
+
from PIL import Image
|
6 |
+
import util.util as util
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
class Pix2pixDataset(BaseDataset):
|
11 |
+
@staticmethod
|
12 |
+
def modify_commandline_options(parser, is_train):
|
13 |
+
parser.add_argument(
|
14 |
+
"--no_pairing_check",
|
15 |
+
action="store_true",
|
16 |
+
help="If specified, skip sanity check of correct label-image file pairing",
|
17 |
+
)
|
18 |
+
return parser
|
19 |
+
|
20 |
+
def initialize(self, opt):
|
21 |
+
self.opt = opt
|
22 |
+
|
23 |
+
label_paths, image_paths, instance_paths = self.get_paths(opt)
|
24 |
+
|
25 |
+
util.natural_sort(label_paths)
|
26 |
+
util.natural_sort(image_paths)
|
27 |
+
if not opt.no_instance:
|
28 |
+
util.natural_sort(instance_paths)
|
29 |
+
|
30 |
+
label_paths = label_paths[: opt.max_dataset_size]
|
31 |
+
image_paths = image_paths[: opt.max_dataset_size]
|
32 |
+
instance_paths = instance_paths[: opt.max_dataset_size]
|
33 |
+
|
34 |
+
if not opt.no_pairing_check:
|
35 |
+
for path1, path2 in zip(label_paths, image_paths):
|
36 |
+
assert self.paths_match(path1, path2), (
|
37 |
+
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
|
38 |
+
% (path1, path2)
|
39 |
+
)
|
40 |
+
|
41 |
+
self.label_paths = label_paths
|
42 |
+
self.image_paths = image_paths
|
43 |
+
self.instance_paths = instance_paths
|
44 |
+
|
45 |
+
size = len(self.label_paths)
|
46 |
+
self.dataset_size = size
|
47 |
+
|
48 |
+
def get_paths(self, opt):
|
49 |
+
label_paths = []
|
50 |
+
image_paths = []
|
51 |
+
instance_paths = []
|
52 |
+
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
|
53 |
+
return label_paths, image_paths, instance_paths
|
54 |
+
|
55 |
+
def paths_match(self, path1, path2):
|
56 |
+
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
|
57 |
+
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
|
58 |
+
return filename1_without_ext == filename2_without_ext
|
59 |
+
|
60 |
+
def __getitem__(self, index):
|
61 |
+
# Label Image
|
62 |
+
label_path = self.label_paths[index]
|
63 |
+
label = Image.open(label_path)
|
64 |
+
params = get_params(self.opt, label.size)
|
65 |
+
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
66 |
+
label_tensor = transform_label(label) * 255.0
|
67 |
+
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
|
68 |
+
|
69 |
+
# input image (real images)
|
70 |
+
image_path = self.image_paths[index]
|
71 |
+
assert self.paths_match(
|
72 |
+
label_path, image_path
|
73 |
+
), "The label_path %s and image_path %s don't match." % (label_path, image_path)
|
74 |
+
image = Image.open(image_path)
|
75 |
+
image = image.convert("RGB")
|
76 |
+
|
77 |
+
transform_image = get_transform(self.opt, params)
|
78 |
+
image_tensor = transform_image(image)
|
79 |
+
|
80 |
+
# if using instance maps
|
81 |
+
if self.opt.no_instance:
|
82 |
+
instance_tensor = 0
|
83 |
+
else:
|
84 |
+
instance_path = self.instance_paths[index]
|
85 |
+
instance = Image.open(instance_path)
|
86 |
+
if instance.mode == "L":
|
87 |
+
instance_tensor = transform_label(instance) * 255
|
88 |
+
instance_tensor = instance_tensor.long()
|
89 |
+
else:
|
90 |
+
instance_tensor = transform_label(instance)
|
91 |
+
|
92 |
+
input_dict = {
|
93 |
+
"label": label_tensor,
|
94 |
+
"instance": instance_tensor,
|
95 |
+
"image": image_tensor,
|
96 |
+
"path": image_path,
|
97 |
+
}
|
98 |
+
|
99 |
+
# Give subclasses a chance to modify the final output
|
100 |
+
self.postprocess(input_dict)
|
101 |
+
|
102 |
+
return input_dict
|
103 |
+
|
104 |
+
def postprocess(self, input_dict):
|
105 |
+
return input_dict
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return self.dataset_size
|
Face_Enhancement/models/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import importlib
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def find_model_using_name(model_name):
|
9 |
+
# Given the option --model [modelname],
|
10 |
+
# the file "models/modelname_model.py"
|
11 |
+
# will be imported.
|
12 |
+
model_filename = "models." + model_name + "_model"
|
13 |
+
modellib = importlib.import_module(model_filename)
|
14 |
+
|
15 |
+
# In the file, the class called ModelNameModel() will
|
16 |
+
# be instantiated. It has to be a subclass of torch.nn.Module,
|
17 |
+
# and it is case-insensitive.
|
18 |
+
model = None
|
19 |
+
target_model_name = model_name.replace("_", "") + "model"
|
20 |
+
for name, cls in modellib.__dict__.items():
|
21 |
+
if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module):
|
22 |
+
model = cls
|
23 |
+
|
24 |
+
if model is None:
|
25 |
+
print(
|
26 |
+
"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase."
|
27 |
+
% (model_filename, target_model_name)
|
28 |
+
)
|
29 |
+
exit(0)
|
30 |
+
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
def get_option_setter(model_name):
|
35 |
+
model_class = find_model_using_name(model_name)
|
36 |
+
return model_class.modify_commandline_options
|
37 |
+
|
38 |
+
|
39 |
+
def create_model(opt):
|
40 |
+
model = find_model_using_name(opt.model)
|
41 |
+
instance = model(opt)
|
42 |
+
print("model [%s] was created" % (type(instance).__name__))
|
43 |
+
|
44 |
+
return instance
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 Jiayuan MAO
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/README.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Synchronized-BatchNorm-PyTorch
|
2 |
+
|
3 |
+
**IMPORTANT: Please read the "Implementation details and highlights" section before use.**
|
4 |
+
|
5 |
+
Synchronized Batch Normalization implementation in PyTorch.
|
6 |
+
|
7 |
+
This module differs from the built-in PyTorch BatchNorm as the mean and
|
8 |
+
standard-deviation are reduced across all devices during training.
|
9 |
+
|
10 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
11 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
12 |
+
the statistics only on that device, which accelerated the computation and
|
13 |
+
is also easy to implement, but the statistics might be inaccurate.
|
14 |
+
Instead, in this synchronized version, the statistics will be computed
|
15 |
+
over all training samples distributed on multiple devices.
|
16 |
+
|
17 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
18 |
+
as the built-in PyTorch implementation.
|
19 |
+
|
20 |
+
This module is currently only a prototype version for research usages. As mentioned below,
|
21 |
+
it has its limitations and may even suffer from some design problems. If you have any
|
22 |
+
questions or suggestions, please feel free to
|
23 |
+
[open an issue](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues) or
|
24 |
+
[submit a pull request](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues).
|
25 |
+
|
26 |
+
## Why Synchronized BatchNorm?
|
27 |
+
|
28 |
+
Although the typical implementation of BatchNorm working on multiple devices (GPUs)
|
29 |
+
is fast (with no communication overhead), it inevitably reduces the size of batch size,
|
30 |
+
which potentially degenerates the performance. This is not a significant issue in some
|
31 |
+
standard vision tasks such as ImageNet classification (as the batch size per device
|
32 |
+
is usually large enough to obtain good statistics). However, it will hurt the performance
|
33 |
+
in some tasks that the batch size is usually very small (e.g., 1 per GPU).
|
34 |
+
|
35 |
+
For example, the importance of synchronized batch normalization in object detection has been recently proved with a
|
36 |
+
an extensive analysis in the paper [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240).
|
37 |
+
|
38 |
+
## Usage
|
39 |
+
|
40 |
+
To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight
|
41 |
+
difference with typical usage of the `nn.DataParallel`.
|
42 |
+
|
43 |
+
Use it with a provided, customized data parallel wrapper:
|
44 |
+
|
45 |
+
```python
|
46 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback
|
47 |
+
|
48 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
49 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
50 |
+
```
|
51 |
+
|
52 |
+
Or, if you are using a customized data parallel module, you can use this library as a monkey patching.
|
53 |
+
|
54 |
+
```python
|
55 |
+
from torch.nn import DataParallel # or your customized DataParallel module
|
56 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback
|
57 |
+
|
58 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
59 |
+
sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
60 |
+
patch_replication_callback(sync_bn) # monkey-patching
|
61 |
+
```
|
62 |
+
|
63 |
+
You can use `convert_model` to convert your model to use Synchronized BatchNorm easily.
|
64 |
+
|
65 |
+
```python
|
66 |
+
import torch.nn as nn
|
67 |
+
from torchvision import models
|
68 |
+
from sync_batchnorm import convert_model
|
69 |
+
# m is a standard pytorch model
|
70 |
+
m = models.resnet18(True)
|
71 |
+
m = nn.DataParallel(m)
|
72 |
+
# after convert, m is using SyncBN
|
73 |
+
m = convert_model(m)
|
74 |
+
```
|
75 |
+
|
76 |
+
See also `tests/test_sync_batchnorm.py` for numeric result comparison.
|
77 |
+
|
78 |
+
## Implementation details and highlights
|
79 |
+
|
80 |
+
If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look
|
81 |
+
at the code with detailed comments. Here we only emphasize some highlights of the implementation:
|
82 |
+
|
83 |
+
- This implementation is in pure-python. No C++ extra extension libs.
|
84 |
+
- Easy to use as demonstrated above.
|
85 |
+
- It uses unbiased variance to update the moving average, and use `sqrt(max(var, eps))` instead of `sqrt(var + eps)`.
|
86 |
+
- The implementation requires that each module on different devices should invoke the `batchnorm` for exactly SAME
|
87 |
+
amount of times in each forward pass. For example, you can not only call `batchnorm` on GPU0 but not on GPU1. The `#i
|
88 |
+
(i = 1, 2, 3, ...)` calls of the `batchnorm` on each device will be viewed as a whole and the statistics will be reduced.
|
89 |
+
This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this
|
90 |
+
will usually not be the issue for most of the models.
|
91 |
+
|
92 |
+
## Known issues
|
93 |
+
|
94 |
+
#### Runtime error on backward pass.
|
95 |
+
|
96 |
+
Due to a [PyTorch Bug](https://github.com/pytorch/pytorch/issues/3883), using old PyTorch libraries will trigger an `RuntimeError` with messages like:
|
97 |
+
|
98 |
+
```
|
99 |
+
Assertion `pos >= 0 && pos < buffer.size()` failed.
|
100 |
+
```
|
101 |
+
|
102 |
+
This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the
|
103 |
+
instructions [here](https://github.com/pytorch/pytorch#from-source).
|
104 |
+
|
105 |
+
#### Numeric error.
|
106 |
+
|
107 |
+
Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less
|
108 |
+
numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in
|
109 |
+
`tests/test_sync_batchnorm.py`.
|
110 |
+
|
111 |
+
## Authors and License:
|
112 |
+
|
113 |
+
Copyright (c) 2018-, [Jiayuan Mao](https://vccy.xyz).
|
114 |
+
|
115 |
+
**Contributors**: [Tete Xiao](https://tetexiao.com), [DTennant](https://github.com/DTennant).
|
116 |
+
|
117 |
+
Distributed under **MIT License** (See LICENSE)
|
118 |
+
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import set_sbn_eps_mode
|
12 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
13 |
+
from .batchnorm import patch_sync_batchnorm, convert_model
|
14 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
import contextlib
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
21 |
+
except ImportError:
|
22 |
+
ReduceAddCoalesced = Broadcast = None
|
23 |
+
|
24 |
+
try:
|
25 |
+
from jactorch.parallel.comm import SyncMaster
|
26 |
+
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
27 |
+
except ImportError:
|
28 |
+
from .comm import SyncMaster
|
29 |
+
from .replicate import DataParallelWithCallback
|
30 |
+
|
31 |
+
__all__ = [
|
32 |
+
'set_sbn_eps_mode',
|
33 |
+
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
34 |
+
'patch_sync_batchnorm', 'convert_model'
|
35 |
+
]
|
36 |
+
|
37 |
+
|
38 |
+
SBN_EPS_MODE = 'clamp'
|
39 |
+
|
40 |
+
|
41 |
+
def set_sbn_eps_mode(mode):
|
42 |
+
global SBN_EPS_MODE
|
43 |
+
assert mode in ('clamp', 'plus')
|
44 |
+
SBN_EPS_MODE = mode
|
45 |
+
|
46 |
+
|
47 |
+
def _sum_ft(tensor):
|
48 |
+
"""sum over the first and last dimention"""
|
49 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
50 |
+
|
51 |
+
|
52 |
+
def _unsqueeze_ft(tensor):
|
53 |
+
"""add new dimensions at the front and the tail"""
|
54 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
55 |
+
|
56 |
+
|
57 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
58 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
59 |
+
|
60 |
+
|
61 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
62 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
|
63 |
+
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
64 |
+
|
65 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
|
66 |
+
track_running_stats=track_running_stats)
|
67 |
+
|
68 |
+
if not self.track_running_stats:
|
69 |
+
import warnings
|
70 |
+
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
|
71 |
+
|
72 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
73 |
+
|
74 |
+
self._is_parallel = False
|
75 |
+
self._parallel_id = None
|
76 |
+
self._slave_pipe = None
|
77 |
+
|
78 |
+
def forward(self, input):
|
79 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
80 |
+
if not (self._is_parallel and self.training):
|
81 |
+
return F.batch_norm(
|
82 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
83 |
+
self.training, self.momentum, self.eps)
|
84 |
+
|
85 |
+
# Resize the input to (B, C, -1).
|
86 |
+
input_shape = input.size()
|
87 |
+
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
|
88 |
+
input = input.view(input.size(0), self.num_features, -1)
|
89 |
+
|
90 |
+
# Compute the sum and square-sum.
|
91 |
+
sum_size = input.size(0) * input.size(2)
|
92 |
+
input_sum = _sum_ft(input)
|
93 |
+
input_ssum = _sum_ft(input ** 2)
|
94 |
+
|
95 |
+
# Reduce-and-broadcast the statistics.
|
96 |
+
if self._parallel_id == 0:
|
97 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
98 |
+
else:
|
99 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
100 |
+
|
101 |
+
# Compute the output.
|
102 |
+
if self.affine:
|
103 |
+
# MJY:: Fuse the multiplication for speed.
|
104 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
105 |
+
else:
|
106 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
107 |
+
|
108 |
+
# Reshape it.
|
109 |
+
return output.view(input_shape)
|
110 |
+
|
111 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
112 |
+
self._is_parallel = True
|
113 |
+
self._parallel_id = copy_id
|
114 |
+
|
115 |
+
# parallel_id == 0 means master device.
|
116 |
+
if self._parallel_id == 0:
|
117 |
+
ctx.sync_master = self._sync_master
|
118 |
+
else:
|
119 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
120 |
+
|
121 |
+
def _data_parallel_master(self, intermediates):
|
122 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
123 |
+
|
124 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
125 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
126 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
127 |
+
|
128 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
129 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
130 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
131 |
+
|
132 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
133 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
134 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
135 |
+
|
136 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
137 |
+
|
138 |
+
outputs = []
|
139 |
+
for i, rec in enumerate(intermediates):
|
140 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
141 |
+
|
142 |
+
return outputs
|
143 |
+
|
144 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
145 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
146 |
+
also maintains the moving average on the master device."""
|
147 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
148 |
+
mean = sum_ / size
|
149 |
+
sumvar = ssum - sum_ * mean
|
150 |
+
unbias_var = sumvar / (size - 1)
|
151 |
+
bias_var = sumvar / size
|
152 |
+
|
153 |
+
if hasattr(torch, 'no_grad'):
|
154 |
+
with torch.no_grad():
|
155 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
156 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
157 |
+
else:
|
158 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
159 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
160 |
+
|
161 |
+
if SBN_EPS_MODE == 'clamp':
|
162 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
163 |
+
elif SBN_EPS_MODE == 'plus':
|
164 |
+
return mean, (bias_var + self.eps) ** -0.5
|
165 |
+
else:
|
166 |
+
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
|
167 |
+
|
168 |
+
|
169 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
170 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
171 |
+
mini-batch.
|
172 |
+
|
173 |
+
.. math::
|
174 |
+
|
175 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
176 |
+
|
177 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
178 |
+
standard-deviation are reduced across all devices during training.
|
179 |
+
|
180 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
181 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
182 |
+
the statistics only on that device, which accelerated the computation and
|
183 |
+
is also easy to implement, but the statistics might be inaccurate.
|
184 |
+
Instead, in this synchronized version, the statistics will be computed
|
185 |
+
over all training samples distributed on multiple devices.
|
186 |
+
|
187 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
188 |
+
as the built-in PyTorch implementation.
|
189 |
+
|
190 |
+
The mean and standard-deviation are calculated per-dimension over
|
191 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
192 |
+
of size C (where C is the input size).
|
193 |
+
|
194 |
+
During training, this layer keeps a running estimate of its computed mean
|
195 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
196 |
+
|
197 |
+
During evaluation, this running mean/variance is used for normalization.
|
198 |
+
|
199 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
200 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
201 |
+
|
202 |
+
Args:
|
203 |
+
num_features: num_features from an expected input of size
|
204 |
+
`batch_size x num_features [x width]`
|
205 |
+
eps: a value added to the denominator for numerical stability.
|
206 |
+
Default: 1e-5
|
207 |
+
momentum: the value used for the running_mean and running_var
|
208 |
+
computation. Default: 0.1
|
209 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
210 |
+
affine parameters. Default: ``True``
|
211 |
+
|
212 |
+
Shape::
|
213 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
214 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
215 |
+
|
216 |
+
Examples:
|
217 |
+
>>> # With Learnable Parameters
|
218 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
219 |
+
>>> # Without Learnable Parameters
|
220 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
221 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
222 |
+
>>> output = m(input)
|
223 |
+
"""
|
224 |
+
|
225 |
+
def _check_input_dim(self, input):
|
226 |
+
if input.dim() != 2 and input.dim() != 3:
|
227 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
228 |
+
.format(input.dim()))
|
229 |
+
|
230 |
+
|
231 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
232 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
233 |
+
of 3d inputs
|
234 |
+
|
235 |
+
.. math::
|
236 |
+
|
237 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
238 |
+
|
239 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
240 |
+
standard-deviation are reduced across all devices during training.
|
241 |
+
|
242 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
243 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
244 |
+
the statistics only on that device, which accelerated the computation and
|
245 |
+
is also easy to implement, but the statistics might be inaccurate.
|
246 |
+
Instead, in this synchronized version, the statistics will be computed
|
247 |
+
over all training samples distributed on multiple devices.
|
248 |
+
|
249 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
250 |
+
as the built-in PyTorch implementation.
|
251 |
+
|
252 |
+
The mean and standard-deviation are calculated per-dimension over
|
253 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
254 |
+
of size C (where C is the input size).
|
255 |
+
|
256 |
+
During training, this layer keeps a running estimate of its computed mean
|
257 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
258 |
+
|
259 |
+
During evaluation, this running mean/variance is used for normalization.
|
260 |
+
|
261 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
262 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
263 |
+
|
264 |
+
Args:
|
265 |
+
num_features: num_features from an expected input of
|
266 |
+
size batch_size x num_features x height x width
|
267 |
+
eps: a value added to the denominator for numerical stability.
|
268 |
+
Default: 1e-5
|
269 |
+
momentum: the value used for the running_mean and running_var
|
270 |
+
computation. Default: 0.1
|
271 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
272 |
+
affine parameters. Default: ``True``
|
273 |
+
|
274 |
+
Shape::
|
275 |
+
- Input: :math:`(N, C, H, W)`
|
276 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
277 |
+
|
278 |
+
Examples:
|
279 |
+
>>> # With Learnable Parameters
|
280 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
281 |
+
>>> # Without Learnable Parameters
|
282 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
283 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
284 |
+
>>> output = m(input)
|
285 |
+
"""
|
286 |
+
|
287 |
+
def _check_input_dim(self, input):
|
288 |
+
if input.dim() != 4:
|
289 |
+
raise ValueError('expected 4D input (got {}D input)'
|
290 |
+
.format(input.dim()))
|
291 |
+
|
292 |
+
|
293 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
294 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
295 |
+
of 4d inputs
|
296 |
+
|
297 |
+
.. math::
|
298 |
+
|
299 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
300 |
+
|
301 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
302 |
+
standard-deviation are reduced across all devices during training.
|
303 |
+
|
304 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
305 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
306 |
+
the statistics only on that device, which accelerated the computation and
|
307 |
+
is also easy to implement, but the statistics might be inaccurate.
|
308 |
+
Instead, in this synchronized version, the statistics will be computed
|
309 |
+
over all training samples distributed on multiple devices.
|
310 |
+
|
311 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
312 |
+
as the built-in PyTorch implementation.
|
313 |
+
|
314 |
+
The mean and standard-deviation are calculated per-dimension over
|
315 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
316 |
+
of size C (where C is the input size).
|
317 |
+
|
318 |
+
During training, this layer keeps a running estimate of its computed mean
|
319 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
320 |
+
|
321 |
+
During evaluation, this running mean/variance is used for normalization.
|
322 |
+
|
323 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
324 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
325 |
+
or Spatio-temporal BatchNorm
|
326 |
+
|
327 |
+
Args:
|
328 |
+
num_features: num_features from an expected input of
|
329 |
+
size batch_size x num_features x depth x height x width
|
330 |
+
eps: a value added to the denominator for numerical stability.
|
331 |
+
Default: 1e-5
|
332 |
+
momentum: the value used for the running_mean and running_var
|
333 |
+
computation. Default: 0.1
|
334 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
335 |
+
affine parameters. Default: ``True``
|
336 |
+
|
337 |
+
Shape::
|
338 |
+
- Input: :math:`(N, C, D, H, W)`
|
339 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
340 |
+
|
341 |
+
Examples:
|
342 |
+
>>> # With Learnable Parameters
|
343 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
344 |
+
>>> # Without Learnable Parameters
|
345 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
346 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
347 |
+
>>> output = m(input)
|
348 |
+
"""
|
349 |
+
|
350 |
+
def _check_input_dim(self, input):
|
351 |
+
if input.dim() != 5:
|
352 |
+
raise ValueError('expected 5D input (got {}D input)'
|
353 |
+
.format(input.dim()))
|
354 |
+
|
355 |
+
|
356 |
+
@contextlib.contextmanager
|
357 |
+
def patch_sync_batchnorm():
|
358 |
+
import torch.nn as nn
|
359 |
+
|
360 |
+
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
361 |
+
|
362 |
+
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
363 |
+
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
364 |
+
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
365 |
+
|
366 |
+
yield
|
367 |
+
|
368 |
+
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
369 |
+
|
370 |
+
|
371 |
+
def convert_model(module):
|
372 |
+
"""Traverse the input module and its child recursively
|
373 |
+
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
374 |
+
to SynchronizedBatchNorm*N*d
|
375 |
+
|
376 |
+
Args:
|
377 |
+
module: the input module needs to be convert to SyncBN model
|
378 |
+
|
379 |
+
Examples:
|
380 |
+
>>> import torch.nn as nn
|
381 |
+
>>> import torchvision
|
382 |
+
>>> # m is a standard pytorch model
|
383 |
+
>>> m = torchvision.models.resnet18(True)
|
384 |
+
>>> m = nn.DataParallel(m)
|
385 |
+
>>> # after convert, m is using SyncBN
|
386 |
+
>>> m = convert_model(m)
|
387 |
+
"""
|
388 |
+
if isinstance(module, torch.nn.DataParallel):
|
389 |
+
mod = module.module
|
390 |
+
mod = convert_model(mod)
|
391 |
+
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
|
392 |
+
return mod
|
393 |
+
|
394 |
+
mod = module
|
395 |
+
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
396 |
+
torch.nn.modules.batchnorm.BatchNorm2d,
|
397 |
+
torch.nn.modules.batchnorm.BatchNorm3d],
|
398 |
+
[SynchronizedBatchNorm1d,
|
399 |
+
SynchronizedBatchNorm2d,
|
400 |
+
SynchronizedBatchNorm3d]):
|
401 |
+
if isinstance(module, pth_module):
|
402 |
+
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
403 |
+
mod.running_mean = module.running_mean
|
404 |
+
mod.running_var = module.running_var
|
405 |
+
if module.affine:
|
406 |
+
mod.weight.data = module.weight.data.clone().detach()
|
407 |
+
mod.bias.data = module.bias.data.clone().detach()
|
408 |
+
|
409 |
+
for name, child in module.named_children():
|
410 |
+
mod.add_module(name, convert_model(child))
|
411 |
+
|
412 |
+
return mod
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : batchnorm_reimpl.py
|
4 |
+
# Author : acgtyrant
|
5 |
+
# Date : 11/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.init as init
|
14 |
+
|
15 |
+
__all__ = ['BatchNorm2dReimpl']
|
16 |
+
|
17 |
+
|
18 |
+
class BatchNorm2dReimpl(nn.Module):
|
19 |
+
"""
|
20 |
+
A re-implementation of batch normalization, used for testing the numerical
|
21 |
+
stability.
|
22 |
+
|
23 |
+
Author: acgtyrant
|
24 |
+
See also:
|
25 |
+
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
26 |
+
"""
|
27 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_features = num_features
|
31 |
+
self.eps = eps
|
32 |
+
self.momentum = momentum
|
33 |
+
self.weight = nn.Parameter(torch.empty(num_features))
|
34 |
+
self.bias = nn.Parameter(torch.empty(num_features))
|
35 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
36 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
37 |
+
self.reset_parameters()
|
38 |
+
|
39 |
+
def reset_running_stats(self):
|
40 |
+
self.running_mean.zero_()
|
41 |
+
self.running_var.fill_(1)
|
42 |
+
|
43 |
+
def reset_parameters(self):
|
44 |
+
self.reset_running_stats()
|
45 |
+
init.uniform_(self.weight)
|
46 |
+
init.zeros_(self.bias)
|
47 |
+
|
48 |
+
def forward(self, input_):
|
49 |
+
batchsize, channels, height, width = input_.size()
|
50 |
+
numel = batchsize * height * width
|
51 |
+
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
52 |
+
sum_ = input_.sum(1)
|
53 |
+
sum_of_square = input_.pow(2).sum(1)
|
54 |
+
mean = sum_ / numel
|
55 |
+
sumvar = sum_of_square - sum_ * mean
|
56 |
+
|
57 |
+
self.running_mean = (
|
58 |
+
(1 - self.momentum) * self.running_mean
|
59 |
+
+ self.momentum * mean.detach()
|
60 |
+
)
|
61 |
+
unbias_var = sumvar / (numel - 1)
|
62 |
+
self.running_var = (
|
63 |
+
(1 - self.momentum) * self.running_var
|
64 |
+
+ self.momentum * unbias_var.detach()
|
65 |
+
)
|
66 |
+
|
67 |
+
bias_var = sumvar / numel
|
68 |
+
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
69 |
+
output = (
|
70 |
+
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
71 |
+
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
72 |
+
|
73 |
+
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
74 |
+
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
class TorchTestCase(unittest.TestCase):
|
16 |
+
def assertTensorClose(self, x, y):
|
17 |
+
adiff = float((x - y).abs().max())
|
18 |
+
if (y == 0).all():
|
19 |
+
rdiff = 'NaN'
|
20 |
+
else:
|
21 |
+
rdiff = float((adiff / y).abs().max())
|
22 |
+
|
23 |
+
message = (
|
24 |
+
'Tensor close check failed\n'
|
25 |
+
'adiff={}\n'
|
26 |
+
'rdiff={}\n'
|
27 |
+
).format(adiff, rdiff)
|
28 |
+
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
|
29 |
+
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_numeric_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm.unittest import TorchTestCase
|
16 |
+
|
17 |
+
|
18 |
+
def handy_var(a, unbias=True):
|
19 |
+
n = a.size(0)
|
20 |
+
asum = a.sum(dim=0)
|
21 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
22 |
+
sumvar = as_sum - asum * asum / n
|
23 |
+
if unbias:
|
24 |
+
return sumvar / (n - 1)
|
25 |
+
else:
|
26 |
+
return sumvar / n
|
27 |
+
|
28 |
+
|
29 |
+
class NumericTestCase(TorchTestCase):
|
30 |
+
def testNumericBatchNorm(self):
|
31 |
+
a = torch.rand(16, 10)
|
32 |
+
bn = nn.BatchNorm1d(10, momentum=1, eps=1e-5, affine=False)
|
33 |
+
bn.train()
|
34 |
+
|
35 |
+
a_var1 = Variable(a, requires_grad=True)
|
36 |
+
b_var1 = bn(a_var1)
|
37 |
+
loss1 = b_var1.sum()
|
38 |
+
loss1.backward()
|
39 |
+
|
40 |
+
a_var2 = Variable(a, requires_grad=True)
|
41 |
+
a_mean2 = a_var2.mean(dim=0, keepdim=True)
|
42 |
+
a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
|
43 |
+
# a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
|
44 |
+
b_var2 = (a_var2 - a_mean2) / a_std2
|
45 |
+
loss2 = b_var2.sum()
|
46 |
+
loss2.backward()
|
47 |
+
|
48 |
+
self.assertTensorClose(bn.running_mean, a.mean(dim=0))
|
49 |
+
self.assertTensorClose(bn.running_var, handy_var(a))
|
50 |
+
self.assertTensorClose(a_var1.data, a_var2.data)
|
51 |
+
self.assertTensorClose(b_var1.data, b_var2.data)
|
52 |
+
self.assertTensorClose(a_var1.grad, a_var2.grad)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
unittest.main()
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : test_numeric_batchnorm_v2.py
|
4 |
+
# Author : Jiayuan Mao
|
5 |
+
# Email : [email protected]
|
6 |
+
# Date : 11/01/2018
|
7 |
+
#
|
8 |
+
# Distributed under terms of the MIT license.
|
9 |
+
|
10 |
+
"""
|
11 |
+
Test the numerical implementation of batch normalization.
|
12 |
+
|
13 |
+
Author: acgtyrant.
|
14 |
+
See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
15 |
+
"""
|
16 |
+
|
17 |
+
import unittest
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.optim as optim
|
22 |
+
|
23 |
+
from sync_batchnorm.unittest import TorchTestCase
|
24 |
+
from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl
|
25 |
+
|
26 |
+
|
27 |
+
class NumericTestCasev2(TorchTestCase):
|
28 |
+
def testNumericBatchNorm(self):
|
29 |
+
CHANNELS = 16
|
30 |
+
batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1)
|
31 |
+
optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01)
|
32 |
+
|
33 |
+
batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1)
|
34 |
+
batchnorm2.weight.data.copy_(batchnorm1.weight.data)
|
35 |
+
batchnorm2.bias.data.copy_(batchnorm1.bias.data)
|
36 |
+
optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01)
|
37 |
+
|
38 |
+
for _ in range(100):
|
39 |
+
input_ = torch.rand(16, CHANNELS, 16, 16)
|
40 |
+
|
41 |
+
input1 = input_.clone().requires_grad_(True)
|
42 |
+
output1 = batchnorm1(input1)
|
43 |
+
output1.sum().backward()
|
44 |
+
optimizer1.step()
|
45 |
+
|
46 |
+
input2 = input_.clone().requires_grad_(True)
|
47 |
+
output2 = batchnorm2(input2)
|
48 |
+
output2.sum().backward()
|
49 |
+
optimizer2.step()
|
50 |
+
|
51 |
+
self.assertTensorClose(input1, input2)
|
52 |
+
self.assertTensorClose(output1, output2)
|
53 |
+
self.assertTensorClose(input1.grad, input2.grad)
|
54 |
+
self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad)
|
55 |
+
self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad)
|
56 |
+
self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean)
|
57 |
+
self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
unittest.main()
|
62 |
+
|
Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_sync_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm import set_sbn_eps_mode
|
16 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
|
17 |
+
from sync_batchnorm.unittest import TorchTestCase
|
18 |
+
|
19 |
+
set_sbn_eps_mode('plus')
|
20 |
+
|
21 |
+
|
22 |
+
def handy_var(a, unbias=True):
|
23 |
+
n = a.size(0)
|
24 |
+
asum = a.sum(dim=0)
|
25 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
26 |
+
sumvar = as_sum - asum * asum / n
|
27 |
+
if unbias:
|
28 |
+
return sumvar / (n - 1)
|
29 |
+
else:
|
30 |
+
return sumvar / n
|
31 |
+
|
32 |
+
|
33 |
+
def _find_bn(module):
|
34 |
+
for m in module.modules():
|
35 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
|
36 |
+
return m
|
37 |
+
|
38 |
+
|
39 |
+
class SyncTestCase(TorchTestCase):
|
40 |
+
def _syncParameters(self, bn1, bn2):
|
41 |
+
bn1.reset_parameters()
|
42 |
+
bn2.reset_parameters()
|
43 |
+
if bn1.affine and bn2.affine:
|
44 |
+
bn2.weight.data.copy_(bn1.weight.data)
|
45 |
+
bn2.bias.data.copy_(bn1.bias.data)
|
46 |
+
|
47 |
+
def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
|
48 |
+
"""Check the forward and backward for the customized batch normalization."""
|
49 |
+
bn1.train(mode=is_train)
|
50 |
+
bn2.train(mode=is_train)
|
51 |
+
|
52 |
+
if cuda:
|
53 |
+
input = input.cuda()
|
54 |
+
|
55 |
+
self._syncParameters(_find_bn(bn1), _find_bn(bn2))
|
56 |
+
|
57 |
+
input1 = Variable(input, requires_grad=True)
|
58 |
+
output1 = bn1(input1)
|
59 |
+
output1.sum().backward()
|
60 |
+
input2 = Variable(input, requires_grad=True)
|
61 |
+
output2 = bn2(input2)
|
62 |
+
output2.sum().backward()
|
63 |
+
|
64 |
+
self.assertTensorClose(input1.data, input2.data)
|
65 |
+
self.assertTensorClose(output1.data, output2.data)
|
66 |
+
self.assertTensorClose(input1.grad, input2.grad)
|
67 |
+
self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
|
68 |
+
self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
|
69 |
+
|
70 |
+
def testSyncBatchNormNormalTrain(self):
|
71 |
+
bn = nn.BatchNorm1d(10)
|
72 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
73 |
+
|
74 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
|
75 |
+
|
76 |
+
def testSyncBatchNormNormalEval(self):
|
77 |
+
bn = nn.BatchNorm1d(10)
|
78 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
79 |
+
|
80 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
|
81 |
+
|
82 |
+
def testSyncBatchNormSyncTrain(self):
|
83 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
84 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
85 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
86 |
+
|
87 |
+
bn.cuda()
|
88 |
+
sync_bn.cuda()
|
89 |
+
|
90 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
|
91 |
+
|
92 |
+
def testSyncBatchNormSyncEval(self):
|
93 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
94 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
95 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
96 |
+
|
97 |
+
bn.cuda()
|
98 |
+
sync_bn.cuda()
|
99 |
+
|
100 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
|
101 |
+
|
102 |
+
def testSyncBatchNorm2DSyncTrain(self):
|
103 |
+
bn = nn.BatchNorm2d(10)
|
104 |
+
sync_bn = SynchronizedBatchNorm2d(10)
|
105 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
106 |
+
|
107 |
+
bn.cuda()
|
108 |
+
sync_bn.cuda()
|
109 |
+
|
110 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
unittest.main()
|
Face_Enhancement/models/networks/__init__.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from models.networks.base_network import BaseNetwork
|
6 |
+
from models.networks.generator import *
|
7 |
+
from models.networks.encoder import *
|
8 |
+
import util.util as util
|
9 |
+
|
10 |
+
|
11 |
+
def find_network_using_name(target_network_name, filename):
|
12 |
+
target_class_name = target_network_name + filename
|
13 |
+
module_name = "models.networks." + filename
|
14 |
+
network = util.find_class_in_module(target_class_name, module_name)
|
15 |
+
|
16 |
+
assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network
|
17 |
+
|
18 |
+
return network
|
19 |
+
|
20 |
+
|
21 |
+
def modify_commandline_options(parser, is_train):
|
22 |
+
opt, _ = parser.parse_known_args()
|
23 |
+
|
24 |
+
netG_cls = find_network_using_name(opt.netG, "generator")
|
25 |
+
parser = netG_cls.modify_commandline_options(parser, is_train)
|
26 |
+
if is_train:
|
27 |
+
netD_cls = find_network_using_name(opt.netD, "discriminator")
|
28 |
+
parser = netD_cls.modify_commandline_options(parser, is_train)
|
29 |
+
netE_cls = find_network_using_name("conv", "encoder")
|
30 |
+
parser = netE_cls.modify_commandline_options(parser, is_train)
|
31 |
+
|
32 |
+
return parser
|
33 |
+
|
34 |
+
|
35 |
+
def create_network(cls, opt):
|
36 |
+
net = cls(opt)
|
37 |
+
net.print_network()
|
38 |
+
if len(opt.gpu_ids) > 0:
|
39 |
+
assert torch.cuda.is_available()
|
40 |
+
net.cuda()
|
41 |
+
net.init_weights(opt.init_type, opt.init_variance)
|
42 |
+
return net
|
43 |
+
|
44 |
+
|
45 |
+
def define_G(opt):
|
46 |
+
netG_cls = find_network_using_name(opt.netG, "generator")
|
47 |
+
return create_network(netG_cls, opt)
|
48 |
+
|
49 |
+
|
50 |
+
def define_D(opt):
|
51 |
+
netD_cls = find_network_using_name(opt.netD, "discriminator")
|
52 |
+
return create_network(netD_cls, opt)
|
53 |
+
|
54 |
+
|
55 |
+
def define_E(opt):
|
56 |
+
# there exists only one encoder type
|
57 |
+
netE_cls = find_network_using_name("conv", "encoder")
|
58 |
+
return create_network(netE_cls, opt)
|
Face_Enhancement/models/networks/architecture.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchvision
|
8 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
9 |
+
from models.networks.normalization import SPADE
|
10 |
+
|
11 |
+
|
12 |
+
# ResNet block that uses SPADE.
|
13 |
+
# It differs from the ResNet block of pix2pixHD in that
|
14 |
+
# it takes in the segmentation map as input, learns the skip connection if necessary,
|
15 |
+
# and applies normalization first and then convolution.
|
16 |
+
# This architecture seemed like a standard architecture for unconditional or
|
17 |
+
# class-conditional GAN architecture using residual block.
|
18 |
+
# The code was inspired from https://github.com/LMescheder/GAN_stability.
|
19 |
+
class SPADEResnetBlock(nn.Module):
|
20 |
+
def __init__(self, fin, fout, opt):
|
21 |
+
super().__init__()
|
22 |
+
# Attributes
|
23 |
+
self.learned_shortcut = fin != fout
|
24 |
+
fmiddle = min(fin, fout)
|
25 |
+
|
26 |
+
self.opt = opt
|
27 |
+
# create conv layers
|
28 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
29 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
30 |
+
if self.learned_shortcut:
|
31 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
32 |
+
|
33 |
+
# apply spectral norm if specified
|
34 |
+
if "spectral" in opt.norm_G:
|
35 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
36 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
37 |
+
if self.learned_shortcut:
|
38 |
+
self.conv_s = spectral_norm(self.conv_s)
|
39 |
+
|
40 |
+
# define normalization layers
|
41 |
+
spade_config_str = opt.norm_G.replace("spectral", "")
|
42 |
+
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
43 |
+
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
|
44 |
+
if self.learned_shortcut:
|
45 |
+
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
46 |
+
|
47 |
+
# note the resnet block with SPADE also takes in |seg|,
|
48 |
+
# the semantic segmentation map as input
|
49 |
+
def forward(self, x, seg, degraded_image):
|
50 |
+
x_s = self.shortcut(x, seg, degraded_image)
|
51 |
+
|
52 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
|
53 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
|
54 |
+
|
55 |
+
out = x_s + dx
|
56 |
+
|
57 |
+
return out
|
58 |
+
|
59 |
+
def shortcut(self, x, seg, degraded_image):
|
60 |
+
if self.learned_shortcut:
|
61 |
+
x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
|
62 |
+
else:
|
63 |
+
x_s = x
|
64 |
+
return x_s
|
65 |
+
|
66 |
+
def actvn(self, x):
|
67 |
+
return F.leaky_relu(x, 2e-1)
|
68 |
+
|
69 |
+
|
70 |
+
# ResNet block used in pix2pixHD
|
71 |
+
# We keep the same architecture as pix2pixHD.
|
72 |
+
class ResnetBlock(nn.Module):
|
73 |
+
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
pw = (kernel_size - 1) // 2
|
77 |
+
self.conv_block = nn.Sequential(
|
78 |
+
nn.ReflectionPad2d(pw),
|
79 |
+
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
|
80 |
+
activation,
|
81 |
+
nn.ReflectionPad2d(pw),
|
82 |
+
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
y = self.conv_block(x)
|
87 |
+
out = x + y
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
# VGG architecter, used for the perceptual loss using a pretrained VGG network
|
92 |
+
class VGG19(torch.nn.Module):
|
93 |
+
def __init__(self, requires_grad=False):
|
94 |
+
super().__init__()
|
95 |
+
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
96 |
+
self.slice1 = torch.nn.Sequential()
|
97 |
+
self.slice2 = torch.nn.Sequential()
|
98 |
+
self.slice3 = torch.nn.Sequential()
|
99 |
+
self.slice4 = torch.nn.Sequential()
|
100 |
+
self.slice5 = torch.nn.Sequential()
|
101 |
+
for x in range(2):
|
102 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
103 |
+
for x in range(2, 7):
|
104 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
105 |
+
for x in range(7, 12):
|
106 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
107 |
+
for x in range(12, 21):
|
108 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
109 |
+
for x in range(21, 30):
|
110 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
if not requires_grad:
|
112 |
+
for param in self.parameters():
|
113 |
+
param.requires_grad = False
|
114 |
+
|
115 |
+
def forward(self, X):
|
116 |
+
h_relu1 = self.slice1(X)
|
117 |
+
h_relu2 = self.slice2(h_relu1)
|
118 |
+
h_relu3 = self.slice3(h_relu2)
|
119 |
+
h_relu4 = self.slice4(h_relu3)
|
120 |
+
h_relu5 = self.slice5(h_relu4)
|
121 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
class SPADEResnetBlock_non_spade(nn.Module):
|
126 |
+
def __init__(self, fin, fout, opt):
|
127 |
+
super().__init__()
|
128 |
+
# Attributes
|
129 |
+
self.learned_shortcut = fin != fout
|
130 |
+
fmiddle = min(fin, fout)
|
131 |
+
|
132 |
+
self.opt = opt
|
133 |
+
# create conv layers
|
134 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
135 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
136 |
+
if self.learned_shortcut:
|
137 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
138 |
+
|
139 |
+
# apply spectral norm if specified
|
140 |
+
if "spectral" in opt.norm_G:
|
141 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
142 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
143 |
+
if self.learned_shortcut:
|
144 |
+
self.conv_s = spectral_norm(self.conv_s)
|
145 |
+
|
146 |
+
# define normalization layers
|
147 |
+
spade_config_str = opt.norm_G.replace("spectral", "")
|
148 |
+
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
149 |
+
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
|
150 |
+
if self.learned_shortcut:
|
151 |
+
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
152 |
+
|
153 |
+
# note the resnet block with SPADE also takes in |seg|,
|
154 |
+
# the semantic segmentation map as input
|
155 |
+
def forward(self, x, seg, degraded_image):
|
156 |
+
x_s = self.shortcut(x, seg, degraded_image)
|
157 |
+
|
158 |
+
dx = self.conv_0(self.actvn(x))
|
159 |
+
dx = self.conv_1(self.actvn(dx))
|
160 |
+
|
161 |
+
out = x_s + dx
|
162 |
+
|
163 |
+
return out
|
164 |
+
|
165 |
+
def shortcut(self, x, seg, degraded_image):
|
166 |
+
if self.learned_shortcut:
|
167 |
+
x_s = self.conv_s(x)
|
168 |
+
else:
|
169 |
+
x_s = x
|
170 |
+
return x_s
|
171 |
+
|
172 |
+
def actvn(self, x):
|
173 |
+
return F.leaky_relu(x, 2e-1)
|
Face_Enhancement/models/networks/base_network.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import init
|
6 |
+
|
7 |
+
|
8 |
+
class BaseNetwork(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(BaseNetwork, self).__init__()
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def modify_commandline_options(parser, is_train):
|
14 |
+
return parser
|
15 |
+
|
16 |
+
def print_network(self):
|
17 |
+
if isinstance(self, list):
|
18 |
+
self = self[0]
|
19 |
+
num_params = 0
|
20 |
+
for param in self.parameters():
|
21 |
+
num_params += param.numel()
|
22 |
+
print(
|
23 |
+
"Network [%s] was created. Total number of parameters: %.1f million. "
|
24 |
+
"To see the architecture, do print(network)." % (type(self).__name__, num_params / 1000000)
|
25 |
+
)
|
26 |
+
|
27 |
+
def init_weights(self, init_type="normal", gain=0.02):
|
28 |
+
def init_func(m):
|
29 |
+
classname = m.__class__.__name__
|
30 |
+
if classname.find("BatchNorm2d") != -1:
|
31 |
+
if hasattr(m, "weight") and m.weight is not None:
|
32 |
+
init.normal_(m.weight.data, 1.0, gain)
|
33 |
+
if hasattr(m, "bias") and m.bias is not None:
|
34 |
+
init.constant_(m.bias.data, 0.0)
|
35 |
+
elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
|
36 |
+
if init_type == "normal":
|
37 |
+
init.normal_(m.weight.data, 0.0, gain)
|
38 |
+
elif init_type == "xavier":
|
39 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
40 |
+
elif init_type == "xavier_uniform":
|
41 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
42 |
+
elif init_type == "kaiming":
|
43 |
+
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
44 |
+
elif init_type == "orthogonal":
|
45 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
46 |
+
elif init_type == "none": # uses pytorch's default init method
|
47 |
+
m.reset_parameters()
|
48 |
+
else:
|
49 |
+
raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
|
50 |
+
if hasattr(m, "bias") and m.bias is not None:
|
51 |
+
init.constant_(m.bias.data, 0.0)
|
52 |
+
|
53 |
+
self.apply(init_func)
|
54 |
+
|
55 |
+
# propagate to children
|
56 |
+
for m in self.children():
|
57 |
+
if hasattr(m, "init_weights"):
|
58 |
+
m.init_weights(init_type, gain)
|
Face_Enhancement/models/networks/encoder.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from models.networks.base_network import BaseNetwork
|
8 |
+
from models.networks.normalization import get_nonspade_norm_layer
|
9 |
+
|
10 |
+
|
11 |
+
class ConvEncoder(BaseNetwork):
|
12 |
+
""" Same architecture as the image discriminator """
|
13 |
+
|
14 |
+
def __init__(self, opt):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
kw = 3
|
18 |
+
pw = int(np.ceil((kw - 1.0) / 2))
|
19 |
+
ndf = opt.ngf
|
20 |
+
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
|
21 |
+
self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
|
22 |
+
self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))
|
23 |
+
self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))
|
24 |
+
self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))
|
25 |
+
self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
|
26 |
+
if opt.crop_size >= 256:
|
27 |
+
self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
|
28 |
+
|
29 |
+
self.so = s0 = 4
|
30 |
+
self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
|
31 |
+
self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
|
32 |
+
|
33 |
+
self.actvn = nn.LeakyReLU(0.2, False)
|
34 |
+
self.opt = opt
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
if x.size(2) != 256 or x.size(3) != 256:
|
38 |
+
x = F.interpolate(x, size=(256, 256), mode="bilinear")
|
39 |
+
|
40 |
+
x = self.layer1(x)
|
41 |
+
x = self.layer2(self.actvn(x))
|
42 |
+
x = self.layer3(self.actvn(x))
|
43 |
+
x = self.layer4(self.actvn(x))
|
44 |
+
x = self.layer5(self.actvn(x))
|
45 |
+
if self.opt.crop_size >= 256:
|
46 |
+
x = self.layer6(self.actvn(x))
|
47 |
+
x = self.actvn(x)
|
48 |
+
|
49 |
+
x = x.view(x.size(0), -1)
|
50 |
+
mu = self.fc_mu(x)
|
51 |
+
logvar = self.fc_var(x)
|
52 |
+
|
53 |
+
return mu, logvar
|
Face_Enhancement/models/networks/generator.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from models.networks.base_network import BaseNetwork
|
8 |
+
from models.networks.normalization import get_nonspade_norm_layer
|
9 |
+
from models.networks.architecture import ResnetBlock as ResnetBlock
|
10 |
+
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
|
11 |
+
from models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade
|
12 |
+
|
13 |
+
|
14 |
+
class SPADEGenerator(BaseNetwork):
|
15 |
+
@staticmethod
|
16 |
+
def modify_commandline_options(parser, is_train):
|
17 |
+
parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
|
18 |
+
parser.add_argument(
|
19 |
+
"--num_upsampling_layers",
|
20 |
+
choices=("normal", "more", "most"),
|
21 |
+
default="normal",
|
22 |
+
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
|
23 |
+
)
|
24 |
+
|
25 |
+
return parser
|
26 |
+
|
27 |
+
def __init__(self, opt):
|
28 |
+
super().__init__()
|
29 |
+
self.opt = opt
|
30 |
+
nf = opt.ngf
|
31 |
+
|
32 |
+
self.sw, self.sh = self.compute_latent_vector_size(opt)
|
33 |
+
|
34 |
+
print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))
|
35 |
+
|
36 |
+
if opt.use_vae:
|
37 |
+
# In case of VAE, we will sample from random z vector
|
38 |
+
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
|
39 |
+
else:
|
40 |
+
# Otherwise, we make the network deterministic by starting with
|
41 |
+
# downsampled segmentation map instead of random z
|
42 |
+
if self.opt.no_parsing_map:
|
43 |
+
self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
|
44 |
+
else:
|
45 |
+
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
|
46 |
+
|
47 |
+
if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
|
48 |
+
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
49 |
+
else:
|
50 |
+
self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
51 |
+
|
52 |
+
if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
|
53 |
+
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
54 |
+
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
55 |
+
|
56 |
+
else:
|
57 |
+
self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
58 |
+
self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
59 |
+
|
60 |
+
if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
|
61 |
+
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
|
62 |
+
else:
|
63 |
+
self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)
|
64 |
+
|
65 |
+
if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
|
66 |
+
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
|
67 |
+
else:
|
68 |
+
self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)
|
69 |
+
|
70 |
+
if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
|
71 |
+
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
|
72 |
+
else:
|
73 |
+
self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)
|
74 |
+
|
75 |
+
if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
|
76 |
+
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
|
77 |
+
else:
|
78 |
+
self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)
|
79 |
+
|
80 |
+
final_nc = nf
|
81 |
+
|
82 |
+
if opt.num_upsampling_layers == "most":
|
83 |
+
self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
|
84 |
+
final_nc = nf // 2
|
85 |
+
|
86 |
+
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
|
87 |
+
|
88 |
+
self.up = nn.Upsample(scale_factor=2)
|
89 |
+
|
90 |
+
def compute_latent_vector_size(self, opt):
|
91 |
+
if opt.num_upsampling_layers == "normal":
|
92 |
+
num_up_layers = 5
|
93 |
+
elif opt.num_upsampling_layers == "more":
|
94 |
+
num_up_layers = 6
|
95 |
+
elif opt.num_upsampling_layers == "most":
|
96 |
+
num_up_layers = 7
|
97 |
+
else:
|
98 |
+
raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
|
99 |
+
|
100 |
+
sw = opt.load_size // (2 ** num_up_layers)
|
101 |
+
sh = round(sw / opt.aspect_ratio)
|
102 |
+
|
103 |
+
return sw, sh
|
104 |
+
|
105 |
+
def forward(self, input, degraded_image, z=None):
|
106 |
+
seg = input
|
107 |
+
|
108 |
+
if self.opt.use_vae:
|
109 |
+
# we sample z from unit normal and reshape the tensor
|
110 |
+
if z is None:
|
111 |
+
z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
|
112 |
+
x = self.fc(z)
|
113 |
+
x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
|
114 |
+
else:
|
115 |
+
# we downsample segmap and run convolution
|
116 |
+
if self.opt.no_parsing_map:
|
117 |
+
x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
|
118 |
+
else:
|
119 |
+
x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
|
120 |
+
x = self.fc(x)
|
121 |
+
|
122 |
+
x = self.head_0(x, seg, degraded_image)
|
123 |
+
|
124 |
+
x = self.up(x)
|
125 |
+
x = self.G_middle_0(x, seg, degraded_image)
|
126 |
+
|
127 |
+
if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
|
128 |
+
x = self.up(x)
|
129 |
+
|
130 |
+
x = self.G_middle_1(x, seg, degraded_image)
|
131 |
+
|
132 |
+
x = self.up(x)
|
133 |
+
x = self.up_0(x, seg, degraded_image)
|
134 |
+
x = self.up(x)
|
135 |
+
x = self.up_1(x, seg, degraded_image)
|
136 |
+
x = self.up(x)
|
137 |
+
x = self.up_2(x, seg, degraded_image)
|
138 |
+
x = self.up(x)
|
139 |
+
x = self.up_3(x, seg, degraded_image)
|
140 |
+
|
141 |
+
if self.opt.num_upsampling_layers == "most":
|
142 |
+
x = self.up(x)
|
143 |
+
x = self.up_4(x, seg, degraded_image)
|
144 |
+
|
145 |
+
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
146 |
+
x = F.tanh(x)
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class Pix2PixHDGenerator(BaseNetwork):
|
152 |
+
@staticmethod
|
153 |
+
def modify_commandline_options(parser, is_train):
|
154 |
+
parser.add_argument(
|
155 |
+
"--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--resnet_n_blocks",
|
159 |
+
type=int,
|
160 |
+
default=9,
|
161 |
+
help="number of residual blocks in the global generator network",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
|
168 |
+
)
|
169 |
+
# parser.set_defaults(norm_G='instance')
|
170 |
+
return parser
|
171 |
+
|
172 |
+
def __init__(self, opt):
|
173 |
+
super().__init__()
|
174 |
+
input_nc = 3
|
175 |
+
|
176 |
+
# print("xxxxx")
|
177 |
+
# print(opt.norm_G)
|
178 |
+
norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
|
179 |
+
activation = nn.ReLU(False)
|
180 |
+
|
181 |
+
model = []
|
182 |
+
|
183 |
+
# initial conv
|
184 |
+
model += [
|
185 |
+
nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
|
186 |
+
norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
|
187 |
+
activation,
|
188 |
+
]
|
189 |
+
|
190 |
+
# downsample
|
191 |
+
mult = 1
|
192 |
+
for i in range(opt.resnet_n_downsample):
|
193 |
+
model += [
|
194 |
+
norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
|
195 |
+
activation,
|
196 |
+
]
|
197 |
+
mult *= 2
|
198 |
+
|
199 |
+
# resnet blocks
|
200 |
+
for i in range(opt.resnet_n_blocks):
|
201 |
+
model += [
|
202 |
+
ResnetBlock(
|
203 |
+
opt.ngf * mult,
|
204 |
+
norm_layer=norm_layer,
|
205 |
+
activation=activation,
|
206 |
+
kernel_size=opt.resnet_kernel_size,
|
207 |
+
)
|
208 |
+
]
|
209 |
+
|
210 |
+
# upsample
|
211 |
+
for i in range(opt.resnet_n_downsample):
|
212 |
+
nc_in = int(opt.ngf * mult)
|
213 |
+
nc_out = int((opt.ngf * mult) / 2)
|
214 |
+
model += [
|
215 |
+
norm_layer(
|
216 |
+
nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
|
217 |
+
),
|
218 |
+
activation,
|
219 |
+
]
|
220 |
+
mult = mult // 2
|
221 |
+
|
222 |
+
# final output conv
|
223 |
+
model += [
|
224 |
+
nn.ReflectionPad2d(3),
|
225 |
+
nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
|
226 |
+
nn.Tanh(),
|
227 |
+
]
|
228 |
+
|
229 |
+
self.model = nn.Sequential(*model)
|
230 |
+
|
231 |
+
def forward(self, input, degraded_image, z=None):
|
232 |
+
return self.model(degraded_image)
|
233 |
+
|
Face_Enhancement/models/networks/normalization.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import re
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
|
9 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
10 |
+
|
11 |
+
|
12 |
+
def get_nonspade_norm_layer(opt, norm_type="instance"):
|
13 |
+
# helper function to get # output channels of the previous layer
|
14 |
+
def get_out_channel(layer):
|
15 |
+
if hasattr(layer, "out_channels"):
|
16 |
+
return getattr(layer, "out_channels")
|
17 |
+
return layer.weight.size(0)
|
18 |
+
|
19 |
+
# this function will be returned
|
20 |
+
def add_norm_layer(layer):
|
21 |
+
nonlocal norm_type
|
22 |
+
if norm_type.startswith("spectral"):
|
23 |
+
layer = spectral_norm(layer)
|
24 |
+
subnorm_type = norm_type[len("spectral") :]
|
25 |
+
|
26 |
+
if subnorm_type == "none" or len(subnorm_type) == 0:
|
27 |
+
return layer
|
28 |
+
|
29 |
+
# remove bias in the previous layer, which is meaningless
|
30 |
+
# since it has no effect after normalization
|
31 |
+
if getattr(layer, "bias", None) is not None:
|
32 |
+
delattr(layer, "bias")
|
33 |
+
layer.register_parameter("bias", None)
|
34 |
+
|
35 |
+
if subnorm_type == "batch":
|
36 |
+
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
37 |
+
elif subnorm_type == "sync_batch":
|
38 |
+
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
|
39 |
+
elif subnorm_type == "instance":
|
40 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
41 |
+
else:
|
42 |
+
raise ValueError("normalization layer %s is not recognized" % subnorm_type)
|
43 |
+
|
44 |
+
return nn.Sequential(layer, norm_layer)
|
45 |
+
|
46 |
+
return add_norm_layer
|
47 |
+
|
48 |
+
|
49 |
+
class SPADE(nn.Module):
|
50 |
+
def __init__(self, config_text, norm_nc, label_nc, opt):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
assert config_text.startswith("spade")
|
54 |
+
parsed = re.search("spade(\D+)(\d)x\d", config_text)
|
55 |
+
param_free_norm_type = str(parsed.group(1))
|
56 |
+
ks = int(parsed.group(2))
|
57 |
+
self.opt = opt
|
58 |
+
if param_free_norm_type == "instance":
|
59 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
60 |
+
elif param_free_norm_type == "syncbatch":
|
61 |
+
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
|
62 |
+
elif param_free_norm_type == "batch":
|
63 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
64 |
+
else:
|
65 |
+
raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type)
|
66 |
+
|
67 |
+
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
68 |
+
nhidden = 128
|
69 |
+
|
70 |
+
pw = ks // 2
|
71 |
+
|
72 |
+
if self.opt.no_parsing_map:
|
73 |
+
self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
|
74 |
+
else:
|
75 |
+
self.mlp_shared = nn.Sequential(
|
76 |
+
nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
|
77 |
+
)
|
78 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
79 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
80 |
+
|
81 |
+
def forward(self, x, segmap, degraded_image):
|
82 |
+
|
83 |
+
# Part 1. generate parameter-free normalized activations
|
84 |
+
normalized = self.param_free_norm(x)
|
85 |
+
|
86 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
87 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
|
88 |
+
degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear")
|
89 |
+
|
90 |
+
if self.opt.no_parsing_map:
|
91 |
+
actv = self.mlp_shared(degraded_face)
|
92 |
+
else:
|
93 |
+
actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))
|
94 |
+
gamma = self.mlp_gamma(actv)
|
95 |
+
beta = self.mlp_beta(actv)
|
96 |
+
|
97 |
+
# apply scale and bias
|
98 |
+
out = normalized * (1 + gamma) + beta
|
99 |
+
|
100 |
+
return out
|
Face_Enhancement/models/networks/sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import set_sbn_eps_mode
|
12 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
13 |
+
from .batchnorm import patch_sync_batchnorm, convert_model
|
14 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
import contextlib
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
21 |
+
except ImportError:
|
22 |
+
ReduceAddCoalesced = Broadcast = None
|
23 |
+
|
24 |
+
try:
|
25 |
+
from jactorch.parallel.comm import SyncMaster
|
26 |
+
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
27 |
+
except ImportError:
|
28 |
+
from .comm import SyncMaster
|
29 |
+
from .replicate import DataParallelWithCallback
|
30 |
+
|
31 |
+
__all__ = [
|
32 |
+
'set_sbn_eps_mode',
|
33 |
+
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
34 |
+
'patch_sync_batchnorm', 'convert_model'
|
35 |
+
]
|
36 |
+
|
37 |
+
|
38 |
+
SBN_EPS_MODE = 'clamp'
|
39 |
+
|
40 |
+
|
41 |
+
def set_sbn_eps_mode(mode):
|
42 |
+
global SBN_EPS_MODE
|
43 |
+
assert mode in ('clamp', 'plus')
|
44 |
+
SBN_EPS_MODE = mode
|
45 |
+
|
46 |
+
|
47 |
+
def _sum_ft(tensor):
|
48 |
+
"""sum over the first and last dimention"""
|
49 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
50 |
+
|
51 |
+
|
52 |
+
def _unsqueeze_ft(tensor):
|
53 |
+
"""add new dimensions at the front and the tail"""
|
54 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
55 |
+
|
56 |
+
|
57 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
58 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
59 |
+
|
60 |
+
|
61 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
62 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
|
63 |
+
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
64 |
+
|
65 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
|
66 |
+
track_running_stats=track_running_stats)
|
67 |
+
|
68 |
+
if not self.track_running_stats:
|
69 |
+
import warnings
|
70 |
+
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
|
71 |
+
|
72 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
73 |
+
|
74 |
+
self._is_parallel = False
|
75 |
+
self._parallel_id = None
|
76 |
+
self._slave_pipe = None
|
77 |
+
|
78 |
+
def forward(self, input):
|
79 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
80 |
+
if not (self._is_parallel and self.training):
|
81 |
+
return F.batch_norm(
|
82 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
83 |
+
self.training, self.momentum, self.eps)
|
84 |
+
|
85 |
+
# Resize the input to (B, C, -1).
|
86 |
+
input_shape = input.size()
|
87 |
+
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
|
88 |
+
input = input.view(input.size(0), self.num_features, -1)
|
89 |
+
|
90 |
+
# Compute the sum and square-sum.
|
91 |
+
sum_size = input.size(0) * input.size(2)
|
92 |
+
input_sum = _sum_ft(input)
|
93 |
+
input_ssum = _sum_ft(input ** 2)
|
94 |
+
|
95 |
+
# Reduce-and-broadcast the statistics.
|
96 |
+
if self._parallel_id == 0:
|
97 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
98 |
+
else:
|
99 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
100 |
+
|
101 |
+
# Compute the output.
|
102 |
+
if self.affine:
|
103 |
+
# MJY:: Fuse the multiplication for speed.
|
104 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
105 |
+
else:
|
106 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
107 |
+
|
108 |
+
# Reshape it.
|
109 |
+
return output.view(input_shape)
|
110 |
+
|
111 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
112 |
+
self._is_parallel = True
|
113 |
+
self._parallel_id = copy_id
|
114 |
+
|
115 |
+
# parallel_id == 0 means master device.
|
116 |
+
if self._parallel_id == 0:
|
117 |
+
ctx.sync_master = self._sync_master
|
118 |
+
else:
|
119 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
120 |
+
|
121 |
+
def _data_parallel_master(self, intermediates):
|
122 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
123 |
+
|
124 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
125 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
126 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
127 |
+
|
128 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
129 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
130 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
131 |
+
|
132 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
133 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
134 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
135 |
+
|
136 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
137 |
+
|
138 |
+
outputs = []
|
139 |
+
for i, rec in enumerate(intermediates):
|
140 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
141 |
+
|
142 |
+
return outputs
|
143 |
+
|
144 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
145 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
146 |
+
also maintains the moving average on the master device."""
|
147 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
148 |
+
mean = sum_ / size
|
149 |
+
sumvar = ssum - sum_ * mean
|
150 |
+
unbias_var = sumvar / (size - 1)
|
151 |
+
bias_var = sumvar / size
|
152 |
+
|
153 |
+
if hasattr(torch, 'no_grad'):
|
154 |
+
with torch.no_grad():
|
155 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
156 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
157 |
+
else:
|
158 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
159 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
160 |
+
|
161 |
+
if SBN_EPS_MODE == 'clamp':
|
162 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
163 |
+
elif SBN_EPS_MODE == 'plus':
|
164 |
+
return mean, (bias_var + self.eps) ** -0.5
|
165 |
+
else:
|
166 |
+
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
|
167 |
+
|
168 |
+
|
169 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
170 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
171 |
+
mini-batch.
|
172 |
+
|
173 |
+
.. math::
|
174 |
+
|
175 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
176 |
+
|
177 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
178 |
+
standard-deviation are reduced across all devices during training.
|
179 |
+
|
180 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
181 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
182 |
+
the statistics only on that device, which accelerated the computation and
|
183 |
+
is also easy to implement, but the statistics might be inaccurate.
|
184 |
+
Instead, in this synchronized version, the statistics will be computed
|
185 |
+
over all training samples distributed on multiple devices.
|
186 |
+
|
187 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
188 |
+
as the built-in PyTorch implementation.
|
189 |
+
|
190 |
+
The mean and standard-deviation are calculated per-dimension over
|
191 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
192 |
+
of size C (where C is the input size).
|
193 |
+
|
194 |
+
During training, this layer keeps a running estimate of its computed mean
|
195 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
196 |
+
|
197 |
+
During evaluation, this running mean/variance is used for normalization.
|
198 |
+
|
199 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
200 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
201 |
+
|
202 |
+
Args:
|
203 |
+
num_features: num_features from an expected input of size
|
204 |
+
`batch_size x num_features [x width]`
|
205 |
+
eps: a value added to the denominator for numerical stability.
|
206 |
+
Default: 1e-5
|
207 |
+
momentum: the value used for the running_mean and running_var
|
208 |
+
computation. Default: 0.1
|
209 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
210 |
+
affine parameters. Default: ``True``
|
211 |
+
|
212 |
+
Shape::
|
213 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
214 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
215 |
+
|
216 |
+
Examples:
|
217 |
+
>>> # With Learnable Parameters
|
218 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
219 |
+
>>> # Without Learnable Parameters
|
220 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
221 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
222 |
+
>>> output = m(input)
|
223 |
+
"""
|
224 |
+
|
225 |
+
def _check_input_dim(self, input):
|
226 |
+
if input.dim() != 2 and input.dim() != 3:
|
227 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
228 |
+
.format(input.dim()))
|
229 |
+
|
230 |
+
|
231 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
232 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
233 |
+
of 3d inputs
|
234 |
+
|
235 |
+
.. math::
|
236 |
+
|
237 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
238 |
+
|
239 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
240 |
+
standard-deviation are reduced across all devices during training.
|
241 |
+
|
242 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
243 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
244 |
+
the statistics only on that device, which accelerated the computation and
|
245 |
+
is also easy to implement, but the statistics might be inaccurate.
|
246 |
+
Instead, in this synchronized version, the statistics will be computed
|
247 |
+
over all training samples distributed on multiple devices.
|
248 |
+
|
249 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
250 |
+
as the built-in PyTorch implementation.
|
251 |
+
|
252 |
+
The mean and standard-deviation are calculated per-dimension over
|
253 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
254 |
+
of size C (where C is the input size).
|
255 |
+
|
256 |
+
During training, this layer keeps a running estimate of its computed mean
|
257 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
258 |
+
|
259 |
+
During evaluation, this running mean/variance is used for normalization.
|
260 |
+
|
261 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
262 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
263 |
+
|
264 |
+
Args:
|
265 |
+
num_features: num_features from an expected input of
|
266 |
+
size batch_size x num_features x height x width
|
267 |
+
eps: a value added to the denominator for numerical stability.
|
268 |
+
Default: 1e-5
|
269 |
+
momentum: the value used for the running_mean and running_var
|
270 |
+
computation. Default: 0.1
|
271 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
272 |
+
affine parameters. Default: ``True``
|
273 |
+
|
274 |
+
Shape::
|
275 |
+
- Input: :math:`(N, C, H, W)`
|
276 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
277 |
+
|
278 |
+
Examples:
|
279 |
+
>>> # With Learnable Parameters
|
280 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
281 |
+
>>> # Without Learnable Parameters
|
282 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
283 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
284 |
+
>>> output = m(input)
|
285 |
+
"""
|
286 |
+
|
287 |
+
def _check_input_dim(self, input):
|
288 |
+
if input.dim() != 4:
|
289 |
+
raise ValueError('expected 4D input (got {}D input)'
|
290 |
+
.format(input.dim()))
|
291 |
+
|
292 |
+
|
293 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
294 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
295 |
+
of 4d inputs
|
296 |
+
|
297 |
+
.. math::
|
298 |
+
|
299 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
300 |
+
|
301 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
302 |
+
standard-deviation are reduced across all devices during training.
|
303 |
+
|
304 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
305 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
306 |
+
the statistics only on that device, which accelerated the computation and
|
307 |
+
is also easy to implement, but the statistics might be inaccurate.
|
308 |
+
Instead, in this synchronized version, the statistics will be computed
|
309 |
+
over all training samples distributed on multiple devices.
|
310 |
+
|
311 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
312 |
+
as the built-in PyTorch implementation.
|
313 |
+
|
314 |
+
The mean and standard-deviation are calculated per-dimension over
|
315 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
316 |
+
of size C (where C is the input size).
|
317 |
+
|
318 |
+
During training, this layer keeps a running estimate of its computed mean
|
319 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
320 |
+
|
321 |
+
During evaluation, this running mean/variance is used for normalization.
|
322 |
+
|
323 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
324 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
325 |
+
or Spatio-temporal BatchNorm
|
326 |
+
|
327 |
+
Args:
|
328 |
+
num_features: num_features from an expected input of
|
329 |
+
size batch_size x num_features x depth x height x width
|
330 |
+
eps: a value added to the denominator for numerical stability.
|
331 |
+
Default: 1e-5
|
332 |
+
momentum: the value used for the running_mean and running_var
|
333 |
+
computation. Default: 0.1
|
334 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
335 |
+
affine parameters. Default: ``True``
|
336 |
+
|
337 |
+
Shape::
|
338 |
+
- Input: :math:`(N, C, D, H, W)`
|
339 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
340 |
+
|
341 |
+
Examples:
|
342 |
+
>>> # With Learnable Parameters
|
343 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
344 |
+
>>> # Without Learnable Parameters
|
345 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
346 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
347 |
+
>>> output = m(input)
|
348 |
+
"""
|
349 |
+
|
350 |
+
def _check_input_dim(self, input):
|
351 |
+
if input.dim() != 5:
|
352 |
+
raise ValueError('expected 5D input (got {}D input)'
|
353 |
+
.format(input.dim()))
|
354 |
+
|
355 |
+
|
356 |
+
@contextlib.contextmanager
|
357 |
+
def patch_sync_batchnorm():
|
358 |
+
import torch.nn as nn
|
359 |
+
|
360 |
+
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
361 |
+
|
362 |
+
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
363 |
+
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
364 |
+
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
365 |
+
|
366 |
+
yield
|
367 |
+
|
368 |
+
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
369 |
+
|
370 |
+
|
371 |
+
def convert_model(module):
|
372 |
+
"""Traverse the input module and its child recursively
|
373 |
+
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
374 |
+
to SynchronizedBatchNorm*N*d
|
375 |
+
|
376 |
+
Args:
|
377 |
+
module: the input module needs to be convert to SyncBN model
|
378 |
+
|
379 |
+
Examples:
|
380 |
+
>>> import torch.nn as nn
|
381 |
+
>>> import torchvision
|
382 |
+
>>> # m is a standard pytorch model
|
383 |
+
>>> m = torchvision.models.resnet18(True)
|
384 |
+
>>> m = nn.DataParallel(m)
|
385 |
+
>>> # after convert, m is using SyncBN
|
386 |
+
>>> m = convert_model(m)
|
387 |
+
"""
|
388 |
+
if isinstance(module, torch.nn.DataParallel):
|
389 |
+
mod = module.module
|
390 |
+
mod = convert_model(mod)
|
391 |
+
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
|
392 |
+
return mod
|
393 |
+
|
394 |
+
mod = module
|
395 |
+
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
396 |
+
torch.nn.modules.batchnorm.BatchNorm2d,
|
397 |
+
torch.nn.modules.batchnorm.BatchNorm3d],
|
398 |
+
[SynchronizedBatchNorm1d,
|
399 |
+
SynchronizedBatchNorm2d,
|
400 |
+
SynchronizedBatchNorm3d]):
|
401 |
+
if isinstance(module, pth_module):
|
402 |
+
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
403 |
+
mod.running_mean = module.running_mean
|
404 |
+
mod.running_var = module.running_var
|
405 |
+
if module.affine:
|
406 |
+
mod.weight.data = module.weight.data.clone().detach()
|
407 |
+
mod.bias.data = module.bias.data.clone().detach()
|
408 |
+
|
409 |
+
for name, child in module.named_children():
|
410 |
+
mod.add_module(name, convert_model(child))
|
411 |
+
|
412 |
+
return mod
|
Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : batchnorm_reimpl.py
|
4 |
+
# Author : acgtyrant
|
5 |
+
# Date : 11/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.init as init
|
14 |
+
|
15 |
+
__all__ = ['BatchNorm2dReimpl']
|
16 |
+
|
17 |
+
|
18 |
+
class BatchNorm2dReimpl(nn.Module):
|
19 |
+
"""
|
20 |
+
A re-implementation of batch normalization, used for testing the numerical
|
21 |
+
stability.
|
22 |
+
|
23 |
+
Author: acgtyrant
|
24 |
+
See also:
|
25 |
+
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
26 |
+
"""
|
27 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_features = num_features
|
31 |
+
self.eps = eps
|
32 |
+
self.momentum = momentum
|
33 |
+
self.weight = nn.Parameter(torch.empty(num_features))
|
34 |
+
self.bias = nn.Parameter(torch.empty(num_features))
|
35 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
36 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
37 |
+
self.reset_parameters()
|
38 |
+
|
39 |
+
def reset_running_stats(self):
|
40 |
+
self.running_mean.zero_()
|
41 |
+
self.running_var.fill_(1)
|
42 |
+
|
43 |
+
def reset_parameters(self):
|
44 |
+
self.reset_running_stats()
|
45 |
+
init.uniform_(self.weight)
|
46 |
+
init.zeros_(self.bias)
|
47 |
+
|
48 |
+
def forward(self, input_):
|
49 |
+
batchsize, channels, height, width = input_.size()
|
50 |
+
numel = batchsize * height * width
|
51 |
+
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
52 |
+
sum_ = input_.sum(1)
|
53 |
+
sum_of_square = input_.pow(2).sum(1)
|
54 |
+
mean = sum_ / numel
|
55 |
+
sumvar = sum_of_square - sum_ * mean
|
56 |
+
|
57 |
+
self.running_mean = (
|
58 |
+
(1 - self.momentum) * self.running_mean
|
59 |
+
+ self.momentum * mean.detach()
|
60 |
+
)
|
61 |
+
unbias_var = sumvar / (numel - 1)
|
62 |
+
self.running_var = (
|
63 |
+
(1 - self.momentum) * self.running_var
|
64 |
+
+ self.momentum * unbias_var.detach()
|
65 |
+
)
|
66 |
+
|
67 |
+
bias_var = sumvar / numel
|
68 |
+
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
69 |
+
output = (
|
70 |
+
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
71 |
+
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
72 |
+
|
73 |
+
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
74 |
+
|
Face_Enhancement/models/networks/sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
Face_Enhancement/models/networks/sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
Face_Enhancement/models/networks/sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
class TorchTestCase(unittest.TestCase):
|
16 |
+
def assertTensorClose(self, x, y):
|
17 |
+
adiff = float((x - y).abs().max())
|
18 |
+
if (y == 0).all():
|
19 |
+
rdiff = 'NaN'
|
20 |
+
else:
|
21 |
+
rdiff = float((adiff / y).abs().max())
|
22 |
+
|
23 |
+
message = (
|
24 |
+
'Tensor close check failed\n'
|
25 |
+
'adiff={}\n'
|
26 |
+
'rdiff={}\n'
|
27 |
+
).format(adiff, rdiff)
|
28 |
+
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
|
29 |
+
|
Face_Enhancement/models/pix2pix_model.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import models.networks as networks
|
6 |
+
import util.util as util
|
7 |
+
|
8 |
+
|
9 |
+
class Pix2PixModel(torch.nn.Module):
|
10 |
+
@staticmethod
|
11 |
+
def modify_commandline_options(parser, is_train):
|
12 |
+
networks.modify_commandline_options(parser, is_train)
|
13 |
+
return parser
|
14 |
+
|
15 |
+
def __init__(self, opt):
|
16 |
+
super().__init__()
|
17 |
+
self.opt = opt
|
18 |
+
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
|
19 |
+
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor
|
20 |
+
|
21 |
+
self.netG, self.netD, self.netE = self.initialize_networks(opt)
|
22 |
+
|
23 |
+
# set loss functions
|
24 |
+
if opt.isTrain:
|
25 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
|
26 |
+
self.criterionFeat = torch.nn.L1Loss()
|
27 |
+
if not opt.no_vgg_loss:
|
28 |
+
self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
|
29 |
+
if opt.use_vae:
|
30 |
+
self.KLDLoss = networks.KLDLoss()
|
31 |
+
|
32 |
+
# Entry point for all calls involving forward pass
|
33 |
+
# of deep networks. We used this approach since DataParallel module
|
34 |
+
# can't parallelize custom functions, we branch to different
|
35 |
+
# routines based on |mode|.
|
36 |
+
def forward(self, data, mode):
|
37 |
+
input_semantics, real_image, degraded_image = self.preprocess_input(data)
|
38 |
+
|
39 |
+
if mode == "generator":
|
40 |
+
g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image)
|
41 |
+
return g_loss, generated
|
42 |
+
elif mode == "discriminator":
|
43 |
+
d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image)
|
44 |
+
return d_loss
|
45 |
+
elif mode == "encode_only":
|
46 |
+
z, mu, logvar = self.encode_z(real_image)
|
47 |
+
return mu, logvar
|
48 |
+
elif mode == "inference":
|
49 |
+
with torch.no_grad():
|
50 |
+
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
|
51 |
+
return fake_image
|
52 |
+
else:
|
53 |
+
raise ValueError("|mode| is invalid")
|
54 |
+
|
55 |
+
def create_optimizers(self, opt):
|
56 |
+
G_params = list(self.netG.parameters())
|
57 |
+
if opt.use_vae:
|
58 |
+
G_params += list(self.netE.parameters())
|
59 |
+
if opt.isTrain:
|
60 |
+
D_params = list(self.netD.parameters())
|
61 |
+
|
62 |
+
beta1, beta2 = opt.beta1, opt.beta2
|
63 |
+
if opt.no_TTUR:
|
64 |
+
G_lr, D_lr = opt.lr, opt.lr
|
65 |
+
else:
|
66 |
+
G_lr, D_lr = opt.lr / 2, opt.lr * 2
|
67 |
+
|
68 |
+
optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
|
69 |
+
optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
|
70 |
+
|
71 |
+
return optimizer_G, optimizer_D
|
72 |
+
|
73 |
+
def save(self, epoch):
|
74 |
+
util.save_network(self.netG, "G", epoch, self.opt)
|
75 |
+
util.save_network(self.netD, "D", epoch, self.opt)
|
76 |
+
if self.opt.use_vae:
|
77 |
+
util.save_network(self.netE, "E", epoch, self.opt)
|
78 |
+
|
79 |
+
############################################################################
|
80 |
+
# Private helper methods
|
81 |
+
############################################################################
|
82 |
+
|
83 |
+
def initialize_networks(self, opt):
|
84 |
+
netG = networks.define_G(opt)
|
85 |
+
netD = networks.define_D(opt) if opt.isTrain else None
|
86 |
+
netE = networks.define_E(opt) if opt.use_vae else None
|
87 |
+
|
88 |
+
if not opt.isTrain or opt.continue_train:
|
89 |
+
netG = util.load_network(netG, "G", opt.which_epoch, opt)
|
90 |
+
if opt.isTrain:
|
91 |
+
netD = util.load_network(netD, "D", opt.which_epoch, opt)
|
92 |
+
if opt.use_vae:
|
93 |
+
netE = util.load_network(netE, "E", opt.which_epoch, opt)
|
94 |
+
|
95 |
+
return netG, netD, netE
|
96 |
+
|
97 |
+
# preprocess the input, such as moving the tensors to GPUs and
|
98 |
+
# transforming the label map to one-hot encoding
|
99 |
+
# |data|: dictionary of the input data
|
100 |
+
|
101 |
+
def preprocess_input(self, data):
|
102 |
+
# move to GPU and change data types
|
103 |
+
# data['label'] = data['label'].long()
|
104 |
+
|
105 |
+
if not self.opt.isTrain:
|
106 |
+
if self.use_gpu():
|
107 |
+
data["label"] = data["label"].cuda()
|
108 |
+
data["image"] = data["image"].cuda()
|
109 |
+
return data["label"], data["image"], data["image"]
|
110 |
+
|
111 |
+
## While testing, the input image is the degraded face
|
112 |
+
if self.use_gpu():
|
113 |
+
data["label"] = data["label"].cuda()
|
114 |
+
data["degraded_image"] = data["degraded_image"].cuda()
|
115 |
+
data["image"] = data["image"].cuda()
|
116 |
+
|
117 |
+
# # create one-hot label map
|
118 |
+
# label_map = data['label']
|
119 |
+
# bs, _, h, w = label_map.size()
|
120 |
+
# nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
|
121 |
+
# else self.opt.label_nc
|
122 |
+
# input_label = self.FloatTensor(bs, nc, h, w).zero_()
|
123 |
+
# input_semantics = input_label.scatter_(1, label_map, 1.0)
|
124 |
+
|
125 |
+
return data["label"], data["image"], data["degraded_image"]
|
126 |
+
|
127 |
+
def compute_generator_loss(self, input_semantics, degraded_image, real_image):
|
128 |
+
G_losses = {}
|
129 |
+
|
130 |
+
fake_image, KLD_loss = self.generate_fake(
|
131 |
+
input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae
|
132 |
+
)
|
133 |
+
|
134 |
+
if self.opt.use_vae:
|
135 |
+
G_losses["KLD"] = KLD_loss
|
136 |
+
|
137 |
+
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
|
138 |
+
|
139 |
+
G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)
|
140 |
+
|
141 |
+
if not self.opt.no_ganFeat_loss:
|
142 |
+
num_D = len(pred_fake)
|
143 |
+
GAN_Feat_loss = self.FloatTensor(1).fill_(0)
|
144 |
+
for i in range(num_D): # for each discriminator
|
145 |
+
# last output is the final prediction, so we exclude it
|
146 |
+
num_intermediate_outputs = len(pred_fake[i]) - 1
|
147 |
+
for j in range(num_intermediate_outputs): # for each layer output
|
148 |
+
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
|
149 |
+
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
|
150 |
+
G_losses["GAN_Feat"] = GAN_Feat_loss
|
151 |
+
|
152 |
+
if not self.opt.no_vgg_loss:
|
153 |
+
G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg
|
154 |
+
|
155 |
+
return G_losses, fake_image
|
156 |
+
|
157 |
+
def compute_discriminator_loss(self, input_semantics, degraded_image, real_image):
|
158 |
+
D_losses = {}
|
159 |
+
with torch.no_grad():
|
160 |
+
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
|
161 |
+
fake_image = fake_image.detach()
|
162 |
+
fake_image.requires_grad_()
|
163 |
+
|
164 |
+
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
|
165 |
+
|
166 |
+
D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
|
167 |
+
D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)
|
168 |
+
|
169 |
+
return D_losses
|
170 |
+
|
171 |
+
def encode_z(self, real_image):
|
172 |
+
mu, logvar = self.netE(real_image)
|
173 |
+
z = self.reparameterize(mu, logvar)
|
174 |
+
return z, mu, logvar
|
175 |
+
|
176 |
+
def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):
|
177 |
+
z = None
|
178 |
+
KLD_loss = None
|
179 |
+
if self.opt.use_vae:
|
180 |
+
z, mu, logvar = self.encode_z(real_image)
|
181 |
+
if compute_kld_loss:
|
182 |
+
KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
|
183 |
+
|
184 |
+
fake_image = self.netG(input_semantics, degraded_image, z=z)
|
185 |
+
|
186 |
+
assert (
|
187 |
+
not compute_kld_loss
|
188 |
+
) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
|
189 |
+
|
190 |
+
return fake_image, KLD_loss
|
191 |
+
|
192 |
+
# Given fake and real image, return the prediction of discriminator
|
193 |
+
# for each fake and real image.
|
194 |
+
|
195 |
+
def discriminate(self, input_semantics, fake_image, real_image):
|
196 |
+
|
197 |
+
if self.opt.no_parsing_map:
|
198 |
+
fake_concat = fake_image
|
199 |
+
real_concat = real_image
|
200 |
+
else:
|
201 |
+
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
|
202 |
+
real_concat = torch.cat([input_semantics, real_image], dim=1)
|
203 |
+
|
204 |
+
# In Batch Normalization, the fake and real images are
|
205 |
+
# recommended to be in the same batch to avoid disparate
|
206 |
+
# statistics in fake and real images.
|
207 |
+
# So both fake and real images are fed to D all at once.
|
208 |
+
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
|
209 |
+
|
210 |
+
discriminator_out = self.netD(fake_and_real)
|
211 |
+
|
212 |
+
pred_fake, pred_real = self.divide_pred(discriminator_out)
|
213 |
+
|
214 |
+
return pred_fake, pred_real
|
215 |
+
|
216 |
+
# Take the prediction of fake and real images from the combined batch
|
217 |
+
def divide_pred(self, pred):
|
218 |
+
# the prediction contains the intermediate outputs of multiscale GAN,
|
219 |
+
# so it's usually a list
|
220 |
+
if type(pred) == list:
|
221 |
+
fake = []
|
222 |
+
real = []
|
223 |
+
for p in pred:
|
224 |
+
fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
|
225 |
+
real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
|
226 |
+
else:
|
227 |
+
fake = pred[: pred.size(0) // 2]
|
228 |
+
real = pred[pred.size(0) // 2 :]
|
229 |
+
|
230 |
+
return fake, real
|
231 |
+
|
232 |
+
def get_edges(self, t):
|
233 |
+
edge = self.ByteTensor(t.size()).zero_()
|
234 |
+
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
235 |
+
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
236 |
+
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
237 |
+
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
238 |
+
return edge.float()
|
239 |
+
|
240 |
+
def reparameterize(self, mu, logvar):
|
241 |
+
std = torch.exp(0.5 * logvar)
|
242 |
+
eps = torch.randn_like(std)
|
243 |
+
return eps.mul(std) + mu
|
244 |
+
|
245 |
+
def use_gpu(self):
|
246 |
+
return len(self.opt.gpu_ids) > 0
|
Face_Enhancement/options/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
Face_Enhancement/options/base_options.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
from util import util
|
8 |
+
import torch
|
9 |
+
import models
|
10 |
+
import data
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
|
14 |
+
class BaseOptions:
|
15 |
+
def __init__(self):
|
16 |
+
self.initialized = False
|
17 |
+
|
18 |
+
def initialize(self, parser):
|
19 |
+
# experiment specifics
|
20 |
+
parser.add_argument(
|
21 |
+
"--name",
|
22 |
+
type=str,
|
23 |
+
default="label2coco",
|
24 |
+
help="name of the experiment. It decides where to store samples and models",
|
25 |
+
)
|
26 |
+
|
27 |
+
parser.add_argument(
|
28 |
+
"--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here"
|
32 |
+
)
|
33 |
+
parser.add_argument("--model", type=str, default="pix2pix", help="which model to use")
|
34 |
+
parser.add_argument(
|
35 |
+
"--norm_G",
|
36 |
+
type=str,
|
37 |
+
default="spectralinstance",
|
38 |
+
help="instance normalization or batch normalization",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--norm_D",
|
42 |
+
type=str,
|
43 |
+
default="spectralinstance",
|
44 |
+
help="instance normalization or batch normalization",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--norm_E",
|
48 |
+
type=str,
|
49 |
+
default="spectralinstance",
|
50 |
+
help="instance normalization or batch normalization",
|
51 |
+
)
|
52 |
+
parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc")
|
53 |
+
|
54 |
+
# input/output sizes
|
55 |
+
parser.add_argument("--batchSize", type=int, default=1, help="input batch size")
|
56 |
+
parser.add_argument(
|
57 |
+
"--preprocess_mode",
|
58 |
+
type=str,
|
59 |
+
default="scale_width_and_crop",
|
60 |
+
help="scaling and cropping of images at load time.",
|
61 |
+
choices=(
|
62 |
+
"resize_and_crop",
|
63 |
+
"crop",
|
64 |
+
"scale_width",
|
65 |
+
"scale_width_and_crop",
|
66 |
+
"scale_shortside",
|
67 |
+
"scale_shortside_and_crop",
|
68 |
+
"fixed",
|
69 |
+
"none",
|
70 |
+
"resize",
|
71 |
+
),
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--load_size",
|
75 |
+
type=int,
|
76 |
+
default=1024,
|
77 |
+
help="Scale images to this size. The final image will be cropped to --crop_size.",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--crop_size",
|
81 |
+
type=int,
|
82 |
+
default=512,
|
83 |
+
help="Crop to the width of crop_size (after initially scaling the images to load_size.)",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--aspect_ratio",
|
87 |
+
type=float,
|
88 |
+
default=1.0,
|
89 |
+
help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--label_nc",
|
93 |
+
type=int,
|
94 |
+
default=182,
|
95 |
+
help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--contain_dontcare_label",
|
99 |
+
action="store_true",
|
100 |
+
help="if the label map contains dontcare label (dontcare=255)",
|
101 |
+
)
|
102 |
+
parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels")
|
103 |
+
|
104 |
+
# for setting inputs
|
105 |
+
parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/")
|
106 |
+
parser.add_argument("--dataset_mode", type=str, default="coco")
|
107 |
+
parser.add_argument(
|
108 |
+
"--serial_batches",
|
109 |
+
action="store_true",
|
110 |
+
help="if true, takes images in order to make batches, otherwise takes them randomly",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--no_flip",
|
114 |
+
action="store_true",
|
115 |
+
help="if specified, do not flip the images for data argumentation",
|
116 |
+
)
|
117 |
+
parser.add_argument("--nThreads", default=0, type=int, help="# threads for loading data")
|
118 |
+
parser.add_argument(
|
119 |
+
"--max_dataset_size",
|
120 |
+
type=int,
|
121 |
+
default=sys.maxsize,
|
122 |
+
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--load_from_opt_file",
|
126 |
+
action="store_true",
|
127 |
+
help="load the options from checkpoints and use that as default",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--cache_filelist_write",
|
131 |
+
action="store_true",
|
132 |
+
help="saves the current filelist into a text file, so that it loads faster",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--cache_filelist_read", action="store_true", help="reads from the file list cache"
|
136 |
+
)
|
137 |
+
|
138 |
+
# for displays
|
139 |
+
parser.add_argument("--display_winsize", type=int, default=400, help="display window size")
|
140 |
+
|
141 |
+
# for generator
|
142 |
+
parser.add_argument(
|
143 |
+
"--netG", type=str, default="spade", help="selects model to use for netG (pix2pixhd | spade)"
|
144 |
+
)
|
145 |
+
parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer")
|
146 |
+
parser.add_argument(
|
147 |
+
"--init_type",
|
148 |
+
type=str,
|
149 |
+
default="xavier",
|
150 |
+
help="network initialization [normal|xavier|kaiming|orthogonal]",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--init_variance", type=float, default=0.02, help="variance of the initialization distribution"
|
154 |
+
)
|
155 |
+
parser.add_argument("--z_dim", type=int, default=256, help="dimension of the latent z vector")
|
156 |
+
parser.add_argument(
|
157 |
+
"--no_parsing_map", action="store_true", help="During training, we do not use the parsing map"
|
158 |
+
)
|
159 |
+
|
160 |
+
# for instance-wise features
|
161 |
+
parser.add_argument(
|
162 |
+
"--no_instance", action="store_true", help="if specified, do *not* add instance map as input"
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--nef", type=int, default=16, help="# of encoder filters in the first conv layer"
|
166 |
+
)
|
167 |
+
parser.add_argument("--use_vae", action="store_true", help="enable training with an image encoder.")
|
168 |
+
parser.add_argument(
|
169 |
+
"--tensorboard_log", action="store_true", help="use tensorboard to record the resutls"
|
170 |
+
)
|
171 |
+
|
172 |
+
# parser.add_argument('--img_dir',)
|
173 |
+
parser.add_argument(
|
174 |
+
"--old_face_folder", type=str, default="", help="The folder name of input old face"
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--old_face_label_folder", type=str, default="", help="The folder name of input old face label"
|
178 |
+
)
|
179 |
+
|
180 |
+
parser.add_argument("--injection_layer", type=str, default="all", help="")
|
181 |
+
|
182 |
+
self.initialized = True
|
183 |
+
return parser
|
184 |
+
|
185 |
+
def gather_options(self):
|
186 |
+
# initialize parser with basic options
|
187 |
+
if not self.initialized:
|
188 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
189 |
+
parser = self.initialize(parser)
|
190 |
+
|
191 |
+
# get the basic options
|
192 |
+
opt, unknown = parser.parse_known_args()
|
193 |
+
|
194 |
+
# modify model-related parser options
|
195 |
+
model_name = opt.model
|
196 |
+
model_option_setter = models.get_option_setter(model_name)
|
197 |
+
parser = model_option_setter(parser, self.isTrain)
|
198 |
+
|
199 |
+
# modify dataset-related parser options
|
200 |
+
# dataset_mode = opt.dataset_mode
|
201 |
+
# dataset_option_setter = data.get_option_setter(dataset_mode)
|
202 |
+
# parser = dataset_option_setter(parser, self.isTrain)
|
203 |
+
|
204 |
+
opt, unknown = parser.parse_known_args()
|
205 |
+
|
206 |
+
# if there is opt_file, load it.
|
207 |
+
# The previous default options will be overwritten
|
208 |
+
if opt.load_from_opt_file:
|
209 |
+
parser = self.update_options_from_file(parser, opt)
|
210 |
+
|
211 |
+
opt = parser.parse_args()
|
212 |
+
self.parser = parser
|
213 |
+
return opt
|
214 |
+
|
215 |
+
def print_options(self, opt):
|
216 |
+
message = ""
|
217 |
+
message += "----------------- Options ---------------\n"
|
218 |
+
for k, v in sorted(vars(opt).items()):
|
219 |
+
comment = ""
|
220 |
+
default = self.parser.get_default(k)
|
221 |
+
if v != default:
|
222 |
+
comment = "\t[default: %s]" % str(default)
|
223 |
+
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
|
224 |
+
message += "----------------- End -------------------"
|
225 |
+
# print(message)
|
226 |
+
|
227 |
+
def option_file_path(self, opt, makedir=False):
|
228 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
229 |
+
if makedir:
|
230 |
+
util.mkdirs(expr_dir)
|
231 |
+
file_name = os.path.join(expr_dir, "opt")
|
232 |
+
return file_name
|
233 |
+
|
234 |
+
def save_options(self, opt):
|
235 |
+
file_name = self.option_file_path(opt, makedir=True)
|
236 |
+
with open(file_name + ".txt", "wt") as opt_file:
|
237 |
+
for k, v in sorted(vars(opt).items()):
|
238 |
+
comment = ""
|
239 |
+
default = self.parser.get_default(k)
|
240 |
+
if v != default:
|
241 |
+
comment = "\t[default: %s]" % str(default)
|
242 |
+
opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment))
|
243 |
+
|
244 |
+
with open(file_name + ".pkl", "wb") as opt_file:
|
245 |
+
pickle.dump(opt, opt_file)
|
246 |
+
|
247 |
+
def update_options_from_file(self, parser, opt):
|
248 |
+
new_opt = self.load_options(opt)
|
249 |
+
for k, v in sorted(vars(opt).items()):
|
250 |
+
if hasattr(new_opt, k) and v != getattr(new_opt, k):
|
251 |
+
new_val = getattr(new_opt, k)
|
252 |
+
parser.set_defaults(**{k: new_val})
|
253 |
+
return parser
|
254 |
+
|
255 |
+
def load_options(self, opt):
|
256 |
+
file_name = self.option_file_path(opt, makedir=False)
|
257 |
+
new_opt = pickle.load(open(file_name + ".pkl", "rb"))
|
258 |
+
return new_opt
|
259 |
+
|
260 |
+
def parse(self, save=False):
|
261 |
+
|
262 |
+
opt = self.gather_options()
|
263 |
+
opt.isTrain = self.isTrain # train or test
|
264 |
+
opt.contain_dontcare_label = False
|
265 |
+
|
266 |
+
self.print_options(opt)
|
267 |
+
if opt.isTrain:
|
268 |
+
self.save_options(opt)
|
269 |
+
|
270 |
+
# Set semantic_nc based on the option.
|
271 |
+
# This will be convenient in many places
|
272 |
+
opt.semantic_nc = (
|
273 |
+
opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)
|
274 |
+
)
|
275 |
+
|
276 |
+
# set gpu ids
|
277 |
+
str_ids = opt.gpu_ids.split(",")
|
278 |
+
opt.gpu_ids = []
|
279 |
+
for str_id in str_ids:
|
280 |
+
int_id = int(str_id)
|
281 |
+
if int_id >= 0:
|
282 |
+
opt.gpu_ids.append(int_id)
|
283 |
+
|
284 |
+
if len(opt.gpu_ids) > 0:
|
285 |
+
print("The main GPU is ")
|
286 |
+
print(opt.gpu_ids[0])
|
287 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
288 |
+
|
289 |
+
assert (
|
290 |
+
len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0
|
291 |
+
), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids))
|
292 |
+
|
293 |
+
self.opt = opt
|
294 |
+
return self.opt
|
Face_Enhancement/options/test_options.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from .base_options import BaseOptions
|
5 |
+
|
6 |
+
|
7 |
+
class TestOptions(BaseOptions):
|
8 |
+
def initialize(self, parser):
|
9 |
+
BaseOptions.initialize(self, parser)
|
10 |
+
parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
|
11 |
+
parser.add_argument(
|
12 |
+
"--which_epoch",
|
13 |
+
type=str,
|
14 |
+
default="latest",
|
15 |
+
help="which epoch to load? set to latest to use latest cached model",
|
16 |
+
)
|
17 |
+
parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run")
|
18 |
+
|
19 |
+
parser.set_defaults(
|
20 |
+
preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256
|
21 |
+
)
|
22 |
+
parser.set_defaults(serial_batches=True)
|
23 |
+
parser.set_defaults(no_flip=True)
|
24 |
+
parser.set_defaults(phase="test")
|
25 |
+
self.isTrain = False
|
26 |
+
return parser
|
Face_Enhancement/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.0.0
|
2 |
+
torchvision
|
3 |
+
dominate>=2.3.1
|
4 |
+
wandb
|
5 |
+
dill
|
6 |
+
scikit-image
|
7 |
+
tensorboardX
|
8 |
+
scipy
|
9 |
+
opencv-python
|
Face_Enhancement/test_face.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import os
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import data
|
8 |
+
from options.test_options import TestOptions
|
9 |
+
from models.pix2pix_model import Pix2PixModel
|
10 |
+
from util.visualizer import Visualizer
|
11 |
+
import torchvision.utils as vutils
|
12 |
+
import warnings
|
13 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
14 |
+
|
15 |
+
opt = TestOptions().parse()
|
16 |
+
|
17 |
+
dataloader = data.create_dataloader(opt)
|
18 |
+
|
19 |
+
model = Pix2PixModel(opt)
|
20 |
+
model.eval()
|
21 |
+
|
22 |
+
visualizer = Visualizer(opt)
|
23 |
+
|
24 |
+
|
25 |
+
single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img")
|
26 |
+
|
27 |
+
|
28 |
+
if not os.path.exists(single_save_url):
|
29 |
+
os.makedirs(single_save_url)
|
30 |
+
|
31 |
+
|
32 |
+
for i, data_i in enumerate(dataloader):
|
33 |
+
if i * opt.batchSize >= opt.how_many:
|
34 |
+
break
|
35 |
+
|
36 |
+
generated = model(data_i, mode="inference")
|
37 |
+
|
38 |
+
img_path = data_i["path"]
|
39 |
+
|
40 |
+
for b in range(generated.shape[0]):
|
41 |
+
img_name = os.path.split(img_path[b])[-1]
|
42 |
+
save_img_url = os.path.join(single_save_url, img_name)
|
43 |
+
|
44 |
+
vutils.save_image((generated[b] + 1) / 2, save_img_url)
|
45 |
+
|
Face_Enhancement/util/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
Face_Enhancement/util/iter_counter.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
# Helper class that keeps track of training iterations
|
10 |
+
class IterationCounter:
|
11 |
+
def __init__(self, opt, dataset_size):
|
12 |
+
self.opt = opt
|
13 |
+
self.dataset_size = dataset_size
|
14 |
+
|
15 |
+
self.first_epoch = 1
|
16 |
+
self.total_epochs = opt.niter + opt.niter_decay
|
17 |
+
self.epoch_iter = 0 # iter number within each epoch
|
18 |
+
self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt")
|
19 |
+
if opt.isTrain and opt.continue_train:
|
20 |
+
try:
|
21 |
+
self.first_epoch, self.epoch_iter = np.loadtxt(
|
22 |
+
self.iter_record_path, delimiter=",", dtype=int
|
23 |
+
)
|
24 |
+
print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter))
|
25 |
+
except:
|
26 |
+
print(
|
27 |
+
"Could not load iteration record at %s. Starting from beginning." % self.iter_record_path
|
28 |
+
)
|
29 |
+
|
30 |
+
self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
|
31 |
+
|
32 |
+
# return the iterator of epochs for the training
|
33 |
+
def training_epochs(self):
|
34 |
+
return range(self.first_epoch, self.total_epochs + 1)
|
35 |
+
|
36 |
+
def record_epoch_start(self, epoch):
|
37 |
+
self.epoch_start_time = time.time()
|
38 |
+
self.epoch_iter = 0
|
39 |
+
self.last_iter_time = time.time()
|
40 |
+
self.current_epoch = epoch
|
41 |
+
|
42 |
+
def record_one_iteration(self):
|
43 |
+
current_time = time.time()
|
44 |
+
|
45 |
+
# the last remaining batch is dropped (see data/__init__.py),
|
46 |
+
# so we can assume batch size is always opt.batchSize
|
47 |
+
self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
|
48 |
+
self.last_iter_time = current_time
|
49 |
+
self.total_steps_so_far += self.opt.batchSize
|
50 |
+
self.epoch_iter += self.opt.batchSize
|
51 |
+
|
52 |
+
def record_epoch_end(self):
|
53 |
+
current_time = time.time()
|
54 |
+
self.time_per_epoch = current_time - self.epoch_start_time
|
55 |
+
print(
|
56 |
+
"End of epoch %d / %d \t Time Taken: %d sec"
|
57 |
+
% (self.current_epoch, self.total_epochs, self.time_per_epoch)
|
58 |
+
)
|
59 |
+
if self.current_epoch % self.opt.save_epoch_freq == 0:
|
60 |
+
np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d")
|
61 |
+
print("Saved current iteration count at %s." % self.iter_record_path)
|
62 |
+
|
63 |
+
def record_current_iter(self):
|
64 |
+
np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d")
|
65 |
+
print("Saved current iteration count at %s." % self.iter_record_path)
|
66 |
+
|
67 |
+
def needs_saving(self):
|
68 |
+
return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
|
69 |
+
|
70 |
+
def needs_printing(self):
|
71 |
+
return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
|
72 |
+
|
73 |
+
def needs_displaying(self):
|
74 |
+
return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
|
Face_Enhancement/util/util.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import re
|
5 |
+
import importlib
|
6 |
+
import torch
|
7 |
+
from argparse import Namespace
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import argparse
|
12 |
+
import dill as pickle
|
13 |
+
|
14 |
+
|
15 |
+
def save_obj(obj, name):
|
16 |
+
with open(name, "wb") as f:
|
17 |
+
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
|
18 |
+
|
19 |
+
|
20 |
+
def load_obj(name):
|
21 |
+
with open(name, "rb") as f:
|
22 |
+
return pickle.load(f)
|
23 |
+
|
24 |
+
|
25 |
+
def copyconf(default_opt, **kwargs):
|
26 |
+
conf = argparse.Namespace(**vars(default_opt))
|
27 |
+
for key in kwargs:
|
28 |
+
print(key, kwargs[key])
|
29 |
+
setattr(conf, key, kwargs[key])
|
30 |
+
return conf
|
31 |
+
|
32 |
+
|
33 |
+
# Converts a Tensor into a Numpy array
|
34 |
+
# |imtype|: the desired type of the converted numpy array
|
35 |
+
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
|
36 |
+
if isinstance(image_tensor, list):
|
37 |
+
image_numpy = []
|
38 |
+
for i in range(len(image_tensor)):
|
39 |
+
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
40 |
+
return image_numpy
|
41 |
+
|
42 |
+
if image_tensor.dim() == 4:
|
43 |
+
# transform each image in the batch
|
44 |
+
images_np = []
|
45 |
+
for b in range(image_tensor.size(0)):
|
46 |
+
one_image = image_tensor[b]
|
47 |
+
one_image_np = tensor2im(one_image)
|
48 |
+
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
|
49 |
+
images_np = np.concatenate(images_np, axis=0)
|
50 |
+
|
51 |
+
return images_np
|
52 |
+
|
53 |
+
if image_tensor.dim() == 2:
|
54 |
+
image_tensor = image_tensor.unsqueeze(0)
|
55 |
+
image_numpy = image_tensor.detach().cpu().float().numpy()
|
56 |
+
if normalize:
|
57 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
58 |
+
else:
|
59 |
+
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
60 |
+
image_numpy = np.clip(image_numpy, 0, 255)
|
61 |
+
if image_numpy.shape[2] == 1:
|
62 |
+
image_numpy = image_numpy[:, :, 0]
|
63 |
+
return image_numpy.astype(imtype)
|
64 |
+
|
65 |
+
|
66 |
+
# Converts a one-hot tensor into a colorful label map
|
67 |
+
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
|
68 |
+
if label_tensor.dim() == 4:
|
69 |
+
# transform each image in the batch
|
70 |
+
images_np = []
|
71 |
+
for b in range(label_tensor.size(0)):
|
72 |
+
one_image = label_tensor[b]
|
73 |
+
one_image_np = tensor2label(one_image, n_label, imtype)
|
74 |
+
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
|
75 |
+
images_np = np.concatenate(images_np, axis=0)
|
76 |
+
# if tile:
|
77 |
+
# images_tiled = tile_images(images_np)
|
78 |
+
# return images_tiled
|
79 |
+
# else:
|
80 |
+
# images_np = images_np[0]
|
81 |
+
# return images_np
|
82 |
+
return images_np
|
83 |
+
|
84 |
+
if label_tensor.dim() == 1:
|
85 |
+
return np.zeros((64, 64, 3), dtype=np.uint8)
|
86 |
+
if n_label == 0:
|
87 |
+
return tensor2im(label_tensor, imtype)
|
88 |
+
label_tensor = label_tensor.cpu().float()
|
89 |
+
if label_tensor.size()[0] > 1:
|
90 |
+
label_tensor = label_tensor.max(0, keepdim=True)[1]
|
91 |
+
label_tensor = Colorize(n_label)(label_tensor)
|
92 |
+
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
|
93 |
+
result = label_numpy.astype(imtype)
|
94 |
+
return result
|
95 |
+
|
96 |
+
|
97 |
+
def save_image(image_numpy, image_path, create_dir=False):
|
98 |
+
if create_dir:
|
99 |
+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
100 |
+
if len(image_numpy.shape) == 2:
|
101 |
+
image_numpy = np.expand_dims(image_numpy, axis=2)
|
102 |
+
if image_numpy.shape[2] == 1:
|
103 |
+
image_numpy = np.repeat(image_numpy, 3, 2)
|
104 |
+
image_pil = Image.fromarray(image_numpy)
|
105 |
+
|
106 |
+
# save to png
|
107 |
+
image_pil.save(image_path.replace(".jpg", ".png"))
|
108 |
+
|
109 |
+
|
110 |
+
def mkdirs(paths):
|
111 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
112 |
+
for path in paths:
|
113 |
+
mkdir(path)
|
114 |
+
else:
|
115 |
+
mkdir(paths)
|
116 |
+
|
117 |
+
|
118 |
+
def mkdir(path):
|
119 |
+
if not os.path.exists(path):
|
120 |
+
os.makedirs(path)
|
121 |
+
|
122 |
+
|
123 |
+
def atoi(text):
|
124 |
+
return int(text) if text.isdigit() else text
|
125 |
+
|
126 |
+
|
127 |
+
def natural_keys(text):
|
128 |
+
"""
|
129 |
+
alist.sort(key=natural_keys) sorts in human order
|
130 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
131 |
+
(See Toothy's implementation in the comments)
|
132 |
+
"""
|
133 |
+
return [atoi(c) for c in re.split("(\d+)", text)]
|
134 |
+
|
135 |
+
|
136 |
+
def natural_sort(items):
|
137 |
+
items.sort(key=natural_keys)
|
138 |
+
|
139 |
+
|
140 |
+
def str2bool(v):
|
141 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
142 |
+
return True
|
143 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
144 |
+
return False
|
145 |
+
else:
|
146 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
147 |
+
|
148 |
+
|
149 |
+
def find_class_in_module(target_cls_name, module):
|
150 |
+
target_cls_name = target_cls_name.replace("_", "").lower()
|
151 |
+
clslib = importlib.import_module(module)
|
152 |
+
cls = None
|
153 |
+
for name, clsobj in clslib.__dict__.items():
|
154 |
+
if name.lower() == target_cls_name:
|
155 |
+
cls = clsobj
|
156 |
+
|
157 |
+
if cls is None:
|
158 |
+
print(
|
159 |
+
"In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
|
160 |
+
% (module, target_cls_name)
|
161 |
+
)
|
162 |
+
exit(0)
|
163 |
+
|
164 |
+
return cls
|
165 |
+
|
166 |
+
|
167 |
+
def save_network(net, label, epoch, opt):
|
168 |
+
save_filename = "%s_net_%s.pth" % (epoch, label)
|
169 |
+
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
|
170 |
+
torch.save(net.cpu().state_dict(), save_path)
|
171 |
+
if len(opt.gpu_ids) and torch.cuda.is_available():
|
172 |
+
net.cuda()
|
173 |
+
|
174 |
+
|
175 |
+
def load_network(net, label, epoch, opt):
|
176 |
+
save_filename = "%s_net_%s.pth" % (epoch, label)
|
177 |
+
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
178 |
+
save_path = os.path.join(save_dir, save_filename)
|
179 |
+
if os.path.exists(save_path):
|
180 |
+
weights = torch.load(save_path)
|
181 |
+
net.load_state_dict(weights)
|
182 |
+
return net
|
183 |
+
|
184 |
+
|
185 |
+
###############################################################################
|
186 |
+
# Code from
|
187 |
+
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
|
188 |
+
# Modified so it complies with the Citscape label map colors
|
189 |
+
###############################################################################
|
190 |
+
def uint82bin(n, count=8):
|
191 |
+
"""returns the binary of integer n, count refers to amount of bits"""
|
192 |
+
return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
|
193 |
+
|
194 |
+
|
195 |
+
class Colorize(object):
|
196 |
+
def __init__(self, n=35):
|
197 |
+
self.cmap = labelcolormap(n)
|
198 |
+
self.cmap = torch.from_numpy(self.cmap[:n])
|
199 |
+
|
200 |
+
def __call__(self, gray_image):
|
201 |
+
size = gray_image.size()
|
202 |
+
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
|
203 |
+
|
204 |
+
for label in range(0, len(self.cmap)):
|
205 |
+
mask = (label == gray_image[0]).cpu()
|
206 |
+
color_image[0][mask] = self.cmap[label][0]
|
207 |
+
color_image[1][mask] = self.cmap[label][1]
|
208 |
+
color_image[2][mask] = self.cmap[label][2]
|
209 |
+
|
210 |
+
return color_image
|
Face_Enhancement/util/visualizer.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import ntpath
|
6 |
+
import time
|
7 |
+
from . import util
|
8 |
+
import scipy.misc
|
9 |
+
|
10 |
+
try:
|
11 |
+
from StringIO import StringIO # Python 2.7
|
12 |
+
except ImportError:
|
13 |
+
from io import BytesIO # Python 3.x
|
14 |
+
import torchvision.utils as vutils
|
15 |
+
from tensorboardX import SummaryWriter
|
16 |
+
import torch
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
|
20 |
+
class Visualizer:
|
21 |
+
def __init__(self, opt):
|
22 |
+
self.opt = opt
|
23 |
+
self.tf_log = opt.isTrain and opt.tf_log
|
24 |
+
|
25 |
+
self.tensorboard_log = opt.tensorboard_log
|
26 |
+
|
27 |
+
self.win_size = opt.display_winsize
|
28 |
+
self.name = opt.name
|
29 |
+
if self.tensorboard_log:
|
30 |
+
|
31 |
+
if self.opt.isTrain:
|
32 |
+
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
|
33 |
+
if not os.path.exists(self.log_dir):
|
34 |
+
os.makedirs(self.log_dir)
|
35 |
+
self.writer = SummaryWriter(log_dir=self.log_dir)
|
36 |
+
else:
|
37 |
+
print("hi :)")
|
38 |
+
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
|
39 |
+
if not os.path.exists(self.log_dir):
|
40 |
+
os.makedirs(self.log_dir)
|
41 |
+
|
42 |
+
if opt.isTrain:
|
43 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
|
44 |
+
with open(self.log_name, "a") as log_file:
|
45 |
+
now = time.strftime("%c")
|
46 |
+
log_file.write("================ Training Loss (%s) ================\n" % now)
|
47 |
+
|
48 |
+
# |visuals|: dictionary of images to display or save
|
49 |
+
def display_current_results(self, visuals, epoch, step):
|
50 |
+
|
51 |
+
all_tensor = []
|
52 |
+
if self.tensorboard_log:
|
53 |
+
|
54 |
+
for key, tensor in visuals.items():
|
55 |
+
all_tensor.append((tensor.data.cpu() + 1) / 2)
|
56 |
+
|
57 |
+
output = torch.cat(all_tensor, 0)
|
58 |
+
img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)
|
59 |
+
|
60 |
+
if self.opt.isTrain:
|
61 |
+
self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
|
62 |
+
else:
|
63 |
+
vutils.save_image(
|
64 |
+
output,
|
65 |
+
os.path.join(self.log_dir, str(step) + ".png"),
|
66 |
+
nrow=self.opt.batchSize,
|
67 |
+
padding=0,
|
68 |
+
normalize=False,
|
69 |
+
)
|
70 |
+
|
71 |
+
# errors: dictionary of error labels and values
|
72 |
+
def plot_current_errors(self, errors, step):
|
73 |
+
if self.tf_log:
|
74 |
+
for tag, value in errors.items():
|
75 |
+
value = value.mean().float()
|
76 |
+
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
|
77 |
+
self.writer.add_summary(summary, step)
|
78 |
+
|
79 |
+
if self.tensorboard_log:
|
80 |
+
|
81 |
+
self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
|
82 |
+
self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
|
83 |
+
self.writer.add_scalars(
|
84 |
+
"Loss/GAN",
|
85 |
+
{
|
86 |
+
"G": errors["GAN"].mean().float(),
|
87 |
+
"D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
|
88 |
+
},
|
89 |
+
step,
|
90 |
+
)
|
91 |
+
|
92 |
+
# errors: same format as |errors| of plotCurrentErrors
|
93 |
+
def print_current_errors(self, epoch, i, errors, t):
|
94 |
+
message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
|
95 |
+
for k, v in errors.items():
|
96 |
+
v = v.mean().float()
|
97 |
+
message += "%s: %.3f " % (k, v)
|
98 |
+
|
99 |
+
print(message)
|
100 |
+
with open(self.log_name, "a") as log_file:
|
101 |
+
log_file.write("%s\n" % message)
|
102 |
+
|
103 |
+
def convert_visuals_to_numpy(self, visuals):
|
104 |
+
for key, t in visuals.items():
|
105 |
+
tile = self.opt.batchSize > 8
|
106 |
+
if "input_label" == key:
|
107 |
+
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy
|
108 |
+
else:
|
109 |
+
t = util.tensor2im(t, tile=tile)
|
110 |
+
visuals[key] = t
|
111 |
+
return visuals
|
112 |
+
|
113 |
+
# save image to the disk
|
114 |
+
def save_images(self, webpage, visuals, image_path):
|
115 |
+
visuals = self.convert_visuals_to_numpy(visuals)
|
116 |
+
|
117 |
+
image_dir = webpage.get_image_dir()
|
118 |
+
short_path = ntpath.basename(image_path[0])
|
119 |
+
name = os.path.splitext(short_path)[0]
|
120 |
+
|
121 |
+
webpage.add_header(name)
|
122 |
+
ims = []
|
123 |
+
txts = []
|
124 |
+
links = []
|
125 |
+
|
126 |
+
for label, image_numpy in visuals.items():
|
127 |
+
image_name = os.path.join(label, "%s.png" % (name))
|
128 |
+
save_path = os.path.join(image_dir, image_name)
|
129 |
+
util.save_image(image_numpy, save_path, create_dir=True)
|
130 |
+
|
131 |
+
ims.append(image_name)
|
132 |
+
txts.append(label)
|
133 |
+
links.append(image_name)
|
134 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
GUI.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import PySimpleGUI as sg
|
4 |
+
import os.path
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import shutil
|
9 |
+
from subprocess import call
|
10 |
+
|
11 |
+
def modify(image_filename=None, cv2_frame=None):
|
12 |
+
|
13 |
+
def run_cmd(command):
|
14 |
+
try:
|
15 |
+
call(command, shell=True)
|
16 |
+
except KeyboardInterrupt:
|
17 |
+
print("Process interrupted")
|
18 |
+
sys.exit(1)
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--input_folder", type=str,
|
22 |
+
default= image_filename, help="Test images")
|
23 |
+
parser.add_argument(
|
24 |
+
"--output_folder",
|
25 |
+
type=str,
|
26 |
+
default="./output",
|
27 |
+
help="Restored images, please use the absolute path",
|
28 |
+
)
|
29 |
+
parser.add_argument("--GPU", type=str, default="-1", help="0,1,2")
|
30 |
+
parser.add_argument(
|
31 |
+
"--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint"
|
32 |
+
)
|
33 |
+
parser.add_argument("--with_scratch",default="--with_scratch" ,action="store_true")
|
34 |
+
opts = parser.parse_args()
|
35 |
+
|
36 |
+
gpu1 = opts.GPU
|
37 |
+
|
38 |
+
# resolve relative paths before changing directory
|
39 |
+
opts.input_folder = os.path.abspath(opts.input_folder)
|
40 |
+
opts.output_folder = os.path.abspath(opts.output_folder)
|
41 |
+
if not os.path.exists(opts.output_folder):
|
42 |
+
os.makedirs(opts.output_folder)
|
43 |
+
|
44 |
+
main_environment = os.getcwd()
|
45 |
+
|
46 |
+
# Stage 1: Overall Quality Improve
|
47 |
+
print("Running Stage 1: Overall restoration")
|
48 |
+
os.chdir("./Global")
|
49 |
+
stage_1_input_dir = opts.input_folder
|
50 |
+
stage_1_output_dir = os.path.join(
|
51 |
+
opts.output_folder, "stage_1_restore_output")
|
52 |
+
if not os.path.exists(stage_1_output_dir):
|
53 |
+
os.makedirs(stage_1_output_dir)
|
54 |
+
|
55 |
+
if not opts.with_scratch:
|
56 |
+
stage_1_command = (
|
57 |
+
"python test.py --test_mode Full --Quality_restore --test_input "
|
58 |
+
+ stage_1_input_dir
|
59 |
+
+ " --outputs_dir "
|
60 |
+
+ stage_1_output_dir
|
61 |
+
+ " --gpu_ids "
|
62 |
+
+ gpu1
|
63 |
+
)
|
64 |
+
run_cmd(stage_1_command)
|
65 |
+
else:
|
66 |
+
|
67 |
+
mask_dir = os.path.join(stage_1_output_dir, "masks")
|
68 |
+
new_input = os.path.join(mask_dir, "input")
|
69 |
+
new_mask = os.path.join(mask_dir, "mask")
|
70 |
+
stage_1_command_1 = (
|
71 |
+
"python detection.py --test_path "
|
72 |
+
+ stage_1_input_dir
|
73 |
+
+ " --output_dir "
|
74 |
+
+ mask_dir
|
75 |
+
+ " --input_size full_size"
|
76 |
+
+ " --GPU "
|
77 |
+
+ gpu1
|
78 |
+
)
|
79 |
+
stage_1_command_2 = (
|
80 |
+
"python test.py --Scratch_and_Quality_restore --test_input "
|
81 |
+
+ new_input
|
82 |
+
+ " --test_mask "
|
83 |
+
+ new_mask
|
84 |
+
+ " --outputs_dir "
|
85 |
+
+ stage_1_output_dir
|
86 |
+
+ " --gpu_ids "
|
87 |
+
+ gpu1
|
88 |
+
)
|
89 |
+
run_cmd(stage_1_command_1)
|
90 |
+
run_cmd(stage_1_command_2)
|
91 |
+
|
92 |
+
# Solve the case when there is no face in the old photo
|
93 |
+
stage_1_results = os.path.join(stage_1_output_dir, "restored_image")
|
94 |
+
stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
|
95 |
+
if not os.path.exists(stage_4_output_dir):
|
96 |
+
os.makedirs(stage_4_output_dir)
|
97 |
+
for x in os.listdir(stage_1_results):
|
98 |
+
img_dir = os.path.join(stage_1_results, x)
|
99 |
+
shutil.copy(img_dir, stage_4_output_dir)
|
100 |
+
|
101 |
+
print("Finish Stage 1 ...")
|
102 |
+
print("\n")
|
103 |
+
|
104 |
+
# Stage 2: Face Detection
|
105 |
+
|
106 |
+
print("Running Stage 2: Face Detection")
|
107 |
+
os.chdir(".././Face_Detection")
|
108 |
+
stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image")
|
109 |
+
stage_2_output_dir = os.path.join(
|
110 |
+
opts.output_folder, "stage_2_detection_output")
|
111 |
+
if not os.path.exists(stage_2_output_dir):
|
112 |
+
os.makedirs(stage_2_output_dir)
|
113 |
+
stage_2_command = (
|
114 |
+
"python detect_all_dlib.py --url " + stage_2_input_dir +
|
115 |
+
" --save_url " + stage_2_output_dir
|
116 |
+
)
|
117 |
+
run_cmd(stage_2_command)
|
118 |
+
print("Finish Stage 2 ...")
|
119 |
+
print("\n")
|
120 |
+
|
121 |
+
# Stage 3: Face Restore
|
122 |
+
print("Running Stage 3: Face Enhancement")
|
123 |
+
os.chdir(".././Face_Enhancement")
|
124 |
+
stage_3_input_mask = "./"
|
125 |
+
stage_3_input_face = stage_2_output_dir
|
126 |
+
stage_3_output_dir = os.path.join(
|
127 |
+
opts.output_folder, "stage_3_face_output")
|
128 |
+
if not os.path.exists(stage_3_output_dir):
|
129 |
+
os.makedirs(stage_3_output_dir)
|
130 |
+
stage_3_command = (
|
131 |
+
"python test_face.py --old_face_folder "
|
132 |
+
+ stage_3_input_face
|
133 |
+
+ " --old_face_label_folder "
|
134 |
+
+ stage_3_input_mask
|
135 |
+
+ " --tensorboard_log --name "
|
136 |
+
+ opts.checkpoint_name
|
137 |
+
+ " --gpu_ids "
|
138 |
+
+ gpu1
|
139 |
+
+ " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir "
|
140 |
+
+ stage_3_output_dir
|
141 |
+
+ " --no_parsing_map"
|
142 |
+
)
|
143 |
+
run_cmd(stage_3_command)
|
144 |
+
print("Finish Stage 3 ...")
|
145 |
+
print("\n")
|
146 |
+
|
147 |
+
# Stage 4: Warp back
|
148 |
+
print("Running Stage 4: Blending")
|
149 |
+
os.chdir(".././Face_Detection")
|
150 |
+
stage_4_input_image_dir = os.path.join(
|
151 |
+
stage_1_output_dir, "restored_image")
|
152 |
+
stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img")
|
153 |
+
stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
|
154 |
+
if not os.path.exists(stage_4_output_dir):
|
155 |
+
os.makedirs(stage_4_output_dir)
|
156 |
+
stage_4_command = (
|
157 |
+
"python align_warp_back_multiple_dlib.py --origin_url "
|
158 |
+
+ stage_4_input_image_dir
|
159 |
+
+ " --replace_url "
|
160 |
+
+ stage_4_input_face_dir
|
161 |
+
+ " --save_url "
|
162 |
+
+ stage_4_output_dir
|
163 |
+
)
|
164 |
+
run_cmd(stage_4_command)
|
165 |
+
print("Finish Stage 4 ...")
|
166 |
+
print("\n")
|
167 |
+
|
168 |
+
print("All the processing is done. Please check the results.")
|
169 |
+
|
170 |
+
# --------------------------------- The GUI ---------------------------------
|
171 |
+
|
172 |
+
# First the window layout...
|
173 |
+
|
174 |
+
images_col = [[sg.Text('Input file:'), sg.In(enable_events=True, key='-IN FILE-'), sg.FileBrowse()],
|
175 |
+
[sg.Button('Modify Photo', key='-MPHOTO-'), sg.Button('Exit')],
|
176 |
+
[sg.Image(filename='', key='-IN-'), sg.Image(filename='', key='-OUT-')],]
|
177 |
+
# ----- Full layout -----
|
178 |
+
layout = [[sg.VSeperator(), sg.Column(images_col)]]
|
179 |
+
|
180 |
+
# ----- Make the window -----
|
181 |
+
window = sg.Window('Bringing-old-photos-back-to-life', layout, grab_anywhere=True)
|
182 |
+
|
183 |
+
# ----- Run the Event Loop -----
|
184 |
+
prev_filename = colorized = cap = None
|
185 |
+
while True:
|
186 |
+
event, values = window.read()
|
187 |
+
if event in (None, 'Exit'):
|
188 |
+
break
|
189 |
+
|
190 |
+
elif event == '-MPHOTO-':
|
191 |
+
try:
|
192 |
+
n1 = filename.split("/")[-2]
|
193 |
+
n2 = filename.split("/")[-3]
|
194 |
+
n3 = filename.split("/")[-1]
|
195 |
+
filename= str(f"./{n2}/{n1}")
|
196 |
+
modify(filename)
|
197 |
+
|
198 |
+
global f_image
|
199 |
+
f_image = f'./output/final_output/{n3}'
|
200 |
+
image = cv2.imread(f_image)
|
201 |
+
window['-OUT-'].update(data=cv2.imencode('.png', image)[1].tobytes())
|
202 |
+
|
203 |
+
except:
|
204 |
+
continue
|
205 |
+
|
206 |
+
elif event == '-IN FILE-': # A single filename was chosen
|
207 |
+
filename = values['-IN FILE-']
|
208 |
+
if filename != prev_filename:
|
209 |
+
prev_filename = filename
|
210 |
+
try:
|
211 |
+
image = cv2.imread(filename)
|
212 |
+
window['-IN-'].update(data=cv2.imencode('.png', image)[1].tobytes())
|
213 |
+
except:
|
214 |
+
continue
|
215 |
+
|
216 |
+
# ----- Exit program -----
|
217 |
+
window.close()
|
Global/data/Create_Bigfile.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import struct
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
|
18 |
+
def make_dataset(dir):
|
19 |
+
images = []
|
20 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
21 |
+
|
22 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
23 |
+
for fname in fnames:
|
24 |
+
if is_image_file(fname):
|
25 |
+
#print(fname)
|
26 |
+
path = os.path.join(root, fname)
|
27 |
+
images.append(path)
|
28 |
+
|
29 |
+
return images
|
30 |
+
|
31 |
+
### Modify these 3 lines in your own environment
|
32 |
+
indir="/home/ziyuwan/workspace/data/temp_old"
|
33 |
+
target_folders=['VOC','Real_L_old','Real_RGB_old']
|
34 |
+
out_dir ="/home/ziyuwan/workspace/data/temp_old"
|
35 |
+
###
|
36 |
+
|
37 |
+
if os.path.exists(out_dir) is False:
|
38 |
+
os.makedirs(out_dir)
|
39 |
+
|
40 |
+
#
|
41 |
+
for target_folder in target_folders:
|
42 |
+
curr_indir = os.path.join(indir, target_folder)
|
43 |
+
curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder)))
|
44 |
+
image_lists = make_dataset(curr_indir)
|
45 |
+
image_lists.sort()
|
46 |
+
with open(curr_out_file, 'wb') as wfid:
|
47 |
+
# write total image number
|
48 |
+
wfid.write(struct.pack('i', len(image_lists)))
|
49 |
+
for i, img_path in enumerate(image_lists):
|
50 |
+
# write file name first
|
51 |
+
img_name = os.path.basename(img_path)
|
52 |
+
img_name_bytes = img_name.encode('utf-8')
|
53 |
+
wfid.write(struct.pack('i', len(img_name_bytes)))
|
54 |
+
wfid.write(img_name_bytes)
|
55 |
+
#
|
56 |
+
# # write image data in
|
57 |
+
with open(img_path, 'rb') as img_fid:
|
58 |
+
img_bytes = img_fid.read()
|
59 |
+
wfid.write(struct.pack('i', len(img_bytes)))
|
60 |
+
wfid.write(img_bytes)
|
61 |
+
|
62 |
+
if i % 1000 == 0:
|
63 |
+
print('write %d images done' % i)
|
Global/data/Load_Bigfile.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import struct
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
class BigFileMemoryLoader(object):
|
10 |
+
def __load_bigfile(self):
|
11 |
+
print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
|
12 |
+
with open(self.file_path, 'rb') as fid:
|
13 |
+
self.img_num = struct.unpack('i', fid.read(4))[0]
|
14 |
+
self.img_names = []
|
15 |
+
self.img_bytes = []
|
16 |
+
print('find total %d images' % self.img_num)
|
17 |
+
for i in range(self.img_num):
|
18 |
+
img_name_len = struct.unpack('i', fid.read(4))[0]
|
19 |
+
img_name = fid.read(img_name_len).decode('utf-8')
|
20 |
+
self.img_names.append(img_name)
|
21 |
+
img_bytes_len = struct.unpack('i', fid.read(4))[0]
|
22 |
+
self.img_bytes.append(fid.read(img_bytes_len))
|
23 |
+
if i % 5000 == 0:
|
24 |
+
print('load %d images done' % i)
|
25 |
+
print('load all %d images done' % self.img_num)
|
26 |
+
|
27 |
+
def __init__(self, file_path):
|
28 |
+
super(BigFileMemoryLoader, self).__init__()
|
29 |
+
self.file_path = file_path
|
30 |
+
self.__load_bigfile()
|
31 |
+
|
32 |
+
def __getitem__(self, index):
|
33 |
+
try:
|
34 |
+
img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
|
35 |
+
return self.img_names[index], img
|
36 |
+
except Exception:
|
37 |
+
print('Image read error for index %d: %s' % (index, self.img_names[index]))
|
38 |
+
return self.__getitem__((index+1)%self.img_num)
|
39 |
+
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return self.img_num
|
Global/data/__init__.py
ADDED
File without changes
|