Spaces:
Configuration error
Configuration error
Upload 45 files
Browse files- .gitattributes +1 -0
- Docs/pics/img_mask_blur_21.jpg +0 -0
- Docs/pics/img_mask_blur_41.jpg +0 -0
- Docs/pics/img_mask_blur_61.jpg +0 -0
- Docs/pics/img_mask_erode_0.jpg +0 -0
- Docs/pics/img_mask_erode_20.jpg +0 -0
- Docs/pics/img_mask_erode_40.jpg +0 -0
- README.md +178 -12
- app.py +74 -0
- app_web.py +160 -0
- configs/run_image.yaml +35 -0
- configs/run_image_specific.yaml +35 -0
- configs/run_video.yaml +36 -0
- configs/run_video_specific.yaml +36 -0
- demo_file/Iron_man.jpg +0 -0
- demo_file/multi_people.jpg +0 -0
- demo_file/multi_people_1080p.mp4 +3 -0
- demo_file/multispecific/DST_01.jpg +0 -0
- demo_file/multispecific/DST_02.jpg +0 -0
- demo_file/multispecific/DST_03.jpg +0 -0
- demo_file/multispecific/SRC_01.png +0 -0
- demo_file/multispecific/SRC_02.png +0 -0
- demo_file/multispecific/SRC_03.png +0 -0
- demo_file/specific1.png +0 -0
- demo_file/specific2.png +0 -0
- demo_file/specific3.png +0 -0
- requirements.txt +9 -0
- src/Blend/blend.py +12 -0
- src/DataManager/ImageDataManager.py +42 -0
- src/DataManager/VideoDataManager.py +73 -0
- src/DataManager/base.py +16 -0
- src/DataManager/utils.py +12 -0
- src/FaceAlign/face_align.py +244 -0
- src/FaceDetector/face_detector.py +37 -0
- src/FaceId/faceid.py +50 -0
- src/Generator/fs_networks_512.py +277 -0
- src/Generator/fs_networks_fix.py +245 -0
- src/Misc/types.py +11 -0
- src/Misc/utils.py +28 -0
- src/PostProcess/GFPGAN/gfpgan.py +341 -0
- src/PostProcess/GFPGAN/stylegan2.py +351 -0
- src/PostProcess/ParsingModel/model.py +323 -0
- src/PostProcess/ParsingModel/resnet.py +109 -0
- src/PostProcess/utils.py +122 -0
- src/model_loader.py +106 -0
- src/simswap.py +322 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
simswap-inference-pytorch-main/demo_file/multi_people_1080p.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
simswap-inference-pytorch-main/demo_file/multi_people_1080p.mp4 filter=lfs diff=lfs merge=lfs -text
|
36 |
+
demo_file/multi_people_1080p.mp4 filter=lfs diff=lfs merge=lfs -text
|
Docs/pics/img_mask_blur_21.jpg
ADDED
Docs/pics/img_mask_blur_41.jpg
ADDED
Docs/pics/img_mask_blur_61.jpg
ADDED
Docs/pics/img_mask_erode_0.jpg
ADDED
Docs/pics/img_mask_erode_20.jpg
ADDED
Docs/pics/img_mask_erode_40.jpg
ADDED
README.md
CHANGED
@@ -1,12 +1,178 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unofficial Pytorch implementation (**inference only**) of the SimSwap: An Efficient Framework For High Fidelity Face Swapping
|
2 |
+
|
3 |
+
## Updates
|
4 |
+
- improved performance (up to 40% in some scenarios, it depends on frame resolution and number of swaps per frame).
|
5 |
+
- fixed a problem with overlapped areas from close faces (https://github.com/mike9251/simswap-inference-pytorch/issues/21)
|
6 |
+
- added support for using GFPGAN model as an additional post-processing step to improve final image quality
|
7 |
+
- added a toy gui app. Might be useful to understand how different pipeline settings affect output
|
8 |
+
|
9 |
+
## Attention
|
10 |
+
***This project is for technical and academic use only. Please do not apply it to illegal and unethical scenarios.***
|
11 |
+
|
12 |
+
***In the event of violation of the legal and ethical requirements of the user's country or region, this code repository is exempt from liability.***
|
13 |
+
|
14 |
+
## Preparation
|
15 |
+
### Installation
|
16 |
+
```
|
17 |
+
# clone project
|
18 |
+
git clone https://github.com/mike9251/simswap-inference-pytorch
|
19 |
+
cd simswap-inference-pytorch
|
20 |
+
|
21 |
+
# [OPTIONAL] create conda environment
|
22 |
+
conda create -n myenv python=3.9
|
23 |
+
conda activate myenv
|
24 |
+
|
25 |
+
# install pytorch and torchvision according to instructions
|
26 |
+
# https://pytorch.org/get-started/
|
27 |
+
|
28 |
+
# install requirements
|
29 |
+
pip install -r requirements.txt
|
30 |
+
```
|
31 |
+
|
32 |
+
### Important
|
33 |
+
Face detection will be performed on CPU. To run it on GPU you need to install onnx gpu runtime:
|
34 |
+
|
35 |
+
```pip install onnxruntime-gpu==1.11.1```
|
36 |
+
|
37 |
+
and modify one line of code in ```...Anaconda3\envs\myenv\Lib\site-packages\insightface\model_zoo\model_zoo.py```
|
38 |
+
|
39 |
+
Here, instead of passing **None** as the second argument to the onnx inference session
|
40 |
+
```angular2html
|
41 |
+
class ModelRouter:
|
42 |
+
def __init__(self, onnx_file):
|
43 |
+
self.onnx_file = onnx_file
|
44 |
+
|
45 |
+
def get_model(self):
|
46 |
+
session = onnxruntime.InferenceSession(self.onnx_file, None)
|
47 |
+
input_cfg = session.get_inputs()[0]
|
48 |
+
```
|
49 |
+
pass a list of providers
|
50 |
+
```angular2html
|
51 |
+
class ModelRouter:
|
52 |
+
def __init__(self, onnx_file):
|
53 |
+
self.onnx_file = onnx_file
|
54 |
+
|
55 |
+
def get_model(self):
|
56 |
+
session = onnxruntime.InferenceSession(self.onnx_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
57 |
+
input_cfg = session.get_inputs()[0]
|
58 |
+
```
|
59 |
+
Otherwise simply use CPU onnx runtime with only a minor performance drop.
|
60 |
+
|
61 |
+
### Weights
|
62 |
+
#### Weights for all models get downloaded automatically.
|
63 |
+
|
64 |
+
You can also download weights manually and put inside `weights` folder:
|
65 |
+
|
66 |
+
- weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/face_detector_scrfd_10g_bnkps.onnx">face_detector_scrfd_10g_bnkps.onnx</a>
|
67 |
+
- weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/arcface_net.jit">arcface_net.jit</a>
|
68 |
+
- weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/parsing_model_79999_iter.pth">79999_iter.pth</a>
|
69 |
+
- weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_224_latest_net_G.pth">simswap_224_latest_net_G.pth</a> - official 224x224 model
|
70 |
+
- weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_512_390000_net_G.pth">simswap_512_390000_net_G.pth</a> - unofficial 512x512 model (I took it <a href="https://github.com/neuralchen/SimSwap/issues/255">here</a>).
|
71 |
+
- weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.1/GFPGANv1.4_ema.pth">GFPGANv1.4_ema.pth</a>
|
72 |
+
- weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.2/blend_module.jit">blend_module.jit</a>
|
73 |
+
|
74 |
+
## Inference
|
75 |
+
### Web App
|
76 |
+
```angular2html
|
77 |
+
streamlit run app_web.py
|
78 |
+
```
|
79 |
+
|
80 |
+
### Command line App
|
81 |
+
This repository supports inference in several modes, which can be easily configured with config files in the **configs** folder.
|
82 |
+
- **replace all faces on a target image / folder with images**
|
83 |
+
```angular2html
|
84 |
+
python app.py --config-name=run_image.yaml
|
85 |
+
```
|
86 |
+
|
87 |
+
- **replace all faces on a video**
|
88 |
+
```angular2html
|
89 |
+
python app.py --config-name=run_video.yaml
|
90 |
+
```
|
91 |
+
|
92 |
+
- **replace a specific face on a target image / folder with images**
|
93 |
+
```angular2html
|
94 |
+
python app.py --config-name=run_image_specific.yaml
|
95 |
+
```
|
96 |
+
|
97 |
+
- **replace a specific face on a video**
|
98 |
+
```angular2html
|
99 |
+
python app.py --config-name=run_video_specific.yaml
|
100 |
+
```
|
101 |
+
|
102 |
+
Config files contain two main parts:
|
103 |
+
|
104 |
+
- **data**
|
105 |
+
- *id_image* - source image, identity of this person will be transferred.
|
106 |
+
- *att_image* - target image, attributes of the person on this image will be mixed with the person's identity from the source image. Here you can also specify a folder with multiple images - identity translation will be applied to all images in the folder.
|
107 |
+
- *specific_id_image* - a specific person on the *att_image* you would like to replace, leaving others untouched (if there's any other person).
|
108 |
+
- *att_video* - the same as *att_image*
|
109 |
+
- *clean_work_dir* - whether remove temp folder with images or not (for video configs only).
|
110 |
+
|
111 |
+
|
112 |
+
- **pipeline**
|
113 |
+
- *face_detector_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
|
114 |
+
- *face_id_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
|
115 |
+
- *parsing_model_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
|
116 |
+
- *simswap_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
|
117 |
+
- *gfpgan_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
|
118 |
+
- *device* - whether you want to run the application using GPU or CPU.
|
119 |
+
- *crop_size* - size of images SimSwap models works with.
|
120 |
+
- *checkpoint_type* - the official model works with 224x224 crops and has different pre/post processings (imagenet like). Latest official repository allows you to train your own models, but the architecture and pre/post processings are slightly different (1. removed Tanh from the last layer; 2. normalization to [0...1] range). **If you run the official 224x224 model then set this parameter to "official_224", otherwise "none".**
|
121 |
+
- *face_alignment_type* - affects reference face key points coordinates. **Possible values are "ffhq" and "none". Try both of them to see which one works better for your data.**
|
122 |
+
- *smooth_mask_kernel_size* - a non-zero value. It's used for the post-processing mask size attenuation. You might want to play with this parameter.
|
123 |
+
- *smooth_mask_iter* - a non-zero value. The number of times a face mask is smoothed.
|
124 |
+
- *smooth_mask_threshold* - controls the face mask saturation. Valid values are in range [0.0...1.0]. Tune this parameter if there are artifacts around swapped faces.
|
125 |
+
- *face_detector_threshold* - values in range [0.0...1.0]. Higher value reduces probability of FP detections but increases the probability of FN.
|
126 |
+
- *specific_latent_match_threshold* - values in range [0.0...inf]. Usually takes small values around 0.05.
|
127 |
+
- *enhance_output* - whether to apply GFPGAN model or not as a post-processing step.
|
128 |
+
|
129 |
+
|
130 |
+
### Overriding parameters with CMD
|
131 |
+
Every parameter in a config file can be overridden by specifying it directly with CMD. For example:
|
132 |
+
|
133 |
+
```angular2html
|
134 |
+
python app.py --config-name=run_image.yaml data.specific_id_image="path/to/the/image" pipeline.erosion_kernel_size=20
|
135 |
+
```
|
136 |
+
|
137 |
+
## Video
|
138 |
+
|
139 |
+
<details>
|
140 |
+
<summary><b>Official 224x224 model, face alignment "none"</b></summary>
|
141 |
+
|
142 |
+
[![Video](https://i.imgur.com/iCujdRB.jpg)](https://vimeo.com/728346715)
|
143 |
+
|
144 |
+
</details>
|
145 |
+
|
146 |
+
<details>
|
147 |
+
<summary><b>Official 224x224 model, face alignment "ffhq"</b></summary>
|
148 |
+
|
149 |
+
[![Video](https://i.imgur.com/48hjJO4.jpg)](https://vimeo.com/728348520)
|
150 |
+
|
151 |
+
</details>
|
152 |
+
|
153 |
+
<details>
|
154 |
+
<summary><b>Unofficial 512x512 model, face alignment "none"</b></summary>
|
155 |
+
|
156 |
+
[![Video](https://i.imgur.com/rRltD4U.jpg)](https://vimeo.com/728346542)
|
157 |
+
|
158 |
+
</details>
|
159 |
+
|
160 |
+
<details>
|
161 |
+
<summary><b>Unofficial 512x512 model, face alignment "ffhq"</b></summary>
|
162 |
+
|
163 |
+
[![Video](https://i.imgur.com/gFkpyXS.jpg)](https://vimeo.com/728349219)
|
164 |
+
|
165 |
+
</details>
|
166 |
+
|
167 |
+
## License
|
168 |
+
For academic and non-commercial use only.The whole project is under the CC-BY-NC 4.0 license. See [LICENSE](https://github.com/neuralchen/SimSwap/blob/main/LICENSE) for additional details.
|
169 |
+
|
170 |
+
## Acknowledgements
|
171 |
+
|
172 |
+
<!--ts-->
|
173 |
+
* [SimSwap](https://github.com/neuralchen/SimSwap)
|
174 |
+
* [Insightface](https://github.com/deepinsight/insightface)
|
175 |
+
* [Face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)
|
176 |
+
* [BiSeNet](https://github.com/CoinCheung/BiSeNet)
|
177 |
+
* [GFPGAN](https://github.com/TencentARC/GFPGAN)
|
178 |
+
<!--te-->
|
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Optional
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import hydra
|
6 |
+
from omegaconf import DictConfig
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from src.simswap import SimSwap
|
10 |
+
from src.DataManager.ImageDataManager import ImageDataManager
|
11 |
+
from src.DataManager.VideoDataManager import VideoDataManager
|
12 |
+
from src.DataManager.utils import imread_rgb
|
13 |
+
|
14 |
+
|
15 |
+
class Application:
|
16 |
+
def __init__(self, config: DictConfig):
|
17 |
+
|
18 |
+
id_image_path = Path(config.data.id_image)
|
19 |
+
specific_id_image_path = Path(config.data.specific_id_image)
|
20 |
+
att_image_path = Path(config.data.att_image)
|
21 |
+
att_video_path = Path(config.data.att_video)
|
22 |
+
output_dir = Path(config.data.output_dir)
|
23 |
+
|
24 |
+
assert id_image_path.exists(), f"Can't find {id_image_path} file!"
|
25 |
+
|
26 |
+
self.id_image: Optional[np.ndarray] = imread_rgb(id_image_path)
|
27 |
+
self.specific_id_image: Optional[np.ndarray] = (
|
28 |
+
imread_rgb(specific_id_image_path)
|
29 |
+
if specific_id_image_path and specific_id_image_path.is_file()
|
30 |
+
else None
|
31 |
+
)
|
32 |
+
|
33 |
+
self.att_image: Optional[ImageDataManager] = None
|
34 |
+
if att_image_path and (att_image_path.is_file() or att_image_path.is_dir()):
|
35 |
+
self.att_image: Optional[ImageDataManager] = ImageDataManager(
|
36 |
+
src_data=att_image_path, output_dir=output_dir
|
37 |
+
)
|
38 |
+
|
39 |
+
self.att_video: Optional[VideoDataManager] = None
|
40 |
+
if att_video_path and att_video_path.is_file():
|
41 |
+
self.att_video: Optional[VideoDataManager] = VideoDataManager(
|
42 |
+
src_data=att_video_path, output_dir=output_dir, clean_work_dir=config.data.clean_work_dir
|
43 |
+
)
|
44 |
+
|
45 |
+
assert not (self.att_video and self.att_image), "Only one attribute source can be used!"
|
46 |
+
|
47 |
+
self.data_manager = self.att_video if self.att_video else self.att_image
|
48 |
+
|
49 |
+
self.model = SimSwap(
|
50 |
+
config=config.pipeline,
|
51 |
+
id_image=self.id_image,
|
52 |
+
specific_image=self.specific_id_image,
|
53 |
+
)
|
54 |
+
|
55 |
+
def run(self):
|
56 |
+
for _ in tqdm(range(len(self.data_manager))):
|
57 |
+
|
58 |
+
att_img = self.data_manager.get()
|
59 |
+
|
60 |
+
output = self.model(att_img)
|
61 |
+
|
62 |
+
self.data_manager.save(output)
|
63 |
+
|
64 |
+
|
65 |
+
@hydra.main(config_path="configs/", config_name="run_image.yaml")
|
66 |
+
def main(config: DictConfig):
|
67 |
+
|
68 |
+
app = Application(config)
|
69 |
+
|
70 |
+
app.run()
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|
app_web.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
from io import BytesIO
|
4 |
+
from collections import namedtuple
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from src.simswap import SimSwap
|
8 |
+
|
9 |
+
|
10 |
+
def run(model):
|
11 |
+
id_image = None
|
12 |
+
attr_image = None
|
13 |
+
specific_image = None
|
14 |
+
output = None
|
15 |
+
|
16 |
+
def get_np_image(file):
|
17 |
+
return np.array(Image.open(file))[:, :, :3]
|
18 |
+
|
19 |
+
with st.sidebar:
|
20 |
+
uploaded_file = st.file_uploader("Select an ID image")
|
21 |
+
if uploaded_file is not None:
|
22 |
+
id_image = get_np_image(uploaded_file)
|
23 |
+
|
24 |
+
uploaded_file = st.file_uploader("Select an Attribute image")
|
25 |
+
if uploaded_file is not None:
|
26 |
+
attr_image = get_np_image(uploaded_file)
|
27 |
+
|
28 |
+
uploaded_file = st.file_uploader("Select a specific person image (Optional)")
|
29 |
+
if uploaded_file is not None:
|
30 |
+
specific_image = get_np_image(uploaded_file)
|
31 |
+
|
32 |
+
face_alignment_type = st.radio("Face alignment type:", ("none", "ffhq"))
|
33 |
+
|
34 |
+
enhance_output = st.radio("Enhance output:", ("yes", "no"))
|
35 |
+
|
36 |
+
smooth_mask_iter = st.slider(
|
37 |
+
label="smooth_mask_iter", min_value=1, max_value=60, step=1, value=7
|
38 |
+
)
|
39 |
+
|
40 |
+
smooth_mask_kernel_size = st.slider(
|
41 |
+
label="smooth_mask_kernel_size", min_value=1, max_value=61, step=2, value=17
|
42 |
+
)
|
43 |
+
|
44 |
+
smooth_mask_threshold = st.slider(label="smooth_mask_threshold", min_value=0.01, max_value=1.0, step=0.01, value=0.9)
|
45 |
+
|
46 |
+
specific_latent_match_threshold = st.slider(
|
47 |
+
label="specific_latent_match_threshold",
|
48 |
+
min_value=0.0,
|
49 |
+
max_value=10.0,
|
50 |
+
value=0.05,
|
51 |
+
)
|
52 |
+
|
53 |
+
num_cols = sum(
|
54 |
+
(id_image is not None, attr_image is not None, specific_image is not None)
|
55 |
+
)
|
56 |
+
cols = st.columns(num_cols if num_cols > 0 else 1)
|
57 |
+
i = 0
|
58 |
+
|
59 |
+
if id_image is not None:
|
60 |
+
with cols[i]:
|
61 |
+
i += 1
|
62 |
+
st.header("ID image")
|
63 |
+
st.image(id_image)
|
64 |
+
|
65 |
+
if attr_image is not None:
|
66 |
+
with cols[i]:
|
67 |
+
i += 1
|
68 |
+
st.header("Attribute image")
|
69 |
+
st.image(attr_image)
|
70 |
+
|
71 |
+
if specific_image is not None:
|
72 |
+
with cols[i]:
|
73 |
+
st.header("Specific image")
|
74 |
+
st.image(specific_image)
|
75 |
+
|
76 |
+
if id_image is not None and attr_image is not None:
|
77 |
+
model.set_face_alignment_type(face_alignment_type)
|
78 |
+
model.set_smooth_mask_iter(smooth_mask_iter)
|
79 |
+
model.set_smooth_mask_kernel_size(smooth_mask_kernel_size)
|
80 |
+
model.set_smooth_mask_threshold(smooth_mask_threshold)
|
81 |
+
model.set_specific_latent_match_threshold(specific_latent_match_threshold)
|
82 |
+
model.enhance_output = True if enhance_output == "yes" else False
|
83 |
+
|
84 |
+
model.specific_latent = None
|
85 |
+
model.specific_id_image = specific_image if specific_image is not None else None
|
86 |
+
|
87 |
+
model.id_latent = None
|
88 |
+
model.id_image = id_image
|
89 |
+
|
90 |
+
output = model(attr_image)
|
91 |
+
|
92 |
+
if output is not None:
|
93 |
+
with st.container():
|
94 |
+
st.header("SimSwap output")
|
95 |
+
st.image(output)
|
96 |
+
|
97 |
+
output_to_download = Image.fromarray(output.astype("uint8"), "RGB")
|
98 |
+
buf = BytesIO()
|
99 |
+
output_to_download.save(buf, format="JPEG")
|
100 |
+
|
101 |
+
st.download_button(
|
102 |
+
label="Download",
|
103 |
+
data=buf.getvalue(),
|
104 |
+
file_name="output.jpg",
|
105 |
+
mime="image/jpeg",
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
@st.cache(allow_output_mutation=True)
|
110 |
+
def load_model(config):
|
111 |
+
return SimSwap(
|
112 |
+
config=config,
|
113 |
+
id_image=None,
|
114 |
+
specific_image=None,
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
# TODO: remove it and use config files from 'configs'
|
119 |
+
Config = namedtuple(
|
120 |
+
"Config",
|
121 |
+
"face_detector_weights"
|
122 |
+
+ " face_id_weights"
|
123 |
+
+ " parsing_model_weights"
|
124 |
+
+ " simswap_weights"
|
125 |
+
+ " gfpgan_weights"
|
126 |
+
+ " blend_module_weights"
|
127 |
+
+ " device"
|
128 |
+
+ " crop_size"
|
129 |
+
+ " checkpoint_type"
|
130 |
+
+ " face_alignment_type"
|
131 |
+
+ " smooth_mask_iter"
|
132 |
+
+ " smooth_mask_kernel_size"
|
133 |
+
+ " smooth_mask_threshold"
|
134 |
+
+ " face_detector_threshold"
|
135 |
+
+ " specific_latent_match_threshold"
|
136 |
+
+ " enhance_output",
|
137 |
+
)
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
config = Config(
|
141 |
+
face_detector_weights="weights/scrfd_10g_bnkps.onnx",
|
142 |
+
face_id_weights="weights/arcface_net.jit",
|
143 |
+
parsing_model_weights="weights/79999_iter.pth",
|
144 |
+
simswap_weights="weights/latest_net_G.pth",
|
145 |
+
gfpgan_weights="weights/GFPGANv1.4_ema.pth",
|
146 |
+
blend_module_weights="weights/blend.jit",
|
147 |
+
device="cuda",
|
148 |
+
crop_size=224,
|
149 |
+
checkpoint_type="official_224",
|
150 |
+
face_alignment_type="none",
|
151 |
+
smooth_mask_iter=7,
|
152 |
+
smooth_mask_kernel_size=17,
|
153 |
+
smooth_mask_threshold=0.9,
|
154 |
+
face_detector_threshold=0.6,
|
155 |
+
specific_latent_match_threshold=0.05,
|
156 |
+
enhance_output=True
|
157 |
+
)
|
158 |
+
|
159 |
+
model = load_model(config)
|
160 |
+
run(model)
|
configs/run_image.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
|
3 |
+
att_image: "${hydra:runtime.cwd}/demo_file/multi_people.jpg"
|
4 |
+
specific_id_image: "none"
|
5 |
+
att_video: "none"
|
6 |
+
output_dir: ${hydra:runtime.cwd}/output
|
7 |
+
|
8 |
+
pipeline:
|
9 |
+
face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
|
10 |
+
face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
|
11 |
+
parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
|
12 |
+
simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
|
13 |
+
gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
|
14 |
+
blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
|
15 |
+
device: "cuda"
|
16 |
+
crop_size: 224
|
17 |
+
# it seems that the official 224 checkpoint works better with 'none' face alignment type
|
18 |
+
checkpoint_type: "official_224" #"none"
|
19 |
+
face_alignment_type: "none" #"ffhq"
|
20 |
+
smooth_mask_iter: 7
|
21 |
+
smooth_mask_kernel_size: 17
|
22 |
+
smooth_mask_threshold: 0.9
|
23 |
+
face_detector_threshold: 0.6
|
24 |
+
specific_latent_match_threshold: 0.05
|
25 |
+
enhance_output: True
|
26 |
+
|
27 |
+
defaults:
|
28 |
+
- _self_
|
29 |
+
- override hydra/hydra_logging: disabled
|
30 |
+
- override hydra/job_logging: disabled
|
31 |
+
|
32 |
+
hydra:
|
33 |
+
output_subdir: null
|
34 |
+
run:
|
35 |
+
dir: .
|
configs/run_image_specific.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
|
3 |
+
att_image: "${hydra:runtime.cwd}/demo_file/multi_people.jpg"
|
4 |
+
specific_id_image: "${hydra:runtime.cwd}/demo_file/specific1.png"
|
5 |
+
att_video: "none"
|
6 |
+
output_dir: ${hydra:runtime.cwd}/output
|
7 |
+
|
8 |
+
pipeline:
|
9 |
+
face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
|
10 |
+
face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
|
11 |
+
parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
|
12 |
+
simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
|
13 |
+
gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
|
14 |
+
blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
|
15 |
+
device: "cuda"
|
16 |
+
crop_size: 224
|
17 |
+
# it seems that the official 224 checkpoint works better with 'none' face alignment type
|
18 |
+
checkpoint_type: "official_224" #"none"
|
19 |
+
face_alignment_type: "none" #"ffhq"
|
20 |
+
smooth_mask_iter: 7
|
21 |
+
smooth_mask_kernel_size: 17
|
22 |
+
smooth_mask_threshold: 0.9
|
23 |
+
face_detector_threshold: 0.6
|
24 |
+
specific_latent_match_threshold: 0.05
|
25 |
+
enhance_output: True
|
26 |
+
|
27 |
+
defaults:
|
28 |
+
- _self_
|
29 |
+
- override hydra/hydra_logging: disabled
|
30 |
+
- override hydra/job_logging: disabled
|
31 |
+
|
32 |
+
hydra:
|
33 |
+
output_subdir: null
|
34 |
+
run:
|
35 |
+
dir: .
|
configs/run_video.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
|
3 |
+
att_image: "none"
|
4 |
+
specific_id_image: "none"
|
5 |
+
att_video: "${hydra:runtime.cwd}/demo_file/multi_people_1080p.mp4"
|
6 |
+
output_dir: ${hydra:runtime.cwd}/output
|
7 |
+
clean_work_dir: True
|
8 |
+
|
9 |
+
pipeline:
|
10 |
+
face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
|
11 |
+
face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
|
12 |
+
parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
|
13 |
+
simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
|
14 |
+
gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
|
15 |
+
blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
|
16 |
+
device: "cuda"
|
17 |
+
crop_size: 224
|
18 |
+
# it seems that the official 224 checkpoint works better with 'none' face alignment type
|
19 |
+
checkpoint_type: "official_224" #"none"
|
20 |
+
face_alignment_type: "none" #"ffhq"
|
21 |
+
smooth_mask_iter: 7
|
22 |
+
smooth_mask_kernel_size: 17
|
23 |
+
smooth_mask_threshold: 0.9
|
24 |
+
face_detector_threshold: 0.6
|
25 |
+
specific_latent_match_threshold: 0.05
|
26 |
+
enhance_output: True
|
27 |
+
|
28 |
+
defaults:
|
29 |
+
- _self_
|
30 |
+
- override hydra/hydra_logging: disabled
|
31 |
+
- override hydra/job_logging: disabled
|
32 |
+
|
33 |
+
hydra:
|
34 |
+
output_subdir: null
|
35 |
+
run:
|
36 |
+
dir: .
|
configs/run_video_specific.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
|
3 |
+
att_image: "none"
|
4 |
+
specific_id_image: "${hydra:runtime.cwd}/demo_file/specific1.png"
|
5 |
+
att_video: "${hydra:runtime.cwd}/demo_file/multi_people_1080p.mp4"
|
6 |
+
output_dir: ${hydra:runtime.cwd}/output
|
7 |
+
clean_work_dir: True
|
8 |
+
|
9 |
+
pipeline:
|
10 |
+
face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
|
11 |
+
face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
|
12 |
+
parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
|
13 |
+
simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
|
14 |
+
gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
|
15 |
+
blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
|
16 |
+
device: "cuda"
|
17 |
+
crop_size: 224
|
18 |
+
# it seems that the official 224 checkpoint works better with 'none' face alignment type
|
19 |
+
checkpoint_type: "official_224" #"none"
|
20 |
+
face_alignment_type: "none" #"ffhq"
|
21 |
+
smooth_mask_iter: 7
|
22 |
+
smooth_mask_kernel_size: 17
|
23 |
+
smooth_mask_threshold: 0.9
|
24 |
+
face_detector_threshold: 0.6
|
25 |
+
specific_latent_match_threshold: 0.05
|
26 |
+
enhance_output: True
|
27 |
+
|
28 |
+
defaults:
|
29 |
+
- _self_
|
30 |
+
- override hydra/hydra_logging: disabled
|
31 |
+
- override hydra/job_logging: disabled
|
32 |
+
|
33 |
+
hydra:
|
34 |
+
output_subdir: null
|
35 |
+
run:
|
36 |
+
dir: .
|
demo_file/Iron_man.jpg
ADDED
demo_file/multi_people.jpg
ADDED
demo_file/multi_people_1080p.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97fe960cc03abac34509ec69a68c7b75f2ca1325aea353456411fe7569d978e1
|
3 |
+
size 8735410
|
demo_file/multispecific/DST_01.jpg
ADDED
demo_file/multispecific/DST_02.jpg
ADDED
demo_file/multispecific/DST_03.jpg
ADDED
demo_file/multispecific/SRC_01.png
ADDED
demo_file/multispecific/SRC_02.png
ADDED
demo_file/multispecific/SRC_03.png
ADDED
demo_file/specific1.png
ADDED
demo_file/specific2.png
ADDED
demo_file/specific3.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra-core>=1.1.0
|
2 |
+
insightface==0.2.1
|
3 |
+
kornia==0.6.5
|
4 |
+
moviepy==1.0.3
|
5 |
+
onnx==1.12.0
|
6 |
+
onnxruntime==1.11.1
|
7 |
+
opencv-python==4.6.0.66
|
8 |
+
tqdm==4.64.0
|
9 |
+
streamlit==1.14.0
|
src/Blend/blend.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class BlendModule(nn.Module):
|
6 |
+
def __init__(self, model_path, device):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.model = torch.jit.load(model_path).to(device)
|
10 |
+
|
11 |
+
def forward(self, swap, mask, att_img):
|
12 |
+
return self.model(swap, mask, att_img)
|
src/DataManager/ImageDataManager.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.DataManager.base import BaseDataManager
|
2 |
+
from src.DataManager.utils import imread_rgb, imwrite_rgb
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
|
8 |
+
class ImageDataManager(BaseDataManager):
|
9 |
+
def __init__(self, src_data: Path, output_dir: Path):
|
10 |
+
self.output_dir: Path = output_dir
|
11 |
+
self.output_dir.mkdir(exist_ok=True)
|
12 |
+
self.output_dir = output_dir / "img"
|
13 |
+
self.output_dir.mkdir(exist_ok=True)
|
14 |
+
|
15 |
+
self.data_paths = []
|
16 |
+
if src_data.is_file():
|
17 |
+
self.data_paths.append(src_data)
|
18 |
+
elif src_data.is_dir():
|
19 |
+
self.data_paths = (
|
20 |
+
list(src_data.glob("*.jpg"))
|
21 |
+
+ list(src_data.glob("*.jpeg"))
|
22 |
+
+ list(src_data.glob("*.png"))
|
23 |
+
)
|
24 |
+
|
25 |
+
assert len(self.data_paths), "Data must be supplied!"
|
26 |
+
|
27 |
+
self.data_paths_iter = iter(self.data_paths)
|
28 |
+
|
29 |
+
self.last_idx = -1
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.data_paths)
|
33 |
+
|
34 |
+
def get(self) -> np.ndarray:
|
35 |
+
img_path = next(self.data_paths_iter)
|
36 |
+
self.last_idx += 1
|
37 |
+
return imread_rgb(img_path)
|
38 |
+
|
39 |
+
def save(self, img: np.ndarray):
|
40 |
+
filename = "swap_" + Path(self.data_paths[self.last_idx]).name
|
41 |
+
|
42 |
+
imwrite_rgb(self.output_dir / filename, img)
|
src/DataManager/VideoDataManager.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.DataManager.base import BaseDataManager
|
2 |
+
from src.DataManager.utils import imwrite_rgb
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
import shutil
|
8 |
+
from typing import Optional, Union
|
9 |
+
|
10 |
+
from moviepy.editor import AudioFileClip, VideoFileClip
|
11 |
+
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
12 |
+
|
13 |
+
|
14 |
+
class VideoDataManager(BaseDataManager):
|
15 |
+
def __init__(self, src_data: Path, output_dir: Path, clean_work_dir: bool = False):
|
16 |
+
self.video_handle: Optional[cv2.VideoCapture] = None
|
17 |
+
self.audio_handle: Optional[AudioFileClip] = None
|
18 |
+
|
19 |
+
self.output_dir = output_dir
|
20 |
+
self.output_img_dir = output_dir / "img"
|
21 |
+
self.output_dir.mkdir(exist_ok=True)
|
22 |
+
self.output_img_dir.mkdir(exist_ok=True)
|
23 |
+
self.video_name = None
|
24 |
+
self.clean_work_dir = clean_work_dir
|
25 |
+
|
26 |
+
if src_data.is_file():
|
27 |
+
self.video_name = "swap_" + src_data.name
|
28 |
+
|
29 |
+
if VideoFileClip(str(src_data)).audio is not None:
|
30 |
+
self.audio_handle = AudioFileClip(str(src_data))
|
31 |
+
|
32 |
+
self.video_handle = cv2.VideoCapture(str(src_data))
|
33 |
+
self.video_handle.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
34 |
+
|
35 |
+
self.frame_count = int(self.video_handle.get(cv2.CAP_PROP_FRAME_COUNT))
|
36 |
+
self.fps = self.video_handle.get(cv2.CAP_PROP_FPS)
|
37 |
+
|
38 |
+
self.last_idx = -1
|
39 |
+
|
40 |
+
assert self.video_handle, "Video file must be specified!"
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return self.frame_count
|
44 |
+
|
45 |
+
def get(self) -> np.ndarray:
|
46 |
+
img: Union[None, np.ndarray] = None
|
47 |
+
|
48 |
+
while img is None and self.last_idx < self.frame_count:
|
49 |
+
status, img = self.video_handle.read()
|
50 |
+
self.last_idx += 1
|
51 |
+
|
52 |
+
if img is not None:
|
53 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
54 |
+
return img
|
55 |
+
|
56 |
+
def save(self, img: np.ndarray):
|
57 |
+
filename = "frame_{:0>7d}.jpg".format(self.last_idx)
|
58 |
+
imwrite_rgb(self.output_img_dir / filename, img)
|
59 |
+
|
60 |
+
if (self.frame_count - 1) == self.last_idx:
|
61 |
+
self._close()
|
62 |
+
|
63 |
+
def _close(self):
|
64 |
+
image_filenames = [str(x) for x in sorted(self.output_img_dir.glob("*.jpg"))]
|
65 |
+
clip = ImageSequenceClip(image_filenames, fps=self.fps)
|
66 |
+
|
67 |
+
if self.audio_handle is not None:
|
68 |
+
clip = clip.set_audio(self.audio_handle)
|
69 |
+
|
70 |
+
clip.write_videofile(str(self.output_dir / self.video_name))
|
71 |
+
|
72 |
+
if self.clean_work_dir:
|
73 |
+
shutil.rmtree(self.output_img_dir, ignore_errors=True)
|
src/DataManager/base.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class BaseDataManager(ABC):
|
6 |
+
@abstractmethod
|
7 |
+
def __len__(self) -> int:
|
8 |
+
pass
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def get(self) -> np.ndarray:
|
12 |
+
pass
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def save(self, img: np.ndarray) -> None:
|
16 |
+
pass
|
src/DataManager/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
|
7 |
+
def imread_rgb(img_path: Union[str, Path]) -> np.ndarray:
|
8 |
+
return cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
|
9 |
+
|
10 |
+
|
11 |
+
def imwrite_rgb(img_path: Union[str, Path], img):
|
12 |
+
return cv2.imwrite(str(img_path), cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
src/FaceAlign/face_align.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from skimage import transform as skt
|
5 |
+
from typing import Iterable, Tuple
|
6 |
+
|
7 |
+
src1 = np.array(
|
8 |
+
[
|
9 |
+
[51.642, 50.115],
|
10 |
+
[57.617, 49.990],
|
11 |
+
[35.740, 69.007],
|
12 |
+
[51.157, 89.050],
|
13 |
+
[57.025, 89.702],
|
14 |
+
],
|
15 |
+
dtype=np.float32,
|
16 |
+
)
|
17 |
+
# <--left
|
18 |
+
src2 = np.array(
|
19 |
+
[
|
20 |
+
[45.031, 50.118],
|
21 |
+
[65.568, 50.872],
|
22 |
+
[39.677, 68.111],
|
23 |
+
[45.177, 86.190],
|
24 |
+
[64.246, 86.758],
|
25 |
+
],
|
26 |
+
dtype=np.float32,
|
27 |
+
)
|
28 |
+
|
29 |
+
# ---frontal
|
30 |
+
src3 = np.array(
|
31 |
+
[
|
32 |
+
[39.730, 51.138],
|
33 |
+
[72.270, 51.138],
|
34 |
+
[56.000, 68.493],
|
35 |
+
[42.463, 87.010],
|
36 |
+
[69.537, 87.010],
|
37 |
+
],
|
38 |
+
dtype=np.float32,
|
39 |
+
)
|
40 |
+
|
41 |
+
# -->right
|
42 |
+
src4 = np.array(
|
43 |
+
[
|
44 |
+
[46.845, 50.872],
|
45 |
+
[67.382, 50.118],
|
46 |
+
[72.737, 68.111],
|
47 |
+
[48.167, 86.758],
|
48 |
+
[67.236, 86.190],
|
49 |
+
],
|
50 |
+
dtype=np.float32,
|
51 |
+
)
|
52 |
+
|
53 |
+
# -->right profile
|
54 |
+
src5 = np.array(
|
55 |
+
[
|
56 |
+
[54.796, 49.990],
|
57 |
+
[60.771, 50.115],
|
58 |
+
[76.673, 69.007],
|
59 |
+
[55.388, 89.702],
|
60 |
+
[61.257, 89.050],
|
61 |
+
],
|
62 |
+
dtype=np.float32,
|
63 |
+
)
|
64 |
+
|
65 |
+
src = np.array([src1, src2, src3, src4, src5])
|
66 |
+
src_map = src
|
67 |
+
|
68 |
+
ffhq_src = np.array(
|
69 |
+
[
|
70 |
+
[192.98138, 239.94708],
|
71 |
+
[318.90277, 240.1936],
|
72 |
+
[256.63416, 314.01935],
|
73 |
+
[201.26117, 371.41043],
|
74 |
+
[313.08905, 371.15118],
|
75 |
+
]
|
76 |
+
)
|
77 |
+
ffhq_src = np.expand_dims(ffhq_src, axis=0)
|
78 |
+
|
79 |
+
|
80 |
+
# arcface_src = np.array(
|
81 |
+
# [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
|
82 |
+
# [41.5493, 92.3655], [70.7299, 92.2041]],
|
83 |
+
# dtype=np.float32)
|
84 |
+
|
85 |
+
# arcface_src = np.expand_dims(arcface_src, axis=0)
|
86 |
+
|
87 |
+
# In[66]:
|
88 |
+
|
89 |
+
|
90 |
+
# lmk is prediction; src is template
|
91 |
+
def estimate_norm(lmk, image_size=112, mode="ffhq"):
|
92 |
+
assert lmk.shape == (5, 2)
|
93 |
+
tform = skt.SimilarityTransform()
|
94 |
+
lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
|
95 |
+
min_M = []
|
96 |
+
min_index = []
|
97 |
+
min_error = float("inf")
|
98 |
+
if mode == "ffhq":
|
99 |
+
# assert image_size == 112
|
100 |
+
src = ffhq_src * image_size / 512
|
101 |
+
else:
|
102 |
+
src = src_map * image_size / 112
|
103 |
+
for i in np.arange(src.shape[0]):
|
104 |
+
tform.estimate(lmk, src[i])
|
105 |
+
M = tform.params[0:2, :]
|
106 |
+
results = np.dot(M, lmk_tran.T)
|
107 |
+
results = results.T
|
108 |
+
error = np.sum(np.sqrt(np.sum((results - src[i]) ** 2, axis=1)))
|
109 |
+
if error < min_error:
|
110 |
+
min_error = error
|
111 |
+
min_M = M
|
112 |
+
min_index = i
|
113 |
+
return min_M, min_index
|
114 |
+
|
115 |
+
|
116 |
+
def norm_crop(img, landmark, image_size=112, mode="ffhq"):
|
117 |
+
if mode == "Both":
|
118 |
+
M_None, _ = estimate_norm(landmark, image_size, mode="newarc")
|
119 |
+
M_ffhq, _ = estimate_norm(landmark, image_size, mode="ffhq")
|
120 |
+
warped_None = cv2.warpAffine(
|
121 |
+
img, M_None, (image_size, image_size), borderValue=0.0
|
122 |
+
)
|
123 |
+
warped_ffhq = cv2.warpAffine(
|
124 |
+
img, M_ffhq, (image_size, image_size), borderValue=0.0
|
125 |
+
)
|
126 |
+
return warped_ffhq, warped_None
|
127 |
+
else:
|
128 |
+
M, pose_index = estimate_norm(landmark, image_size, mode)
|
129 |
+
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
|
130 |
+
return warped
|
131 |
+
|
132 |
+
|
133 |
+
def square_crop(im, S):
|
134 |
+
if im.shape[0] > im.shape[1]:
|
135 |
+
height = S
|
136 |
+
width = int(float(im.shape[1]) / im.shape[0] * S)
|
137 |
+
scale = float(S) / im.shape[0]
|
138 |
+
else:
|
139 |
+
width = S
|
140 |
+
height = int(float(im.shape[0]) / im.shape[1] * S)
|
141 |
+
scale = float(S) / im.shape[1]
|
142 |
+
resized_im = cv2.resize(im, (width, height))
|
143 |
+
det_im = np.zeros((S, S, 3), dtype=np.uint8)
|
144 |
+
det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im
|
145 |
+
return det_im, scale
|
146 |
+
|
147 |
+
|
148 |
+
def transform(data, center, output_size, scale, rotation):
|
149 |
+
scale_ratio = scale
|
150 |
+
rot = float(rotation) * np.pi / 180.0
|
151 |
+
# translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
|
152 |
+
t1 = skt.SimilarityTransform(scale=scale_ratio)
|
153 |
+
cx = center[0] * scale_ratio
|
154 |
+
cy = center[1] * scale_ratio
|
155 |
+
t2 = skt.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
156 |
+
t3 = skt.SimilarityTransform(rotation=rot)
|
157 |
+
t4 = skt.SimilarityTransform(translation=(output_size / 2, output_size / 2))
|
158 |
+
t = t1 + t2 + t3 + t4
|
159 |
+
M = t.params[0:2]
|
160 |
+
cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0)
|
161 |
+
return cropped, M
|
162 |
+
|
163 |
+
|
164 |
+
def trans_points2d(pts, M):
|
165 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
166 |
+
for i in range(pts.shape[0]):
|
167 |
+
pt = pts[i]
|
168 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
169 |
+
new_pt = np.dot(M, new_pt)
|
170 |
+
# print('new_pt', new_pt.shape, new_pt)
|
171 |
+
new_pts[i] = new_pt[0:2]
|
172 |
+
|
173 |
+
return new_pts
|
174 |
+
|
175 |
+
|
176 |
+
def trans_points3d(pts, M):
|
177 |
+
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
|
178 |
+
# print(scale)
|
179 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
180 |
+
for i in range(pts.shape[0]):
|
181 |
+
pt = pts[i]
|
182 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
183 |
+
new_pt = np.dot(M, new_pt)
|
184 |
+
# print('new_pt', new_pt.shape, new_pt)
|
185 |
+
new_pts[i][0:2] = new_pt[0:2]
|
186 |
+
new_pts[i][2] = pts[i][2] * scale
|
187 |
+
|
188 |
+
return new_pts
|
189 |
+
|
190 |
+
|
191 |
+
def trans_points(pts, M):
|
192 |
+
if pts.shape[1] == 2:
|
193 |
+
return trans_points2d(pts, M)
|
194 |
+
else:
|
195 |
+
return trans_points3d(pts, M)
|
196 |
+
|
197 |
+
|
198 |
+
def inverse_transform(mat: np.ndarray) -> np.ndarray:
|
199 |
+
# inverse the Affine transformation matrix
|
200 |
+
inv_mat = np.zeros([2, 3])
|
201 |
+
div1 = mat[0][0] * mat[1][1] - mat[0][1] * mat[1][0]
|
202 |
+
inv_mat[0][0] = mat[1][1] / div1
|
203 |
+
inv_mat[0][1] = -mat[0][1] / div1
|
204 |
+
inv_mat[0][2] = -(mat[0][2] * mat[1][1] - mat[0][1] * mat[1][2]) / div1
|
205 |
+
div2 = mat[0][1] * mat[1][0] - mat[0][0] * mat[1][1]
|
206 |
+
inv_mat[1][0] = mat[1][0] / div2
|
207 |
+
inv_mat[1][1] = -mat[0][0] / div2
|
208 |
+
inv_mat[1][2] = -(mat[0][2] * mat[1][0] - mat[0][0] * mat[1][2]) / div2
|
209 |
+
return inv_mat
|
210 |
+
|
211 |
+
|
212 |
+
def inverse_transform_batch(mat: torch.Tensor) -> torch.Tensor:
|
213 |
+
# inverse the Affine transformation matrix
|
214 |
+
inv_mat = torch.zeros_like(mat)
|
215 |
+
div1 = mat[:, 0, 0] * mat[:, 1, 1] - mat[:, 0, 1] * mat[:, 1, 0]
|
216 |
+
inv_mat[:, 0, 0] = mat[:, 1, 1] / div1
|
217 |
+
inv_mat[:, 0, 1] = -mat[:, 0, 1] / div1
|
218 |
+
inv_mat[:, 0, 2] = (
|
219 |
+
-(mat[:, 0, 2] * mat[:, 1, 1] - mat[:, 0, 1] * mat[:, 1, 2]) / div1
|
220 |
+
)
|
221 |
+
div2 = mat[:, 0, 1] * mat[:, 1, 0] - mat[:, 0, 0] * mat[:, 1, 1]
|
222 |
+
inv_mat[:, 1, 0] = mat[:, 1, 0] / div2
|
223 |
+
inv_mat[:, 1, 1] = -mat[:, 0, 0] / div2
|
224 |
+
inv_mat[:, 1, 2] = (
|
225 |
+
-(mat[:, 0, 2] * mat[:, 1, 0] - mat[:, 0, 0] * mat[:, 1, 2]) / div2
|
226 |
+
)
|
227 |
+
return inv_mat
|
228 |
+
|
229 |
+
|
230 |
+
def align_face(
|
231 |
+
img: np.ndarray, key_points: np.ndarray, crop_size: int, mode: str = "ffhq"
|
232 |
+
) -> Tuple[Iterable[np.ndarray], Iterable[np.ndarray]]:
|
233 |
+
align_imgs = []
|
234 |
+
transforms = []
|
235 |
+
for i in range(key_points.shape[0]):
|
236 |
+
kps = key_points[i]
|
237 |
+
transform_matrix, _ = estimate_norm(kps, crop_size, mode=mode)
|
238 |
+
align_img = cv2.warpAffine(
|
239 |
+
img, transform_matrix, (crop_size, crop_size), borderValue=0.0
|
240 |
+
)
|
241 |
+
align_imgs.append(align_img)
|
242 |
+
transforms.append(transform_matrix)
|
243 |
+
|
244 |
+
return align_imgs, transforms
|
src/FaceDetector/face_detector.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple, Optional, Tuple
|
2 |
+
|
3 |
+
from insightface.model_zoo import model_zoo
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
|
8 |
+
class Detection(NamedTuple):
|
9 |
+
bbox: Optional[np.ndarray]
|
10 |
+
score: Optional[np.ndarray]
|
11 |
+
key_points: Optional[np.ndarray]
|
12 |
+
|
13 |
+
|
14 |
+
class FaceDetector:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
model_path: Path,
|
18 |
+
det_thresh: float = 0.5,
|
19 |
+
det_size: Tuple[int, int] = (640, 640),
|
20 |
+
mode: str = "None",
|
21 |
+
device: str = "cpu",
|
22 |
+
):
|
23 |
+
self.det_thresh = det_thresh
|
24 |
+
self.mode = mode
|
25 |
+
self.device = device
|
26 |
+
self.handler = model_zoo.get_model(str(model_path))
|
27 |
+
ctx_id = -1 if device == "cpu" else 0
|
28 |
+
self.handler.prepare(ctx_id, input_size=det_size)
|
29 |
+
|
30 |
+
def __call__(self, img: np.ndarray, max_num: int = 0) -> Detection:
|
31 |
+
bboxes, kpss = self.handler.detect(
|
32 |
+
img, threshold=self.det_thresh, max_num=max_num, metric="default"
|
33 |
+
)
|
34 |
+
if bboxes.shape[0] == 0:
|
35 |
+
return Detection(None, None, None)
|
36 |
+
|
37 |
+
return Detection(bboxes[..., :-1], bboxes[..., -1], kpss)
|
src/FaceId/faceid.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
from typing import Iterable, Union
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
|
11 |
+
class FaceId(torch.nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self, model_path: Path, device: str, input_shape: Iterable[int] = (112, 112)
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.input_shape = input_shape
|
18 |
+
self.net = torch.load(model_path, map_location=torch.device("cpu"))
|
19 |
+
self.net.eval()
|
20 |
+
|
21 |
+
self.transform = transforms.Compose(
|
22 |
+
[
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
25 |
+
]
|
26 |
+
)
|
27 |
+
|
28 |
+
for n, p in self.net.named_parameters():
|
29 |
+
assert (
|
30 |
+
not p.requires_grad
|
31 |
+
), f"Parameter {n}: requires_grad: {p.requires_grad}"
|
32 |
+
|
33 |
+
self.device = torch.device(device)
|
34 |
+
self.to(self.device)
|
35 |
+
|
36 |
+
def forward(
|
37 |
+
self, img_id: Union[np.ndarray, Iterable[np.ndarray]], normalize: bool = True
|
38 |
+
) -> torch.Tensor:
|
39 |
+
if isinstance(img_id, Iterable):
|
40 |
+
img_id = [self.transform(x) for x in img_id]
|
41 |
+
img_id = torch.stack(img_id, dim=0)
|
42 |
+
else:
|
43 |
+
img_id = self.transform(img_id)
|
44 |
+
img_id = img_id.unsqueeze(0)
|
45 |
+
|
46 |
+
img_id = img_id.to(self.device)
|
47 |
+
|
48 |
+
img_id_112 = F.interpolate(img_id, size=self.input_shape)
|
49 |
+
latent_id = self.net(img_id_112)
|
50 |
+
return F.normalize(latent_id, p=2, dim=1) if normalize else latent_id
|
src/Generator/fs_networks_512.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Naiyuan liu
|
3 |
+
Github: https://github.com/NNNNAI
|
4 |
+
Date: 2021-11-23 16:55:48
|
5 |
+
LastEditors: Naiyuan liu
|
6 |
+
LastEditTime: 2021-11-24 16:58:06
|
7 |
+
Description:
|
8 |
+
"""
|
9 |
+
"""
|
10 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
11 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
12 |
+
"""
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
class InstanceNorm(nn.Module):
|
18 |
+
def __init__(self, epsilon=1e-8):
|
19 |
+
"""
|
20 |
+
@notice: avoid in-place ops.
|
21 |
+
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
22 |
+
"""
|
23 |
+
super(InstanceNorm, self).__init__()
|
24 |
+
self.epsilon = epsilon
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = x - torch.mean(x, (2, 3), True)
|
28 |
+
tmp = torch.mul(x, x) # or x ** 2
|
29 |
+
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
30 |
+
return x * tmp
|
31 |
+
|
32 |
+
|
33 |
+
class ApplyStyle(nn.Module):
|
34 |
+
"""
|
35 |
+
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, latent_size, channels):
|
39 |
+
super(ApplyStyle, self).__init__()
|
40 |
+
self.linear = nn.Linear(latent_size, channels * 2)
|
41 |
+
|
42 |
+
def forward(self, x, latent):
|
43 |
+
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
44 |
+
shape = [-1, 2, x.size(1), 1, 1]
|
45 |
+
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
46 |
+
# x = x * (style[:, 0] + 1.) + style[:, 1]
|
47 |
+
x = x * (style[:, 0] * 1 + 1.0) + style[:, 1] * 1
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class ResnetBlock_Adain(nn.Module):
|
52 |
+
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
53 |
+
super(ResnetBlock_Adain, self).__init__()
|
54 |
+
|
55 |
+
p = 0
|
56 |
+
conv1 = []
|
57 |
+
if padding_type == "reflect":
|
58 |
+
conv1 += [nn.ReflectionPad2d(1)]
|
59 |
+
elif padding_type == "replicate":
|
60 |
+
conv1 += [nn.ReplicationPad2d(1)]
|
61 |
+
elif padding_type == "zero":
|
62 |
+
p = 1
|
63 |
+
else:
|
64 |
+
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
65 |
+
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
66 |
+
self.conv1 = nn.Sequential(*conv1)
|
67 |
+
self.style1 = ApplyStyle(latent_size, dim)
|
68 |
+
self.act1 = activation
|
69 |
+
|
70 |
+
p = 0
|
71 |
+
conv2 = []
|
72 |
+
if padding_type == "reflect":
|
73 |
+
conv2 += [nn.ReflectionPad2d(1)]
|
74 |
+
elif padding_type == "replicate":
|
75 |
+
conv2 += [nn.ReplicationPad2d(1)]
|
76 |
+
elif padding_type == "zero":
|
77 |
+
p = 1
|
78 |
+
else:
|
79 |
+
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
80 |
+
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
81 |
+
self.conv2 = nn.Sequential(*conv2)
|
82 |
+
self.style2 = ApplyStyle(latent_size, dim)
|
83 |
+
|
84 |
+
def forward(self, x, dlatents_in_slice):
|
85 |
+
y = self.conv1(x)
|
86 |
+
y = self.style1(y, dlatents_in_slice)
|
87 |
+
y = self.act1(y)
|
88 |
+
y = self.conv2(y)
|
89 |
+
y = self.style2(y, dlatents_in_slice)
|
90 |
+
out = x + y
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class Generator_Adain_Upsample(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
input_nc,
|
98 |
+
output_nc,
|
99 |
+
latent_size,
|
100 |
+
n_blocks=6,
|
101 |
+
deep=False,
|
102 |
+
norm_layer=nn.BatchNorm2d,
|
103 |
+
padding_type="reflect",
|
104 |
+
):
|
105 |
+
assert n_blocks >= 0
|
106 |
+
super(Generator_Adain_Upsample, self).__init__()
|
107 |
+
activation = nn.ReLU(True)
|
108 |
+
self.deep = deep
|
109 |
+
|
110 |
+
self.first_layer = nn.Sequential(
|
111 |
+
nn.ReflectionPad2d(3),
|
112 |
+
nn.Conv2d(input_nc, 32, kernel_size=7, padding=0),
|
113 |
+
norm_layer(32),
|
114 |
+
activation,
|
115 |
+
)
|
116 |
+
# downsample
|
117 |
+
self.down0 = nn.Sequential(
|
118 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
119 |
+
norm_layer(64),
|
120 |
+
activation,
|
121 |
+
)
|
122 |
+
self.down1 = nn.Sequential(
|
123 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
124 |
+
norm_layer(128),
|
125 |
+
activation,
|
126 |
+
)
|
127 |
+
self.down2 = nn.Sequential(
|
128 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
129 |
+
norm_layer(256),
|
130 |
+
activation,
|
131 |
+
)
|
132 |
+
self.down3 = nn.Sequential(
|
133 |
+
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
134 |
+
norm_layer(512),
|
135 |
+
activation,
|
136 |
+
)
|
137 |
+
if self.deep:
|
138 |
+
self.down4 = nn.Sequential(
|
139 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
140 |
+
norm_layer(512),
|
141 |
+
activation,
|
142 |
+
)
|
143 |
+
|
144 |
+
# resnet blocks
|
145 |
+
BN = []
|
146 |
+
for i in range(n_blocks):
|
147 |
+
BN += [
|
148 |
+
ResnetBlock_Adain(
|
149 |
+
512,
|
150 |
+
latent_size=latent_size,
|
151 |
+
padding_type=padding_type,
|
152 |
+
activation=activation,
|
153 |
+
)
|
154 |
+
]
|
155 |
+
self.BottleNeck = nn.Sequential(*BN)
|
156 |
+
|
157 |
+
if self.deep:
|
158 |
+
self.up4 = nn.Sequential(
|
159 |
+
nn.Upsample(scale_factor=2, mode="bilinear"),
|
160 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
161 |
+
nn.BatchNorm2d(512),
|
162 |
+
activation,
|
163 |
+
)
|
164 |
+
self.up3 = nn.Sequential(
|
165 |
+
nn.Upsample(scale_factor=2, mode="bilinear"),
|
166 |
+
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
167 |
+
nn.BatchNorm2d(256),
|
168 |
+
activation,
|
169 |
+
)
|
170 |
+
self.up2 = nn.Sequential(
|
171 |
+
nn.Upsample(scale_factor=2, mode="bilinear"),
|
172 |
+
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
173 |
+
nn.BatchNorm2d(128),
|
174 |
+
activation,
|
175 |
+
)
|
176 |
+
self.up1 = nn.Sequential(
|
177 |
+
nn.Upsample(scale_factor=2, mode="bilinear"),
|
178 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
179 |
+
nn.BatchNorm2d(64),
|
180 |
+
activation,
|
181 |
+
)
|
182 |
+
self.up0 = nn.Sequential(
|
183 |
+
nn.Upsample(scale_factor=2, mode="bilinear"),
|
184 |
+
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
|
185 |
+
nn.BatchNorm2d(32),
|
186 |
+
activation,
|
187 |
+
)
|
188 |
+
self.last_layer = nn.Sequential(
|
189 |
+
nn.ReflectionPad2d(3),
|
190 |
+
nn.Conv2d(32, output_nc, kernel_size=7, padding=0),
|
191 |
+
nn.Tanh(),
|
192 |
+
)
|
193 |
+
|
194 |
+
def forward(self, input, dlatents):
|
195 |
+
x = input # 3*224*224
|
196 |
+
|
197 |
+
skip0 = self.first_layer(x)
|
198 |
+
skip1 = self.down0(skip0)
|
199 |
+
skip2 = self.down1(skip1)
|
200 |
+
skip3 = self.down2(skip2)
|
201 |
+
if self.deep:
|
202 |
+
skip4 = self.down3(skip3)
|
203 |
+
x = self.down4(skip4)
|
204 |
+
else:
|
205 |
+
x = self.down3(skip3)
|
206 |
+
|
207 |
+
for i in range(len(self.BottleNeck)):
|
208 |
+
x = self.BottleNeck[i](x, dlatents)
|
209 |
+
|
210 |
+
if self.deep:
|
211 |
+
x = self.up4(x)
|
212 |
+
x = self.up3(x)
|
213 |
+
x = self.up2(x)
|
214 |
+
x = self.up1(x)
|
215 |
+
x = self.up0(x)
|
216 |
+
x = self.last_layer(x)
|
217 |
+
x = (x + 1) / 2
|
218 |
+
|
219 |
+
return x
|
220 |
+
|
221 |
+
|
222 |
+
class Discriminator(nn.Module):
|
223 |
+
def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
|
224 |
+
super(Discriminator, self).__init__()
|
225 |
+
|
226 |
+
kw = 4
|
227 |
+
padw = 1
|
228 |
+
self.down1 = nn.Sequential(
|
229 |
+
nn.Conv2d(input_nc, 64, kernel_size=kw, stride=2, padding=padw),
|
230 |
+
nn.LeakyReLU(0.2, True),
|
231 |
+
)
|
232 |
+
self.down2 = nn.Sequential(
|
233 |
+
nn.Conv2d(64, 128, kernel_size=kw, stride=2, padding=padw),
|
234 |
+
norm_layer(128),
|
235 |
+
nn.LeakyReLU(0.2, True),
|
236 |
+
)
|
237 |
+
self.down3 = nn.Sequential(
|
238 |
+
nn.Conv2d(128, 256, kernel_size=kw, stride=2, padding=padw),
|
239 |
+
norm_layer(256),
|
240 |
+
nn.LeakyReLU(0.2, True),
|
241 |
+
)
|
242 |
+
self.down4 = nn.Sequential(
|
243 |
+
nn.Conv2d(256, 512, kernel_size=kw, stride=2, padding=padw),
|
244 |
+
norm_layer(512),
|
245 |
+
nn.LeakyReLU(0.2, True),
|
246 |
+
)
|
247 |
+
self.conv1 = nn.Sequential(
|
248 |
+
nn.Conv2d(512, 512, kernel_size=kw, stride=1, padding=padw),
|
249 |
+
norm_layer(512),
|
250 |
+
nn.LeakyReLU(0.2, True),
|
251 |
+
)
|
252 |
+
|
253 |
+
if use_sigmoid:
|
254 |
+
self.conv2 = nn.Sequential(
|
255 |
+
nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw), nn.Sigmoid()
|
256 |
+
)
|
257 |
+
else:
|
258 |
+
self.conv2 = nn.Sequential(
|
259 |
+
nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw)
|
260 |
+
)
|
261 |
+
|
262 |
+
def forward(self, input):
|
263 |
+
out = []
|
264 |
+
x = self.down1(input)
|
265 |
+
out.append(x)
|
266 |
+
x = self.down2(x)
|
267 |
+
out.append(x)
|
268 |
+
x = self.down3(x)
|
269 |
+
out.append(x)
|
270 |
+
x = self.down4(x)
|
271 |
+
out.append(x)
|
272 |
+
x = self.conv1(x)
|
273 |
+
out.append(x)
|
274 |
+
x = self.conv2(x)
|
275 |
+
out.append(x)
|
276 |
+
|
277 |
+
return out
|
src/Generator/fs_networks_fix.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from typing import Iterable
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
class InstanceNorm(nn.Module):
|
15 |
+
def __init__(self, epsilon=1e-8):
|
16 |
+
"""
|
17 |
+
@notice: avoid in-place ops.
|
18 |
+
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
19 |
+
"""
|
20 |
+
super(InstanceNorm, self).__init__()
|
21 |
+
self.epsilon = epsilon
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = x - torch.mean(x, (2, 3), True)
|
25 |
+
tmp = torch.mul(x, x) # or x ** 2
|
26 |
+
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
27 |
+
return x * tmp
|
28 |
+
|
29 |
+
|
30 |
+
class ApplyStyle(nn.Module):
|
31 |
+
"""
|
32 |
+
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, latent_size, channels):
|
36 |
+
super(ApplyStyle, self).__init__()
|
37 |
+
self.linear = nn.Linear(latent_size, channels * 2)
|
38 |
+
|
39 |
+
def forward(self, x, latent):
|
40 |
+
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
41 |
+
shape = [-1, 2, x.size(1), 1, 1]
|
42 |
+
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
43 |
+
# x = x * (style[:, 0] + 1.) + style[:, 1]
|
44 |
+
x = x * (style[:, 0] * 1 + 1.0) + style[:, 1] * 1
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class ResnetBlock_Adain(nn.Module):
|
49 |
+
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
50 |
+
super(ResnetBlock_Adain, self).__init__()
|
51 |
+
|
52 |
+
p = 0
|
53 |
+
conv1 = []
|
54 |
+
if padding_type == "reflect":
|
55 |
+
conv1 += [nn.ReflectionPad2d(1)]
|
56 |
+
elif padding_type == "replicate":
|
57 |
+
conv1 += [nn.ReplicationPad2d(1)]
|
58 |
+
elif padding_type == "zero":
|
59 |
+
p = 1
|
60 |
+
else:
|
61 |
+
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
62 |
+
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
63 |
+
self.conv1 = nn.Sequential(*conv1)
|
64 |
+
self.style1 = ApplyStyle(latent_size, dim)
|
65 |
+
self.act1 = activation
|
66 |
+
|
67 |
+
p = 0
|
68 |
+
conv2 = []
|
69 |
+
if padding_type == "reflect":
|
70 |
+
conv2 += [nn.ReflectionPad2d(1)]
|
71 |
+
elif padding_type == "replicate":
|
72 |
+
conv2 += [nn.ReplicationPad2d(1)]
|
73 |
+
elif padding_type == "zero":
|
74 |
+
p = 1
|
75 |
+
else:
|
76 |
+
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
77 |
+
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
78 |
+
self.conv2 = nn.Sequential(*conv2)
|
79 |
+
self.style2 = ApplyStyle(latent_size, dim)
|
80 |
+
|
81 |
+
def forward(self, x, dlatents_in_slice):
|
82 |
+
y = self.conv1(x)
|
83 |
+
y = self.style1(y, dlatents_in_slice)
|
84 |
+
y = self.act1(y)
|
85 |
+
y = self.conv2(y)
|
86 |
+
y = self.style2(y, dlatents_in_slice)
|
87 |
+
out = x + y
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
class Generator_Adain_Upsample(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
input_nc: int,
|
95 |
+
output_nc: int,
|
96 |
+
latent_size: int,
|
97 |
+
n_blocks: int = 6,
|
98 |
+
deep: bool = False,
|
99 |
+
use_last_act: bool = True,
|
100 |
+
norm_layer: torch.nn.Module = nn.BatchNorm2d,
|
101 |
+
padding_type: str = "reflect",
|
102 |
+
):
|
103 |
+
assert n_blocks >= 0
|
104 |
+
super(Generator_Adain_Upsample, self).__init__()
|
105 |
+
|
106 |
+
activation = nn.ReLU(True)
|
107 |
+
|
108 |
+
self.deep = deep
|
109 |
+
self.use_last_act = use_last_act
|
110 |
+
|
111 |
+
self.to_tensor_normalize = transforms.Compose(
|
112 |
+
[
|
113 |
+
transforms.ToTensor(),
|
114 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
115 |
+
]
|
116 |
+
)
|
117 |
+
|
118 |
+
self.to_tensor = transforms.Compose([transforms.ToTensor()])
|
119 |
+
|
120 |
+
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
121 |
+
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
122 |
+
|
123 |
+
self.first_layer = nn.Sequential(
|
124 |
+
nn.ReflectionPad2d(3),
|
125 |
+
nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
|
126 |
+
norm_layer(64),
|
127 |
+
activation,
|
128 |
+
)
|
129 |
+
# downsample
|
130 |
+
self.down1 = nn.Sequential(
|
131 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
132 |
+
norm_layer(128),
|
133 |
+
activation,
|
134 |
+
)
|
135 |
+
self.down2 = nn.Sequential(
|
136 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
137 |
+
norm_layer(256),
|
138 |
+
activation,
|
139 |
+
)
|
140 |
+
self.down3 = nn.Sequential(
|
141 |
+
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
142 |
+
norm_layer(512),
|
143 |
+
activation,
|
144 |
+
)
|
145 |
+
|
146 |
+
if self.deep:
|
147 |
+
self.down4 = nn.Sequential(
|
148 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
149 |
+
norm_layer(512),
|
150 |
+
activation,
|
151 |
+
)
|
152 |
+
|
153 |
+
# resnet blocks
|
154 |
+
BN = []
|
155 |
+
for i in range(n_blocks):
|
156 |
+
BN += [
|
157 |
+
ResnetBlock_Adain(
|
158 |
+
512,
|
159 |
+
latent_size=latent_size,
|
160 |
+
padding_type=padding_type,
|
161 |
+
activation=activation,
|
162 |
+
)
|
163 |
+
]
|
164 |
+
self.BottleNeck = nn.Sequential(*BN)
|
165 |
+
|
166 |
+
if self.deep:
|
167 |
+
self.up4 = nn.Sequential(
|
168 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
169 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
170 |
+
nn.BatchNorm2d(512),
|
171 |
+
activation,
|
172 |
+
)
|
173 |
+
self.up3 = nn.Sequential(
|
174 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
175 |
+
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
176 |
+
nn.BatchNorm2d(256),
|
177 |
+
activation,
|
178 |
+
)
|
179 |
+
self.up2 = nn.Sequential(
|
180 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
181 |
+
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
182 |
+
nn.BatchNorm2d(128),
|
183 |
+
activation,
|
184 |
+
)
|
185 |
+
self.up1 = nn.Sequential(
|
186 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
187 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
188 |
+
nn.BatchNorm2d(64),
|
189 |
+
activation,
|
190 |
+
)
|
191 |
+
if self.use_last_act:
|
192 |
+
self.last_layer = nn.Sequential(
|
193 |
+
nn.ReflectionPad2d(3),
|
194 |
+
nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
|
195 |
+
torch.nn.Tanh(),
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
self.last_layer = nn.Sequential(
|
199 |
+
nn.ReflectionPad2d(3),
|
200 |
+
nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
|
201 |
+
)
|
202 |
+
|
203 |
+
def to(self, device):
|
204 |
+
super().to(device)
|
205 |
+
self.device = device
|
206 |
+
self.imagenet_mean = self.imagenet_mean.to(device)
|
207 |
+
self.imagenet_std = self.imagenet_std.to(device)
|
208 |
+
return self
|
209 |
+
|
210 |
+
def forward(self, x: Iterable[np.ndarray], dlatents: torch.Tensor):
|
211 |
+
if self.use_last_act:
|
212 |
+
x = [self.to_tensor(_) for _ in x]
|
213 |
+
else:
|
214 |
+
x = [self.to_tensor_normalize(_) for _ in x]
|
215 |
+
|
216 |
+
x = torch.stack(x, dim=0)
|
217 |
+
|
218 |
+
x = x.to(self.device)
|
219 |
+
|
220 |
+
skip1 = self.first_layer(x)
|
221 |
+
skip2 = self.down1(skip1)
|
222 |
+
skip3 = self.down2(skip2)
|
223 |
+
if self.deep:
|
224 |
+
skip4 = self.down3(skip3)
|
225 |
+
x = self.down4(skip4)
|
226 |
+
else:
|
227 |
+
x = self.down3(skip3)
|
228 |
+
|
229 |
+
for i in range(len(self.BottleNeck)):
|
230 |
+
x = self.BottleNeck[i](x, dlatents)
|
231 |
+
|
232 |
+
if self.deep:
|
233 |
+
x = self.up4(x)
|
234 |
+
|
235 |
+
x = self.up3(x)
|
236 |
+
x = self.up2(x)
|
237 |
+
x = self.up1(x)
|
238 |
+
x = self.last_layer(x)
|
239 |
+
|
240 |
+
if self.use_last_act:
|
241 |
+
x = (x + 1) / 2
|
242 |
+
else:
|
243 |
+
x = x * self.imagenet_std + self.imagenet_mean
|
244 |
+
|
245 |
+
return x
|
src/Misc/types.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class CheckpointType(Enum):
|
5 |
+
OFFICIAL_224 = "official_224"
|
6 |
+
UNOFFICIAL = "none"
|
7 |
+
|
8 |
+
|
9 |
+
class FaceAlignmentType(Enum):
|
10 |
+
FFHQ = "ffhq"
|
11 |
+
DEFAULT = "none"
|
src/Misc/utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def tensor2img_denorm(tensor):
|
7 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
8 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
9 |
+
tensor = std * tensor.detach().cpu() + mean
|
10 |
+
img = tensor.numpy()
|
11 |
+
img = img.transpose(0, 2, 3, 1)[0]
|
12 |
+
img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
|
13 |
+
return img
|
14 |
+
|
15 |
+
|
16 |
+
def tensor2img(tensor):
|
17 |
+
tensor = tensor.detach().cpu().numpy()
|
18 |
+
img = tensor.transpose(0, 2, 3, 1)[0]
|
19 |
+
img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
|
20 |
+
return img
|
21 |
+
|
22 |
+
|
23 |
+
def show_tensor(tensor, name):
|
24 |
+
img = cv2.cvtColor(tensor2img(tensor), cv2.COLOR_RGB2BGR)
|
25 |
+
|
26 |
+
cv2.namedWindow(name, cv2.WINDOW_NORMAL)
|
27 |
+
cv2.imshow(name, img)
|
28 |
+
cv2.waitKey()
|
src/PostProcess/GFPGAN/gfpgan.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from src.PostProcess.GFPGAN.stylegan2 import StyleGAN2GeneratorClean
|
7 |
+
|
8 |
+
|
9 |
+
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
10 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
11 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
12 |
+
Args:
|
13 |
+
out_size (int): The spatial size of outputs.
|
14 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
15 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
16 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
17 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
18 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
22 |
+
super(StyleGAN2GeneratorCSFT, self).__init__(
|
23 |
+
out_size,
|
24 |
+
num_style_feat=num_style_feat,
|
25 |
+
num_mlp=num_mlp,
|
26 |
+
channel_multiplier=channel_multiplier,
|
27 |
+
narrow=narrow)
|
28 |
+
self.sft_half = sft_half
|
29 |
+
|
30 |
+
def forward(self,
|
31 |
+
styles,
|
32 |
+
conditions,
|
33 |
+
input_is_latent=False,
|
34 |
+
noise=None,
|
35 |
+
randomize_noise=True,
|
36 |
+
truncation=1,
|
37 |
+
truncation_latent=None,
|
38 |
+
inject_index=None,
|
39 |
+
return_latents=False):
|
40 |
+
"""Forward function for StyleGAN2GeneratorCSFT.
|
41 |
+
Args:
|
42 |
+
styles (list[Tensor]): Sample codes of styles.
|
43 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
44 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
45 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
46 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
47 |
+
truncation (float): The truncation ratio. Default: 1.
|
48 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
49 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
50 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
51 |
+
"""
|
52 |
+
# style codes -> latents with Style MLP layer
|
53 |
+
if not input_is_latent:
|
54 |
+
styles = [self.style_mlp(s) for s in styles]
|
55 |
+
# noises
|
56 |
+
if noise is None:
|
57 |
+
if randomize_noise:
|
58 |
+
noise = [None] * self.num_layers # for each style conv layer
|
59 |
+
else: # use the stored noise
|
60 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
61 |
+
# style truncation
|
62 |
+
if truncation < 1:
|
63 |
+
style_truncation = []
|
64 |
+
for style in styles:
|
65 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
66 |
+
styles = style_truncation
|
67 |
+
# get style latents with injection
|
68 |
+
if len(styles) == 1:
|
69 |
+
inject_index = self.num_latent
|
70 |
+
|
71 |
+
if styles[0].ndim < 3:
|
72 |
+
# repeat latent code for all the layers
|
73 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
74 |
+
else: # used for encoder with different latent code for each layer
|
75 |
+
latent = styles[0]
|
76 |
+
elif len(styles) == 2: # mixing noises
|
77 |
+
if inject_index is None:
|
78 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
79 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
80 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
81 |
+
latent = torch.cat([latent1, latent2], 1)
|
82 |
+
|
83 |
+
# main generation
|
84 |
+
out = self.constant_input(latent.shape[0])
|
85 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
86 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
87 |
+
|
88 |
+
i = 1
|
89 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
90 |
+
noise[2::2], self.to_rgbs):
|
91 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
92 |
+
|
93 |
+
# the conditions may have fewer levels
|
94 |
+
if i < len(conditions):
|
95 |
+
# SFT part to combine the conditions
|
96 |
+
if self.sft_half: # only apply SFT to half of the channels
|
97 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
98 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
99 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
100 |
+
else: # apply SFT to all the channels
|
101 |
+
out = out * conditions[i - 1] + conditions[i]
|
102 |
+
|
103 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
104 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
105 |
+
i += 2
|
106 |
+
|
107 |
+
image = skip
|
108 |
+
|
109 |
+
if return_latents:
|
110 |
+
return image, latent
|
111 |
+
else:
|
112 |
+
return image, None
|
113 |
+
|
114 |
+
|
115 |
+
class ResBlock(torch.nn.Module):
|
116 |
+
"""Residual block with bilinear upsampling/downsampling.
|
117 |
+
Args:
|
118 |
+
in_channels (int): Channel number of the input.
|
119 |
+
out_channels (int): Channel number of the output.
|
120 |
+
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(self, in_channels, out_channels, mode='down'):
|
124 |
+
super(ResBlock, self).__init__()
|
125 |
+
|
126 |
+
self.conv1 = torch.nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
127 |
+
self.conv2 = torch.nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
128 |
+
self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
129 |
+
if mode == 'down':
|
130 |
+
self.scale_factor = 0.5
|
131 |
+
elif mode == 'up':
|
132 |
+
self.scale_factor = 2
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
136 |
+
# upsample/downsample
|
137 |
+
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
138 |
+
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
139 |
+
# skip
|
140 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
141 |
+
skip = self.skip(x)
|
142 |
+
out = out + skip
|
143 |
+
return out
|
144 |
+
|
145 |
+
|
146 |
+
class GFPGANv1Clean(torch.nn.Module):
|
147 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
148 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
149 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
150 |
+
Args:
|
151 |
+
out_size (int): The spatial size of outputs.
|
152 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
153 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
154 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
155 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
156 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
157 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
158 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
159 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
160 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
out_size,
|
166 |
+
num_style_feat=512,
|
167 |
+
channel_multiplier=1,
|
168 |
+
decoder_load_path=None,
|
169 |
+
fix_decoder=True,
|
170 |
+
# for stylegan decoder
|
171 |
+
num_mlp=8,
|
172 |
+
input_is_latent=False,
|
173 |
+
different_w=False,
|
174 |
+
narrow=1,
|
175 |
+
sft_half=False):
|
176 |
+
|
177 |
+
super(GFPGANv1Clean, self).__init__()
|
178 |
+
self.input_is_latent = input_is_latent
|
179 |
+
self.different_w = different_w
|
180 |
+
self.num_style_feat = num_style_feat
|
181 |
+
|
182 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
183 |
+
channels = {
|
184 |
+
'4': int(512 * unet_narrow),
|
185 |
+
'8': int(512 * unet_narrow),
|
186 |
+
'16': int(512 * unet_narrow),
|
187 |
+
'32': int(512 * unet_narrow),
|
188 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
189 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
190 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
191 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
192 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
193 |
+
}
|
194 |
+
|
195 |
+
self.log_size = int(math.log(out_size, 2))
|
196 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
197 |
+
|
198 |
+
self.conv_body_first = torch.nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
199 |
+
|
200 |
+
# downsample
|
201 |
+
in_channels = channels[f'{first_out_size}']
|
202 |
+
self.conv_body_down = torch.nn.ModuleList()
|
203 |
+
for i in range(self.log_size, 2, -1):
|
204 |
+
out_channels = channels[f'{2**(i - 1)}']
|
205 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
206 |
+
in_channels = out_channels
|
207 |
+
|
208 |
+
self.final_conv = torch.nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
209 |
+
|
210 |
+
# upsample
|
211 |
+
in_channels = channels['4']
|
212 |
+
self.conv_body_up = torch.nn.ModuleList()
|
213 |
+
for i in range(3, self.log_size + 1):
|
214 |
+
out_channels = channels[f'{2**i}']
|
215 |
+
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
216 |
+
in_channels = out_channels
|
217 |
+
|
218 |
+
# to RGB
|
219 |
+
self.toRGB = torch.nn.ModuleList()
|
220 |
+
for i in range(3, self.log_size + 1):
|
221 |
+
self.toRGB.append(torch.nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
222 |
+
|
223 |
+
if different_w:
|
224 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
225 |
+
else:
|
226 |
+
linear_out_channel = num_style_feat
|
227 |
+
|
228 |
+
self.final_linear = torch.nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
229 |
+
|
230 |
+
# the decoder: stylegan2 generator with SFT modulations
|
231 |
+
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
232 |
+
out_size=out_size,
|
233 |
+
num_style_feat=num_style_feat,
|
234 |
+
num_mlp=num_mlp,
|
235 |
+
channel_multiplier=channel_multiplier,
|
236 |
+
narrow=narrow,
|
237 |
+
sft_half=sft_half)
|
238 |
+
|
239 |
+
# load pre-trained stylegan2 model if necessary
|
240 |
+
if decoder_load_path:
|
241 |
+
self.stylegan_decoder.load_state_dict(
|
242 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
243 |
+
# fix decoder without updating params
|
244 |
+
if fix_decoder:
|
245 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
246 |
+
param.requires_grad = False
|
247 |
+
|
248 |
+
# for SFT modulations (scale and shift)
|
249 |
+
self.condition_scale = torch.nn.ModuleList()
|
250 |
+
self.condition_shift = torch.nn.ModuleList()
|
251 |
+
for i in range(3, self.log_size + 1):
|
252 |
+
out_channels = channels[f'{2**i}']
|
253 |
+
if sft_half:
|
254 |
+
sft_out_channels = out_channels
|
255 |
+
else:
|
256 |
+
sft_out_channels = out_channels * 2
|
257 |
+
self.condition_scale.append(
|
258 |
+
torch.nn.Sequential(
|
259 |
+
torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1), torch.nn.LeakyReLU(0.2, True),
|
260 |
+
torch.nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
261 |
+
self.condition_shift.append(
|
262 |
+
torch.nn.Sequential(
|
263 |
+
torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1), torch.nn.LeakyReLU(0.2, True),
|
264 |
+
torch.nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
265 |
+
|
266 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
267 |
+
"""Forward function for GFPGANv1Clean.
|
268 |
+
Args:
|
269 |
+
x (Tensor): Input images.
|
270 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
271 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
272 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
273 |
+
"""
|
274 |
+
conditions = []
|
275 |
+
unet_skips = []
|
276 |
+
out_rgbs = []
|
277 |
+
|
278 |
+
# encoder
|
279 |
+
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
280 |
+
for i in range(self.log_size - 2):
|
281 |
+
feat = self.conv_body_down[i](feat)
|
282 |
+
unet_skips.insert(0, feat)
|
283 |
+
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
284 |
+
|
285 |
+
# style code
|
286 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
287 |
+
if self.different_w:
|
288 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
289 |
+
|
290 |
+
# decode
|
291 |
+
for i in range(self.log_size - 2):
|
292 |
+
# add unet skip
|
293 |
+
feat = feat + unet_skips[i]
|
294 |
+
# ResUpLayer
|
295 |
+
feat = self.conv_body_up[i](feat)
|
296 |
+
# generate scale and shift for SFT layers
|
297 |
+
scale = self.condition_scale[i](feat)
|
298 |
+
conditions.append(scale.clone())
|
299 |
+
shift = self.condition_shift[i](feat)
|
300 |
+
conditions.append(shift.clone())
|
301 |
+
# generate rgb images
|
302 |
+
if return_rgb:
|
303 |
+
out_rgbs.append(self.toRGB[i](feat))
|
304 |
+
|
305 |
+
# decoder
|
306 |
+
image, _ = self.stylegan_decoder([style_code],
|
307 |
+
conditions,
|
308 |
+
return_latents=return_latents,
|
309 |
+
input_is_latent=self.input_is_latent,
|
310 |
+
randomize_noise=randomize_noise)
|
311 |
+
|
312 |
+
return image, out_rgbs
|
313 |
+
|
314 |
+
|
315 |
+
class GFPGANer(GFPGANv1Clean):
|
316 |
+
"""Helper for restoration with GFPGAN."""
|
317 |
+
|
318 |
+
def __init__(self):
|
319 |
+
super().__init__(out_size=512, num_style_feat=512, channel_multiplier=2,
|
320 |
+
decoder_load_path=None, fix_decoder=False, num_mlp=8, input_is_latent=True,
|
321 |
+
different_w=True, narrow=1, sft_half=True)
|
322 |
+
|
323 |
+
self.min_max = (-1, 1)
|
324 |
+
|
325 |
+
@torch.no_grad()
|
326 |
+
def enhance(self, img, weight=0.5):
|
327 |
+
n, c, h, w = img.shape
|
328 |
+
img = F.interpolate(img, size=(512, 512), mode="bilinear")
|
329 |
+
|
330 |
+
img = (img - 0.5) / 0.5
|
331 |
+
|
332 |
+
try:
|
333 |
+
restored_faces = self.forward(img, return_rgb=False, weight=weight)[0]
|
334 |
+
except RuntimeError as error:
|
335 |
+
print(f'\tFailed inference for GFPGAN: {error}.')
|
336 |
+
restored_faces = img
|
337 |
+
|
338 |
+
restored_faces.clamp_(*self.min_max)
|
339 |
+
restored_faces = (restored_faces - self.min_max[0]) / (self.min_max[1] - self.min_max[0])
|
340 |
+
|
341 |
+
return F.interpolate(restored_faces, size=(h, w), mode="bilinear")
|
src/PostProcess/GFPGAN/stylegan2.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class NormStyleCode(torch.nn.Module):
|
8 |
+
|
9 |
+
def forward(self, x):
|
10 |
+
"""Normalize the style codes.
|
11 |
+
Args:
|
12 |
+
x (Tensor): Style codes with shape (b, c).
|
13 |
+
Returns:
|
14 |
+
Tensor: Normalized tensor.
|
15 |
+
"""
|
16 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
17 |
+
|
18 |
+
|
19 |
+
class ModulatedConv2d(torch.nn.Module):
|
20 |
+
"""Modulated Conv2d used in StyleGAN2.
|
21 |
+
There is no bias in ModulatedConv2d.
|
22 |
+
Args:
|
23 |
+
in_channels (int): Channel number of the input.
|
24 |
+
out_channels (int): Channel number of the output.
|
25 |
+
kernel_size (int): Size of the convolving kernel.
|
26 |
+
num_style_feat (int): Channel number of style features.
|
27 |
+
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
28 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
29 |
+
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
in_channels,
|
34 |
+
out_channels,
|
35 |
+
kernel_size,
|
36 |
+
num_style_feat,
|
37 |
+
demodulate=True,
|
38 |
+
sample_mode=None,
|
39 |
+
eps=1e-8):
|
40 |
+
super(ModulatedConv2d, self).__init__()
|
41 |
+
self.in_channels = in_channels
|
42 |
+
self.out_channels = out_channels
|
43 |
+
self.kernel_size = kernel_size
|
44 |
+
self.demodulate = demodulate
|
45 |
+
self.sample_mode = sample_mode
|
46 |
+
self.eps = eps
|
47 |
+
|
48 |
+
# modulation inside each modulated conv
|
49 |
+
self.modulation = torch.nn.Linear(num_style_feat, in_channels, bias=True)
|
50 |
+
# initialization
|
51 |
+
# default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
52 |
+
|
53 |
+
self.weight = torch.nn.Parameter(
|
54 |
+
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
55 |
+
math.sqrt(in_channels * kernel_size**2))
|
56 |
+
self.padding = kernel_size // 2
|
57 |
+
|
58 |
+
def forward(self, x, style):
|
59 |
+
"""Forward function.
|
60 |
+
Args:
|
61 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
62 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
63 |
+
Returns:
|
64 |
+
Tensor: Modulated tensor after convolution.
|
65 |
+
"""
|
66 |
+
b, c, h, w = x.shape # c = c_in
|
67 |
+
# weight modulation
|
68 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
69 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
70 |
+
weight = self.weight * style # (b, c_out, c_in, k, k)
|
71 |
+
|
72 |
+
if self.demodulate:
|
73 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
74 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
75 |
+
|
76 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
77 |
+
|
78 |
+
# upsample or downsample if necessary
|
79 |
+
if self.sample_mode == 'upsample':
|
80 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
81 |
+
elif self.sample_mode == 'downsample':
|
82 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
83 |
+
|
84 |
+
b, c, h, w = x.shape
|
85 |
+
x = x.view(1, b * c, h, w)
|
86 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
87 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
88 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
89 |
+
|
90 |
+
return out
|
91 |
+
|
92 |
+
def __repr__(self):
|
93 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
94 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
95 |
+
|
96 |
+
|
97 |
+
class StyleConv(torch.nn.Module):
|
98 |
+
"""Style conv used in StyleGAN2.
|
99 |
+
Args:
|
100 |
+
in_channels (int): Channel number of the input.
|
101 |
+
out_channels (int): Channel number of the output.
|
102 |
+
kernel_size (int): Size of the convolving kernel.
|
103 |
+
num_style_feat (int): Channel number of style features.
|
104 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
105 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
109 |
+
super(StyleConv, self).__init__()
|
110 |
+
self.modulated_conv = ModulatedConv2d(
|
111 |
+
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
112 |
+
self.weight = torch.nn.Parameter(torch.zeros(1)) # for noise injection
|
113 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
114 |
+
self.activate = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
115 |
+
|
116 |
+
def forward(self, x, style, noise=None):
|
117 |
+
# modulate
|
118 |
+
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
119 |
+
# noise injection
|
120 |
+
if noise is None:
|
121 |
+
b, _, h, w = out.shape
|
122 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
123 |
+
out = out + self.weight * noise
|
124 |
+
# add bias
|
125 |
+
out = out + self.bias
|
126 |
+
# activation
|
127 |
+
out = self.activate(out)
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
class ToRGB(torch.nn.Module):
|
132 |
+
"""To RGB (image space) from features.
|
133 |
+
Args:
|
134 |
+
in_channels (int): Channel number of input.
|
135 |
+
num_style_feat (int): Channel number of style features.
|
136 |
+
upsample (bool): Whether to upsample. Default: True.
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self, in_channels, num_style_feat, upsample=True):
|
140 |
+
super(ToRGB, self).__init__()
|
141 |
+
self.upsample = upsample
|
142 |
+
self.modulated_conv = ModulatedConv2d(
|
143 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
144 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, 3, 1, 1))
|
145 |
+
|
146 |
+
def forward(self, x, style, skip=None):
|
147 |
+
"""Forward function.
|
148 |
+
Args:
|
149 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
150 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
151 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
152 |
+
Returns:
|
153 |
+
Tensor: RGB images.
|
154 |
+
"""
|
155 |
+
out = self.modulated_conv(x, style)
|
156 |
+
out = out + self.bias
|
157 |
+
if skip is not None:
|
158 |
+
if self.upsample:
|
159 |
+
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
160 |
+
out = out + skip
|
161 |
+
return out
|
162 |
+
|
163 |
+
|
164 |
+
class ConstantInput(torch.nn.Module):
|
165 |
+
"""Constant input.
|
166 |
+
Args:
|
167 |
+
num_channel (int): Channel number of constant input.
|
168 |
+
size (int): Spatial size of constant input.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, num_channel, size):
|
172 |
+
super(ConstantInput, self).__init__()
|
173 |
+
self.weight = torch.nn.Parameter(torch.randn(1, num_channel, size, size))
|
174 |
+
|
175 |
+
def forward(self, batch):
|
176 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
177 |
+
return out
|
178 |
+
|
179 |
+
|
180 |
+
class StyleGAN2GeneratorClean(torch.nn.Module):
|
181 |
+
"""Clean version of StyleGAN2 Generator.
|
182 |
+
Args:
|
183 |
+
out_size (int): The spatial size of outputs.
|
184 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
185 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
186 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
187 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
191 |
+
super(StyleGAN2GeneratorClean, self).__init__()
|
192 |
+
# Style MLP layers
|
193 |
+
self.num_style_feat = num_style_feat
|
194 |
+
style_mlp_layers = [NormStyleCode()]
|
195 |
+
for i in range(num_mlp):
|
196 |
+
style_mlp_layers.extend(
|
197 |
+
[torch.nn.Linear(num_style_feat, num_style_feat, bias=True),
|
198 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
199 |
+
self.style_mlp = torch.nn.Sequential(*style_mlp_layers)
|
200 |
+
# initialization
|
201 |
+
# default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
202 |
+
|
203 |
+
# channel list
|
204 |
+
channels = {
|
205 |
+
'4': int(512 * narrow),
|
206 |
+
'8': int(512 * narrow),
|
207 |
+
'16': int(512 * narrow),
|
208 |
+
'32': int(512 * narrow),
|
209 |
+
'64': int(256 * channel_multiplier * narrow),
|
210 |
+
'128': int(128 * channel_multiplier * narrow),
|
211 |
+
'256': int(64 * channel_multiplier * narrow),
|
212 |
+
'512': int(32 * channel_multiplier * narrow),
|
213 |
+
'1024': int(16 * channel_multiplier * narrow)
|
214 |
+
}
|
215 |
+
self.channels = channels
|
216 |
+
|
217 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
218 |
+
self.style_conv1 = StyleConv(
|
219 |
+
channels['4'],
|
220 |
+
channels['4'],
|
221 |
+
kernel_size=3,
|
222 |
+
num_style_feat=num_style_feat,
|
223 |
+
demodulate=True,
|
224 |
+
sample_mode=None)
|
225 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
226 |
+
|
227 |
+
self.log_size = int(math.log(out_size, 2))
|
228 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
229 |
+
self.num_latent = self.log_size * 2 - 2
|
230 |
+
|
231 |
+
self.style_convs = torch.nn.ModuleList()
|
232 |
+
self.to_rgbs = torch.nn.ModuleList()
|
233 |
+
self.noises = torch.nn.Module()
|
234 |
+
|
235 |
+
in_channels = channels['4']
|
236 |
+
# noise
|
237 |
+
for layer_idx in range(self.num_layers):
|
238 |
+
resolution = 2**((layer_idx + 5) // 2)
|
239 |
+
shape = [1, 1, resolution, resolution]
|
240 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
241 |
+
# style convs and to_rgbs
|
242 |
+
for i in range(3, self.log_size + 1):
|
243 |
+
out_channels = channels[f'{2**i}']
|
244 |
+
self.style_convs.append(
|
245 |
+
StyleConv(
|
246 |
+
in_channels,
|
247 |
+
out_channels,
|
248 |
+
kernel_size=3,
|
249 |
+
num_style_feat=num_style_feat,
|
250 |
+
demodulate=True,
|
251 |
+
sample_mode='upsample'))
|
252 |
+
self.style_convs.append(
|
253 |
+
StyleConv(
|
254 |
+
out_channels,
|
255 |
+
out_channels,
|
256 |
+
kernel_size=3,
|
257 |
+
num_style_feat=num_style_feat,
|
258 |
+
demodulate=True,
|
259 |
+
sample_mode=None))
|
260 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
261 |
+
in_channels = out_channels
|
262 |
+
|
263 |
+
def make_noise(self):
|
264 |
+
"""Make noise for noise injection."""
|
265 |
+
device = self.constant_input.weight.device
|
266 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
267 |
+
|
268 |
+
for i in range(3, self.log_size + 1):
|
269 |
+
for _ in range(2):
|
270 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
271 |
+
|
272 |
+
return noises
|
273 |
+
|
274 |
+
def get_latent(self, x):
|
275 |
+
return self.style_mlp(x)
|
276 |
+
|
277 |
+
def mean_latent(self, num_latent):
|
278 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
279 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
280 |
+
return latent
|
281 |
+
|
282 |
+
def forward(self,
|
283 |
+
styles,
|
284 |
+
input_is_latent=False,
|
285 |
+
noise=None,
|
286 |
+
randomize_noise=True,
|
287 |
+
truncation=1,
|
288 |
+
truncation_latent=None,
|
289 |
+
inject_index=None,
|
290 |
+
return_latents=False):
|
291 |
+
"""Forward function for StyleGAN2GeneratorClean.
|
292 |
+
Args:
|
293 |
+
styles (list[Tensor]): Sample codes of styles.
|
294 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
295 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
296 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
297 |
+
truncation (float): The truncation ratio. Default: 1.
|
298 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
299 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
300 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
301 |
+
"""
|
302 |
+
# style codes -> latents with Style MLP layer
|
303 |
+
if not input_is_latent:
|
304 |
+
styles = [self.style_mlp(s) for s in styles]
|
305 |
+
# noises
|
306 |
+
if noise is None:
|
307 |
+
if randomize_noise:
|
308 |
+
noise = [None] * self.num_layers # for each style conv layer
|
309 |
+
else: # use the stored noise
|
310 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
311 |
+
# style truncation
|
312 |
+
if truncation < 1:
|
313 |
+
style_truncation = []
|
314 |
+
for style in styles:
|
315 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
316 |
+
styles = style_truncation
|
317 |
+
# get style latents with injection
|
318 |
+
if len(styles) == 1:
|
319 |
+
inject_index = self.num_latent
|
320 |
+
|
321 |
+
if styles[0].ndim < 3:
|
322 |
+
# repeat latent code for all the layers
|
323 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
324 |
+
else: # used for encoder with different latent code for each layer
|
325 |
+
latent = styles[0]
|
326 |
+
elif len(styles) == 2: # mixing noises
|
327 |
+
if inject_index is None:
|
328 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
329 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
330 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
331 |
+
latent = torch.cat([latent1, latent2], 1)
|
332 |
+
|
333 |
+
# main generation
|
334 |
+
out = self.constant_input(latent.shape[0])
|
335 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
336 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
337 |
+
|
338 |
+
i = 1
|
339 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
340 |
+
noise[2::2], self.to_rgbs):
|
341 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
342 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
343 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
344 |
+
i += 2
|
345 |
+
|
346 |
+
image = skip
|
347 |
+
|
348 |
+
if return_latents:
|
349 |
+
return image, latent
|
350 |
+
else:
|
351 |
+
return image, None
|
src/PostProcess/ParsingModel/model.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from src.PostProcess.ParsingModel.resnet import Resnet18
|
10 |
+
|
11 |
+
from src.PostProcess.utils import encode_segmentation_rgb_batch
|
12 |
+
from typing import Tuple
|
13 |
+
|
14 |
+
|
15 |
+
class ConvBNReLU(nn.Module):
|
16 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
17 |
+
super(ConvBNReLU, self).__init__()
|
18 |
+
self.conv = nn.Conv2d(
|
19 |
+
in_chan,
|
20 |
+
out_chan,
|
21 |
+
kernel_size=ks,
|
22 |
+
stride=stride,
|
23 |
+
padding=padding,
|
24 |
+
bias=False,
|
25 |
+
)
|
26 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
27 |
+
self.init_weight()
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.conv(x)
|
31 |
+
x = F.relu(self.bn(x))
|
32 |
+
return x
|
33 |
+
|
34 |
+
def init_weight(self):
|
35 |
+
for ly in self.children():
|
36 |
+
if isinstance(ly, nn.Conv2d):
|
37 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
38 |
+
if ly.bias is not None:
|
39 |
+
nn.init.constant_(ly.bias, 0)
|
40 |
+
|
41 |
+
|
42 |
+
class BiSeNetOutput(nn.Module):
|
43 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
44 |
+
super(BiSeNetOutput, self).__init__()
|
45 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
46 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
47 |
+
self.init_weight()
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = self.conv(x)
|
51 |
+
x = self.conv_out(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
def init_weight(self):
|
55 |
+
for ly in self.children():
|
56 |
+
if isinstance(ly, nn.Conv2d):
|
57 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
58 |
+
if ly.bias is not None:
|
59 |
+
nn.init.constant_(ly.bias, 0)
|
60 |
+
|
61 |
+
def get_params(self):
|
62 |
+
wd_params, nowd_params = [], []
|
63 |
+
for name, module in self.named_modules():
|
64 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
65 |
+
wd_params.append(module.weight)
|
66 |
+
if module.bias is not None:
|
67 |
+
nowd_params.append(module.bias)
|
68 |
+
elif isinstance(module, nn.BatchNorm2d):
|
69 |
+
nowd_params += list(module.parameters())
|
70 |
+
return wd_params, nowd_params
|
71 |
+
|
72 |
+
|
73 |
+
class AttentionRefinementModule(nn.Module):
|
74 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
75 |
+
super(AttentionRefinementModule, self).__init__()
|
76 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
77 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
|
78 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
79 |
+
self.sigmoid_atten = nn.Sigmoid()
|
80 |
+
self.init_weight()
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
feat = self.conv(x)
|
84 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
85 |
+
atten = self.conv_atten(atten)
|
86 |
+
atten = self.bn_atten(atten)
|
87 |
+
atten = self.sigmoid_atten(atten)
|
88 |
+
out = torch.mul(feat, atten)
|
89 |
+
return out
|
90 |
+
|
91 |
+
def init_weight(self):
|
92 |
+
for ly in self.children():
|
93 |
+
if isinstance(ly, nn.Conv2d):
|
94 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
95 |
+
if ly.bias is not None:
|
96 |
+
nn.init.constant_(ly.bias, 0)
|
97 |
+
|
98 |
+
|
99 |
+
class ContextPath(nn.Module):
|
100 |
+
def __init__(self, *args, **kwargs):
|
101 |
+
super(ContextPath, self).__init__()
|
102 |
+
self.resnet = Resnet18()
|
103 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
104 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
105 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
106 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
107 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
108 |
+
|
109 |
+
self.init_weight()
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
H0, W0 = x.size()[2:]
|
113 |
+
feat8, feat16, feat32 = self.resnet(x)
|
114 |
+
H8, W8 = feat8.size()[2:]
|
115 |
+
H16, W16 = feat16.size()[2:]
|
116 |
+
H32, W32 = feat32.size()[2:]
|
117 |
+
|
118 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
119 |
+
avg = self.conv_avg(avg)
|
120 |
+
avg_up = F.interpolate(avg, (H32, W32), mode="nearest")
|
121 |
+
|
122 |
+
feat32_arm = self.arm32(feat32)
|
123 |
+
feat32_sum = feat32_arm + avg_up
|
124 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode="nearest")
|
125 |
+
feat32_up = self.conv_head32(feat32_up)
|
126 |
+
|
127 |
+
feat16_arm = self.arm16(feat16)
|
128 |
+
feat16_sum = feat16_arm + feat32_up
|
129 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode="nearest")
|
130 |
+
feat16_up = self.conv_head16(feat16_up)
|
131 |
+
|
132 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
133 |
+
|
134 |
+
def init_weight(self):
|
135 |
+
for ly in self.children():
|
136 |
+
if isinstance(ly, nn.Conv2d):
|
137 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
138 |
+
if ly.bias is not None:
|
139 |
+
nn.init.constant_(ly.bias, 0)
|
140 |
+
|
141 |
+
def get_params(self):
|
142 |
+
wd_params, nowd_params = [], []
|
143 |
+
for name, module in self.named_modules():
|
144 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
145 |
+
wd_params.append(module.weight)
|
146 |
+
if module.bias is not None:
|
147 |
+
nowd_params.append(module.bias)
|
148 |
+
elif isinstance(module, nn.BatchNorm2d):
|
149 |
+
nowd_params += list(module.parameters())
|
150 |
+
return wd_params, nowd_params
|
151 |
+
|
152 |
+
|
153 |
+
# This is not used, since I replace this with the resnet feature with the same size
|
154 |
+
class SpatialPath(nn.Module):
|
155 |
+
def __init__(self, *args, **kwargs):
|
156 |
+
super(SpatialPath, self).__init__()
|
157 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
158 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
159 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
160 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
161 |
+
self.init_weight()
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
feat = self.conv1(x)
|
165 |
+
feat = self.conv2(feat)
|
166 |
+
feat = self.conv3(feat)
|
167 |
+
feat = self.conv_out(feat)
|
168 |
+
return feat
|
169 |
+
|
170 |
+
def init_weight(self):
|
171 |
+
for ly in self.children():
|
172 |
+
if isinstance(ly, nn.Conv2d):
|
173 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
174 |
+
if ly.bias is not None:
|
175 |
+
nn.init.constant_(ly.bias, 0)
|
176 |
+
|
177 |
+
def get_params(self):
|
178 |
+
wd_params, nowd_params = [], []
|
179 |
+
for name, module in self.named_modules():
|
180 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
181 |
+
wd_params.append(module.weight)
|
182 |
+
if module.bias is not None:
|
183 |
+
nowd_params.append(module.bias)
|
184 |
+
elif isinstance(module, nn.BatchNorm2d):
|
185 |
+
nowd_params += list(module.parameters())
|
186 |
+
return wd_params, nowd_params
|
187 |
+
|
188 |
+
|
189 |
+
class FeatureFusionModule(nn.Module):
|
190 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
191 |
+
super(FeatureFusionModule, self).__init__()
|
192 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
193 |
+
self.conv1 = nn.Conv2d(
|
194 |
+
out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False
|
195 |
+
)
|
196 |
+
self.conv2 = nn.Conv2d(
|
197 |
+
out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False
|
198 |
+
)
|
199 |
+
self.relu = nn.ReLU(inplace=True)
|
200 |
+
self.sigmoid = nn.Sigmoid()
|
201 |
+
self.init_weight()
|
202 |
+
|
203 |
+
def forward(self, fsp, fcp):
|
204 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
205 |
+
feat = self.convblk(fcat)
|
206 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
207 |
+
atten = self.conv1(atten)
|
208 |
+
atten = self.relu(atten)
|
209 |
+
atten = self.conv2(atten)
|
210 |
+
atten = self.sigmoid(atten)
|
211 |
+
feat_atten = torch.mul(feat, atten)
|
212 |
+
feat_out = feat_atten + feat
|
213 |
+
return feat_out
|
214 |
+
|
215 |
+
def init_weight(self):
|
216 |
+
for ly in self.children():
|
217 |
+
if isinstance(ly, nn.Conv2d):
|
218 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
219 |
+
if ly.bias is not None:
|
220 |
+
nn.init.constant_(ly.bias, 0)
|
221 |
+
|
222 |
+
def get_params(self):
|
223 |
+
wd_params, nowd_params = [], []
|
224 |
+
for name, module in self.named_modules():
|
225 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
226 |
+
wd_params.append(module.weight)
|
227 |
+
if module.bias is not None:
|
228 |
+
nowd_params.append(module.bias)
|
229 |
+
elif isinstance(module, nn.BatchNorm2d):
|
230 |
+
nowd_params += list(module.parameters())
|
231 |
+
return wd_params, nowd_params
|
232 |
+
|
233 |
+
|
234 |
+
class BiSeNet(nn.Module):
|
235 |
+
def __init__(self, n_classes, *args, **kwargs):
|
236 |
+
super(BiSeNet, self).__init__()
|
237 |
+
self.cp = ContextPath()
|
238 |
+
# here self.sp is deleted
|
239 |
+
self.ffm = FeatureFusionModule(256, 256)
|
240 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
241 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
242 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
243 |
+
self.init_weight()
|
244 |
+
|
245 |
+
def get_mask(
|
246 |
+
self, x: torch.Tensor, crop_size: int
|
247 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
248 |
+
x = F.interpolate(x, size=(512, 512))
|
249 |
+
|
250 |
+
parsed_face = self.forward(x)[0]
|
251 |
+
|
252 |
+
parsed_face = torch.argmax(parsed_face, dim=1, keepdim=True)
|
253 |
+
|
254 |
+
parsed_face = encode_segmentation_rgb_batch(parsed_face)
|
255 |
+
|
256 |
+
parsed_face = torch.where(
|
257 |
+
torch.sum(parsed_face, dim=[1, 2, 3], keepdim=True) > 5000,
|
258 |
+
parsed_face,
|
259 |
+
torch.zeros_like(parsed_face),
|
260 |
+
)
|
261 |
+
|
262 |
+
ignore_mask_ids = torch.sum(parsed_face, dim=[1, 2, 3]) == 0
|
263 |
+
|
264 |
+
parsed_face = parsed_face.float().mul_(1 / 255.0)
|
265 |
+
|
266 |
+
parsed_face = F.interpolate(
|
267 |
+
parsed_face, size=(crop_size, crop_size), mode="bilinear"
|
268 |
+
)
|
269 |
+
|
270 |
+
parsed_face = torch.sum(parsed_face, dim=1, keepdim=True)
|
271 |
+
|
272 |
+
return parsed_face, ignore_mask_ids
|
273 |
+
|
274 |
+
def forward(self, x):
|
275 |
+
H, W = x.size()[2:]
|
276 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
277 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
278 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
279 |
+
|
280 |
+
feat_out = self.conv_out(feat_fuse)
|
281 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
282 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
283 |
+
|
284 |
+
feat_out = F.interpolate(feat_out, (H, W), mode="bilinear", align_corners=True)
|
285 |
+
feat_out16 = F.interpolate(
|
286 |
+
feat_out16, (H, W), mode="bilinear", align_corners=True
|
287 |
+
)
|
288 |
+
feat_out32 = F.interpolate(
|
289 |
+
feat_out32, (H, W), mode="bilinear", align_corners=True
|
290 |
+
)
|
291 |
+
return feat_out, feat_out16, feat_out32
|
292 |
+
|
293 |
+
def init_weight(self):
|
294 |
+
for ly in self.children():
|
295 |
+
if isinstance(ly, nn.Conv2d):
|
296 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
297 |
+
if ly.bias is not None:
|
298 |
+
nn.init.constant_(ly.bias, 0)
|
299 |
+
|
300 |
+
def get_params(self):
|
301 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
302 |
+
for name, child in self.named_children():
|
303 |
+
child_wd_params, child_nowd_params = child.get_params()
|
304 |
+
if isinstance(child, FeatureFusionModule) or isinstance(
|
305 |
+
child, BiSeNetOutput
|
306 |
+
):
|
307 |
+
lr_mul_wd_params += child_wd_params
|
308 |
+
lr_mul_nowd_params += child_nowd_params
|
309 |
+
else:
|
310 |
+
wd_params += child_wd_params
|
311 |
+
nowd_params += child_nowd_params
|
312 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
net = BiSeNet(19)
|
317 |
+
net.cuda()
|
318 |
+
net.eval()
|
319 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
320 |
+
out, out16, out32 = net(in_ten)
|
321 |
+
print(out.shape)
|
322 |
+
|
323 |
+
net.get_params()
|
src/PostProcess/ParsingModel/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(
|
17 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class BasicBlock(nn.Module):
|
22 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
23 |
+
super(BasicBlock, self).__init__()
|
24 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
25 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
26 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
27 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
28 |
+
self.relu = nn.ReLU(inplace=True)
|
29 |
+
self.downsample = None
|
30 |
+
if in_chan != out_chan or stride != 1:
|
31 |
+
self.downsample = nn.Sequential(
|
32 |
+
nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum - 1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
62 |
+
self.bn1 = nn.BatchNorm2d(64)
|
63 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
64 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
65 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
66 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
67 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
68 |
+
self.init_weight()
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = self.conv1(x)
|
72 |
+
x = F.relu(self.bn1(x))
|
73 |
+
x = self.maxpool(x)
|
74 |
+
|
75 |
+
x = self.layer1(x)
|
76 |
+
feat8 = self.layer2(x) # 1/8
|
77 |
+
feat16 = self.layer3(feat8) # 1/16
|
78 |
+
feat32 = self.layer4(feat16) # 1/32
|
79 |
+
return feat8, feat16, feat32
|
80 |
+
|
81 |
+
def init_weight(self):
|
82 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
83 |
+
self_state_dict = self.state_dict()
|
84 |
+
for k, v in state_dict.items():
|
85 |
+
if "fc" in k:
|
86 |
+
continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if module.bias is not None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
src/PostProcess/utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
|
8 |
+
class SoftErosion(torch.nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1
|
11 |
+
):
|
12 |
+
super(SoftErosion, self).__init__()
|
13 |
+
r = kernel_size // 2
|
14 |
+
self.padding = r
|
15 |
+
self.iterations = iterations
|
16 |
+
self.threshold = threshold
|
17 |
+
|
18 |
+
# Create kernel
|
19 |
+
y_indices, x_indices = torch.meshgrid(
|
20 |
+
torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size)
|
21 |
+
)
|
22 |
+
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
|
23 |
+
kernel = dist.max() - dist
|
24 |
+
kernel /= kernel.sum()
|
25 |
+
kernel = kernel.view(1, 1, *kernel.shape)
|
26 |
+
self.register_buffer("weight", kernel)
|
27 |
+
|
28 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
29 |
+
for i in range(self.iterations - 1):
|
30 |
+
x = torch.min(
|
31 |
+
x,
|
32 |
+
F.conv2d(
|
33 |
+
x, weight=self.weight, groups=x.shape[1], padding=self.padding
|
34 |
+
),
|
35 |
+
)
|
36 |
+
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
|
37 |
+
|
38 |
+
mask = x >= self.threshold
|
39 |
+
|
40 |
+
x[mask] = 1.0
|
41 |
+
# add small epsilon to avoid Nans
|
42 |
+
x[~mask] /= (x[~mask].max() + 1e-7)
|
43 |
+
|
44 |
+
return x, mask
|
45 |
+
|
46 |
+
|
47 |
+
def encode_segmentation_rgb(
|
48 |
+
segmentation: np.ndarray, no_neck: bool = True
|
49 |
+
) -> np.ndarray:
|
50 |
+
parse = segmentation
|
51 |
+
# https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
|
52 |
+
face_part_ids = (
|
53 |
+
[1, 2, 3, 4, 5, 6, 10, 12, 13]
|
54 |
+
if no_neck
|
55 |
+
else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
|
56 |
+
)
|
57 |
+
mouth_id = 11
|
58 |
+
# hair_id = 17
|
59 |
+
face_map = np.zeros([parse.shape[0], parse.shape[1]])
|
60 |
+
mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
|
61 |
+
# hair_map = np.zeros([parse.shape[0], parse.shape[1]])
|
62 |
+
|
63 |
+
for valid_id in face_part_ids:
|
64 |
+
valid_index = np.where(parse == valid_id)
|
65 |
+
face_map[valid_index] = 255
|
66 |
+
valid_index = np.where(parse == mouth_id)
|
67 |
+
mouth_map[valid_index] = 255
|
68 |
+
# valid_index = np.where(parse==hair_id)
|
69 |
+
# hair_map[valid_index] = 255
|
70 |
+
# return np.stack([face_map, mouth_map,hair_map], axis=2)
|
71 |
+
return np.stack([face_map, mouth_map], axis=2)
|
72 |
+
|
73 |
+
|
74 |
+
def encode_segmentation_rgb_batch(
|
75 |
+
segmentation: torch.Tensor, no_neck: bool = True
|
76 |
+
) -> torch.Tensor:
|
77 |
+
# https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
|
78 |
+
face_part_ids = (
|
79 |
+
[1, 2, 3, 4, 5, 6, 10, 12, 13]
|
80 |
+
if no_neck
|
81 |
+
else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
|
82 |
+
)
|
83 |
+
mouth_id = 11
|
84 |
+
# hair_id = 17
|
85 |
+
segmentation = segmentation.int()
|
86 |
+
face_map = torch.zeros_like(segmentation)
|
87 |
+
mouth_map = torch.zeros_like(segmentation)
|
88 |
+
# hair_map = np.zeros([parse.shape[0], parse.shape[1]])
|
89 |
+
|
90 |
+
white_tensor = face_map + 255
|
91 |
+
for valid_id in face_part_ids:
|
92 |
+
face_map = torch.where(segmentation == valid_id, white_tensor, face_map)
|
93 |
+
mouth_map = torch.where(segmentation == mouth_id, white_tensor, mouth_map)
|
94 |
+
|
95 |
+
return torch.cat([face_map, mouth_map], dim=1)
|
96 |
+
|
97 |
+
|
98 |
+
def postprocess(
|
99 |
+
swapped_face: np.ndarray,
|
100 |
+
target: np.ndarray,
|
101 |
+
target_mask: np.ndarray,
|
102 |
+
smooth_mask: torch.nn.Module,
|
103 |
+
) -> np.ndarray:
|
104 |
+
# target_mask = cv2.resize(target_mask, (self.size, self.size))
|
105 |
+
|
106 |
+
mask_tensor = (
|
107 |
+
torch.from_numpy(target_mask.copy().transpose((2, 0, 1)))
|
108 |
+
.float()
|
109 |
+
.mul_(1 / 255.0)
|
110 |
+
.cuda()
|
111 |
+
)
|
112 |
+
face_mask_tensor = mask_tensor[0] + mask_tensor[1]
|
113 |
+
|
114 |
+
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
|
115 |
+
soft_face_mask_tensor.squeeze_()
|
116 |
+
|
117 |
+
soft_face_mask = soft_face_mask_tensor.cpu().numpy()
|
118 |
+
soft_face_mask = soft_face_mask[:, :, np.newaxis]
|
119 |
+
|
120 |
+
result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
|
121 |
+
result = result[:, :, ::-1] # .astype(np.uint8)
|
122 |
+
return result
|
src/model_loader.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torch.utils import model_zoo
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from src.FaceDetector.face_detector import FaceDetector
|
9 |
+
from src.FaceId.faceid import FaceId
|
10 |
+
from src.Generator.fs_networks_fix import Generator_Adain_Upsample
|
11 |
+
from src.PostProcess.ParsingModel.model import BiSeNet
|
12 |
+
from src.PostProcess.GFPGAN.gfpgan import GFPGANer
|
13 |
+
from src.Blend.blend import BlendModule
|
14 |
+
|
15 |
+
|
16 |
+
model = namedtuple("model", ["url", "model"])
|
17 |
+
|
18 |
+
models = {
|
19 |
+
"face_detector": model(
|
20 |
+
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/face_detector_scrfd_10g_bnkps.onnx",
|
21 |
+
model=FaceDetector,
|
22 |
+
),
|
23 |
+
"arcface": model(
|
24 |
+
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/arcface_net.jit",
|
25 |
+
model=FaceId,
|
26 |
+
),
|
27 |
+
"generator_224": model(
|
28 |
+
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_224_latest_net_G.pth",
|
29 |
+
model=Generator_Adain_Upsample,
|
30 |
+
),
|
31 |
+
"generator_512": model(
|
32 |
+
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_512_390000_net_G.pth",
|
33 |
+
model=Generator_Adain_Upsample,
|
34 |
+
),
|
35 |
+
"parsing_model": model(
|
36 |
+
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/parsing_model_79999_iter.pth",
|
37 |
+
model=BiSeNet,
|
38 |
+
),
|
39 |
+
"gfpgan": model(
|
40 |
+
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.1/GFPGANv1.4_ema.pth",
|
41 |
+
model=GFPGANer,
|
42 |
+
),
|
43 |
+
"blend_module": model(
|
44 |
+
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.2/blend_module.jit",
|
45 |
+
model=BlendModule
|
46 |
+
)
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def get_model(
|
51 |
+
model_name: str,
|
52 |
+
device: torch.device,
|
53 |
+
load_state_dice: bool,
|
54 |
+
model_path: Path,
|
55 |
+
**kwargs,
|
56 |
+
):
|
57 |
+
dst_dir = Path.cwd() / "weights"
|
58 |
+
dst_dir.mkdir(exist_ok=True)
|
59 |
+
|
60 |
+
url = models[model_name].url if not model_path.is_file() else str(model_path)
|
61 |
+
|
62 |
+
if load_state_dice:
|
63 |
+
model = models[model_name].model(**kwargs)
|
64 |
+
|
65 |
+
if Path(url).is_file():
|
66 |
+
state_dict = torch.load(url)
|
67 |
+
else:
|
68 |
+
state_dict = model_zoo.load_url(
|
69 |
+
url,
|
70 |
+
model_dir=str(dst_dir),
|
71 |
+
progress=True,
|
72 |
+
map_location="cpu",
|
73 |
+
)
|
74 |
+
|
75 |
+
model.load_state_dict(state_dict)
|
76 |
+
|
77 |
+
model.to(device)
|
78 |
+
model.eval()
|
79 |
+
else:
|
80 |
+
dst_path = Path(url)
|
81 |
+
|
82 |
+
if not dst_path.is_file():
|
83 |
+
dst_path = dst_dir / Path(url).name
|
84 |
+
|
85 |
+
if not dst_path.is_file():
|
86 |
+
print(f"Downloading: '{url}' to {dst_path}")
|
87 |
+
response = requests.get(url, stream=True)
|
88 |
+
if int(response.status_code) == 200:
|
89 |
+
file_size = int(response.headers["Content-Length"]) / (2 ** 20)
|
90 |
+
chunk_size = 1024
|
91 |
+
bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n:3.1f}M/{total:3.1f}M [{elapsed}<{remaining}]"
|
92 |
+
with open(dst_path, "wb") as handle:
|
93 |
+
with tqdm(total=file_size, bar_format=bar_format) as pbar:
|
94 |
+
for data in response.iter_content(chunk_size=chunk_size):
|
95 |
+
handle.write(data)
|
96 |
+
pbar.update(len(data) / (2 ** 20))
|
97 |
+
else:
|
98 |
+
raise ValueError(
|
99 |
+
f"Couldn't download weights {url}. Specify weights for the '{model_name}' model manually."
|
100 |
+
)
|
101 |
+
|
102 |
+
kwargs.update({"model_path": str(dst_path), "device": device})
|
103 |
+
|
104 |
+
model = models[model_name].model(**kwargs)
|
105 |
+
|
106 |
+
return model
|
src/simswap.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Iterable, Tuple, Union
|
5 |
+
from pathlib import Path
|
6 |
+
from torchvision import transforms
|
7 |
+
import kornia
|
8 |
+
from omegaconf import DictConfig
|
9 |
+
|
10 |
+
from src.FaceDetector.face_detector import Detection
|
11 |
+
from src.FaceAlign.face_align import align_face, inverse_transform_batch
|
12 |
+
from src.PostProcess.utils import SoftErosion
|
13 |
+
from src.model_loader import get_model
|
14 |
+
from src.Misc.types import CheckpointType, FaceAlignmentType
|
15 |
+
from src.Misc.utils import tensor2img
|
16 |
+
|
17 |
+
|
18 |
+
class SimSwap:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
config: DictConfig,
|
22 |
+
id_image: Union[np.ndarray, None] = None,
|
23 |
+
specific_image: Union[np.ndarray, None] = None,
|
24 |
+
):
|
25 |
+
|
26 |
+
self.id_image: Union[np.ndarray, None] = id_image
|
27 |
+
self.id_latent: Union[torch.Tensor, None] = None
|
28 |
+
self.specific_id_image: Union[np.ndarray, None] = specific_image
|
29 |
+
self.specific_latent: Union[torch.Tensor, None] = None
|
30 |
+
|
31 |
+
self.use_mask: Union[bool, None] = True
|
32 |
+
self.crop_size: Union[int, None] = None
|
33 |
+
self.checkpoint_type: Union[CheckpointType, None] = None
|
34 |
+
self.face_alignment_type: Union[FaceAlignmentType, None] = None
|
35 |
+
self.smooth_mask_iter: Union[int, None] = None
|
36 |
+
self.smooth_mask_kernel_size: Union[int, None] = None
|
37 |
+
self.smooth_mask_threshold: Union[float, None] = None
|
38 |
+
self.face_detector_threshold: Union[float, None] = None
|
39 |
+
self.specific_latent_match_threshold: Union[float, None] = None
|
40 |
+
self.device = torch.device(config.device)
|
41 |
+
|
42 |
+
self.set_parameters(config)
|
43 |
+
|
44 |
+
# For BiSeNet and for official_224 SimSwap
|
45 |
+
self.to_tensor_normalize = transforms.Compose(
|
46 |
+
[
|
47 |
+
transforms.ToTensor(),
|
48 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
# For SimSwap models trained with the updated code
|
53 |
+
self.to_tensor = transforms.ToTensor()
|
54 |
+
|
55 |
+
self.face_detector = get_model(
|
56 |
+
"face_detector",
|
57 |
+
device=self.device,
|
58 |
+
load_state_dice=False,
|
59 |
+
model_path=Path(config.face_detector_weights),
|
60 |
+
det_thresh=self.face_detector_threshold,
|
61 |
+
det_size=(640, 640),
|
62 |
+
mode="ffhq",
|
63 |
+
)
|
64 |
+
|
65 |
+
self.face_id_net = get_model(
|
66 |
+
"arcface",
|
67 |
+
device=self.device,
|
68 |
+
load_state_dice=False,
|
69 |
+
model_path=Path(config.face_id_weights),
|
70 |
+
)
|
71 |
+
|
72 |
+
self.bise_net = get_model(
|
73 |
+
"parsing_model",
|
74 |
+
device=self.device,
|
75 |
+
load_state_dice=True,
|
76 |
+
model_path=Path(config.parsing_model_weights),
|
77 |
+
n_classes=19,
|
78 |
+
)
|
79 |
+
|
80 |
+
gen_model = "generator_512" if self.crop_size == 512 else "generator_224"
|
81 |
+
self.simswap_net = get_model(
|
82 |
+
gen_model,
|
83 |
+
device=self.device,
|
84 |
+
load_state_dice=True,
|
85 |
+
model_path=Path(config.simswap_weights),
|
86 |
+
input_nc=3,
|
87 |
+
output_nc=3,
|
88 |
+
latent_size=512,
|
89 |
+
n_blocks=9,
|
90 |
+
deep=True if self.crop_size == 512 else False,
|
91 |
+
use_last_act=True
|
92 |
+
if self.checkpoint_type == CheckpointType.OFFICIAL_224
|
93 |
+
else False,
|
94 |
+
)
|
95 |
+
|
96 |
+
self.blend = get_model(
|
97 |
+
"blend_module",
|
98 |
+
device=self.device,
|
99 |
+
load_state_dice=False,
|
100 |
+
model_path=Path(config.blend_module_weights)
|
101 |
+
)
|
102 |
+
|
103 |
+
self.enhance_output = config.enhance_output
|
104 |
+
if config.enhance_output:
|
105 |
+
self.gfpgan_net = get_model(
|
106 |
+
"gfpgan",
|
107 |
+
device=self.device,
|
108 |
+
load_state_dice=True,
|
109 |
+
model_path=Path(config.gfpgan_weights)
|
110 |
+
)
|
111 |
+
|
112 |
+
def set_parameters(self, config) -> None:
|
113 |
+
self.set_crop_size(config.crop_size)
|
114 |
+
self.set_checkpoint_type(config.checkpoint_type)
|
115 |
+
self.set_face_alignment_type(config.face_alignment_type)
|
116 |
+
self.set_face_detector_threshold(config.face_detector_threshold)
|
117 |
+
self.set_specific_latent_match_threshold(config.specific_latent_match_threshold)
|
118 |
+
self.set_smooth_mask_kernel_size(config.smooth_mask_kernel_size)
|
119 |
+
self.set_smooth_mask_threshold(config.smooth_mask_threshold)
|
120 |
+
self.set_smooth_mask_iter(config.smooth_mask_iter)
|
121 |
+
|
122 |
+
def set_crop_size(self, crop_size: int) -> None:
|
123 |
+
if crop_size < 0:
|
124 |
+
raise "Invalid crop_size! Must be a positive value."
|
125 |
+
|
126 |
+
self.crop_size = crop_size
|
127 |
+
|
128 |
+
def set_checkpoint_type(self, checkpoint_type: str) -> None:
|
129 |
+
type = CheckpointType(checkpoint_type)
|
130 |
+
if type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL):
|
131 |
+
raise "Invalid checkpoint_type! Must be one of the predefined values."
|
132 |
+
|
133 |
+
self.checkpoint_type = type
|
134 |
+
|
135 |
+
def set_face_alignment_type(self, face_alignment_type: str) -> None:
|
136 |
+
type = FaceAlignmentType(face_alignment_type)
|
137 |
+
if type not in (
|
138 |
+
FaceAlignmentType.FFHQ,
|
139 |
+
FaceAlignmentType.DEFAULT,
|
140 |
+
):
|
141 |
+
raise "Invalid face_alignment_type! Must be one of the predefined values."
|
142 |
+
|
143 |
+
self.face_alignment_type = type
|
144 |
+
|
145 |
+
def set_face_detector_threshold(self, face_detector_threshold: float) -> None:
|
146 |
+
if face_detector_threshold < 0.0 or face_detector_threshold > 1.0:
|
147 |
+
raise "Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]."
|
148 |
+
|
149 |
+
self.face_detector_threshold = face_detector_threshold
|
150 |
+
|
151 |
+
def set_specific_latent_match_threshold(
|
152 |
+
self, specific_latent_match_threshold: float
|
153 |
+
) -> None:
|
154 |
+
if specific_latent_match_threshold < 0.0:
|
155 |
+
raise "Invalid specific_latent_match_th! Must be a positive value."
|
156 |
+
|
157 |
+
self.specific_latent_match_threshold = specific_latent_match_threshold
|
158 |
+
|
159 |
+
def re_initialize_soft_mask(self):
|
160 |
+
self.smooth_mask = SoftErosion(kernel_size=self.smooth_mask_kernel_size,
|
161 |
+
threshold=self.smooth_mask_threshold,
|
162 |
+
iterations=self.smooth_mask_iter).to(self.device)
|
163 |
+
|
164 |
+
def set_smooth_mask_kernel_size(self, smooth_mask_kernel_size: int) -> None:
|
165 |
+
if smooth_mask_kernel_size < 0:
|
166 |
+
raise "Invalid smooth_mask_kernel_size! Must be a positive value."
|
167 |
+
smooth_mask_kernel_size += 1 if smooth_mask_kernel_size % 2 == 0 else 0
|
168 |
+
self.smooth_mask_kernel_size = smooth_mask_kernel_size
|
169 |
+
self.re_initialize_soft_mask()
|
170 |
+
|
171 |
+
def set_smooth_mask_threshold(self, smooth_mask_threshold: int) -> None:
|
172 |
+
if smooth_mask_threshold < 0 or smooth_mask_threshold > 1.0:
|
173 |
+
raise "Invalid smooth_mask_threshold! Must be within 0...1 range."
|
174 |
+
self.smooth_mask_threshold = smooth_mask_threshold
|
175 |
+
self.re_initialize_soft_mask()
|
176 |
+
|
177 |
+
def set_smooth_mask_iter(self, smooth_mask_iter: float) -> None:
|
178 |
+
if smooth_mask_iter < 0:
|
179 |
+
raise "Invalid smooth_mask_iter! Must be a positive value.."
|
180 |
+
self.smooth_mask_iter = smooth_mask_iter
|
181 |
+
self.re_initialize_soft_mask()
|
182 |
+
|
183 |
+
def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None],
|
184 |
+
Union[Iterable[np.ndarray], None],
|
185 |
+
np.ndarray]:
|
186 |
+
detection: Detection = self.face_detector(image)
|
187 |
+
|
188 |
+
if detection.bbox is None:
|
189 |
+
if for_id:
|
190 |
+
raise "Can't detect a face! Please change the ID image!"
|
191 |
+
return None, None, detection.score
|
192 |
+
|
193 |
+
kps = detection.key_points
|
194 |
+
|
195 |
+
if for_id:
|
196 |
+
max_score_ind = np.argmax(detection.score, axis=0)
|
197 |
+
kps = detection.key_points[max_score_ind]
|
198 |
+
kps = kps[None, ...]
|
199 |
+
|
200 |
+
align_imgs, transforms = align_face(
|
201 |
+
image,
|
202 |
+
kps,
|
203 |
+
crop_size=self.crop_size,
|
204 |
+
mode="ffhq"
|
205 |
+
if self.face_alignment_type == FaceAlignmentType.FFHQ
|
206 |
+
else "none",
|
207 |
+
)
|
208 |
+
|
209 |
+
return align_imgs, transforms, detection.score
|
210 |
+
|
211 |
+
def __call__(self, att_image: np.ndarray) -> np.ndarray:
|
212 |
+
if self.id_latent is None:
|
213 |
+
align_id_imgs, id_transforms, _ = self.run_detect_align(
|
214 |
+
self.id_image, for_id=True
|
215 |
+
)
|
216 |
+
# normalize=True, because official SimSwap model trained with normalized id_lattent
|
217 |
+
self.id_latent: torch.Tensor = self.face_id_net(
|
218 |
+
align_id_imgs, normalize=True
|
219 |
+
)
|
220 |
+
|
221 |
+
if self.specific_id_image is not None and self.specific_latent is None:
|
222 |
+
align_specific_imgs, specific_transforms, _ = self.run_detect_align(
|
223 |
+
self.specific_id_image, for_id=True
|
224 |
+
)
|
225 |
+
self.specific_latent: torch.Tensor = self.face_id_net(
|
226 |
+
align_specific_imgs, normalize=False
|
227 |
+
)
|
228 |
+
|
229 |
+
# for_id=False, because we want to get all faces
|
230 |
+
align_att_imgs, att_transforms, att_detection_score = self.run_detect_align(
|
231 |
+
att_image, for_id=False
|
232 |
+
)
|
233 |
+
|
234 |
+
if align_att_imgs is None and att_transforms is None:
|
235 |
+
return att_image
|
236 |
+
|
237 |
+
# Select specific crop from the target image
|
238 |
+
if self.specific_latent is not None:
|
239 |
+
att_latent: torch.Tensor = self.face_id_net(align_att_imgs, normalize=False)
|
240 |
+
latent_dist = torch.mean(
|
241 |
+
F.mse_loss(
|
242 |
+
att_latent,
|
243 |
+
self.specific_latent.repeat(att_latent.shape[0], 1),
|
244 |
+
reduction="none",
|
245 |
+
),
|
246 |
+
dim=-1,
|
247 |
+
)
|
248 |
+
|
249 |
+
att_detection_score = torch.tensor(
|
250 |
+
att_detection_score, device=latent_dist.device
|
251 |
+
)
|
252 |
+
|
253 |
+
min_index = torch.argmin(latent_dist * att_detection_score)
|
254 |
+
min_value = latent_dist[min_index]
|
255 |
+
|
256 |
+
if min_value < self.specific_latent_match_threshold:
|
257 |
+
align_att_imgs = [align_att_imgs[min_index]]
|
258 |
+
att_transforms = [att_transforms[min_index]]
|
259 |
+
else:
|
260 |
+
return att_image
|
261 |
+
|
262 |
+
swapped_img: torch.Tensor = self.simswap_net(align_att_imgs, self.id_latent)
|
263 |
+
|
264 |
+
if self.enhance_output:
|
265 |
+
swapped_img = self.gfpgan_net.enhance(swapped_img, weight=0.5)
|
266 |
+
|
267 |
+
# Put all crops/transformations into a batch
|
268 |
+
align_att_img_batch_for_parsing_model: torch.Tensor = torch.stack(
|
269 |
+
[self.to_tensor_normalize(x) for x in align_att_imgs], dim=0
|
270 |
+
)
|
271 |
+
align_att_img_batch_for_parsing_model = (
|
272 |
+
align_att_img_batch_for_parsing_model.to(self.device)
|
273 |
+
)
|
274 |
+
|
275 |
+
att_transforms: torch.Tensor = torch.stack(
|
276 |
+
[torch.tensor(x).float() for x in att_transforms], dim=0
|
277 |
+
)
|
278 |
+
att_transforms = att_transforms.to(self.device, non_blocking=True)
|
279 |
+
|
280 |
+
align_att_img_batch: torch.Tensor = torch.stack(
|
281 |
+
[self.to_tensor(x) for x in align_att_imgs], dim=0
|
282 |
+
)
|
283 |
+
align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True)
|
284 |
+
|
285 |
+
# Get face masks for the attribute image
|
286 |
+
face_mask, ignore_mask_ids = self.bise_net.get_mask(
|
287 |
+
align_att_img_batch_for_parsing_model, self.crop_size
|
288 |
+
)
|
289 |
+
|
290 |
+
inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)
|
291 |
+
|
292 |
+
soft_face_mask, _ = self.smooth_mask(face_mask)
|
293 |
+
|
294 |
+
swapped_img[ignore_mask_ids, ...] = align_att_img_batch[ignore_mask_ids, ...]
|
295 |
+
|
296 |
+
frame_size = (att_image.shape[0], att_image.shape[1])
|
297 |
+
|
298 |
+
att_image = self.to_tensor(att_image).to(self.device, non_blocking=True).unsqueeze(0)
|
299 |
+
|
300 |
+
target_image = kornia.geometry.transform.warp_affine(
|
301 |
+
swapped_img,
|
302 |
+
inv_att_transforms,
|
303 |
+
frame_size,
|
304 |
+
mode="bilinear",
|
305 |
+
padding_mode="border",
|
306 |
+
align_corners=True,
|
307 |
+
fill_value=torch.zeros(3),
|
308 |
+
)
|
309 |
+
|
310 |
+
soft_face_mask = kornia.geometry.transform.warp_affine(
|
311 |
+
soft_face_mask,
|
312 |
+
inv_att_transforms,
|
313 |
+
frame_size,
|
314 |
+
mode="bilinear",
|
315 |
+
padding_mode="zeros",
|
316 |
+
align_corners=True,
|
317 |
+
fill_value=torch.zeros(3),
|
318 |
+
)
|
319 |
+
|
320 |
+
result = self.blend(target_image, soft_face_mask, att_image)
|
321 |
+
|
322 |
+
return tensor2img(result)
|