Spaces:
Runtime error
Runtime error
commit
Browse files- Time-Travel-Rephotography/.gitignore +141 -0
- Time-Travel-Rephotography/.gitmodules +9 -0
- Time-Travel-Rephotography/LICENSE +21 -0
- Time-Travel-Rephotography/LICENSE-NVIDIA +101 -0
- Time-Travel-Rephotography/LICENSE-STYLEGAN2 +21 -0
- Time-Travel-Rephotography/README.md +119 -0
- Time-Travel-Rephotography/app.py +172 -0
- Time-Travel-Rephotography/losses/color_transfer_loss.py +60 -0
- Time-Travel-Rephotography/losses/joint_loss.py +167 -0
- Time-Travel-Rephotography/losses/perceptual_loss.py +111 -0
- Time-Travel-Rephotography/losses/reconstruction.py +119 -0
- Time-Travel-Rephotography/losses/regularize_noise.py +37 -0
- Time-Travel-Rephotography/model.py +697 -0
- Time-Travel-Rephotography/models/__init__.py +0 -0
- Time-Travel-Rephotography/models/degrade.py +122 -0
- Time-Travel-Rephotography/models/encoder.py +66 -0
- Time-Travel-Rephotography/models/gaussian_smoothing.py +74 -0
- Time-Travel-Rephotography/models/resnet.py +99 -0
- Time-Travel-Rephotography/models/vggface.py +150 -0
- Time-Travel-Rephotography/op/__init__.py +2 -0
- Time-Travel-Rephotography/op/fused_act.py +86 -0
- Time-Travel-Rephotography/op/fused_bias_act.cpp +21 -0
- Time-Travel-Rephotography/op/fused_bias_act_kernel.cu +99 -0
- Time-Travel-Rephotography/op/upfirdn2d.cpp +23 -0
- Time-Travel-Rephotography/op/upfirdn2d.py +187 -0
- Time-Travel-Rephotography/op/upfirdn2d_kernel.cu +272 -0
- Time-Travel-Rephotography/optim/__init__.py +15 -0
- Time-Travel-Rephotography/optim/radam.py +250 -0
- Time-Travel-Rephotography/requirements.txt +25 -0
- Time-Travel-Rephotography/scripts/download_checkpoints.sh +14 -0
- Time-Travel-Rephotography/scripts/install.sh +6 -0
- Time-Travel-Rephotography/scripts/run.sh +34 -0
- Time-Travel-Rephotography/tools/__init__.py +0 -0
- Time-Travel-Rephotography/tools/data/__init__.py +0 -0
- Time-Travel-Rephotography/tools/data/align_images.py +117 -0
- Time-Travel-Rephotography/tools/initialize.py +160 -0
- Time-Travel-Rephotography/tools/match_histogram.py +167 -0
- Time-Travel-Rephotography/tools/match_skin_histogram.py +67 -0
- Time-Travel-Rephotography/tools/parse_face.py +55 -0
- Time-Travel-Rephotography/utils/__init__.py +0 -0
- Time-Travel-Rephotography/utils/ffhq_dataset/__init__.py +0 -0
- Time-Travel-Rephotography/utils/ffhq_dataset/face_alignment.py +99 -0
- Time-Travel-Rephotography/utils/ffhq_dataset/landmarks_detector.py +71 -0
- Time-Travel-Rephotography/utils/misc.py +18 -0
- Time-Travel-Rephotography/utils/optimize.py +230 -0
- Time-Travel-Rephotography/utils/projector_arguments.py +76 -0
- Time-Travel-Rephotography/utils/torch_helpers.py +36 -0
Time-Travel-Rephotography/.gitignore
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
wandb/
|
132 |
+
*.lmdb/
|
133 |
+
*.pkl
|
134 |
+
|
135 |
+
# results
|
136 |
+
results
|
137 |
+
results_old
|
138 |
+
log
|
139 |
+
checkpoint
|
140 |
+
*.pt
|
141 |
+
*.old
|
Time-Travel-Rephotography/.gitmodules
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/face_parsing"]
|
2 |
+
path = third_party/face_parsing
|
3 |
+
url = https://github.com/Time-Travel-Rephotography/face-parsing.PyTorch.git
|
4 |
+
[submodule "models/encoder4editing"]
|
5 |
+
path = models/encoder4editing
|
6 |
+
url = https://github.com/Time-Travel-Rephotography/encoder4editing.git
|
7 |
+
[submodule "losses/contextual_loss"]
|
8 |
+
path = losses/contextual_loss
|
9 |
+
url = https://github.com/Time-Travel-Rephotography/contextual_loss_pytorch.git
|
Time-Travel-Rephotography/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Time-Travel-Rephotography
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Time-Travel-Rephotography/LICENSE-NVIDIA
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
Nvidia Source Code License-NC
|
5 |
+
|
6 |
+
=======================================================================
|
7 |
+
|
8 |
+
1. Definitions
|
9 |
+
|
10 |
+
"Licensor" means any person or entity that distributes its Work.
|
11 |
+
|
12 |
+
"Software" means the original work of authorship made available under
|
13 |
+
this License.
|
14 |
+
|
15 |
+
"Work" means the Software and any additions to or derivative works of
|
16 |
+
the Software that are made available under this License.
|
17 |
+
|
18 |
+
"Nvidia Processors" means any central processing unit (CPU), graphics
|
19 |
+
processing unit (GPU), field-programmable gate array (FPGA),
|
20 |
+
application-specific integrated circuit (ASIC) or any combination
|
21 |
+
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
22 |
+
|
23 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
24 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
25 |
+
provided, however, that for the purposes of this License, derivative
|
26 |
+
works shall not include works that remain separable from, or merely
|
27 |
+
link (or bind by name) to the interfaces of, the Work.
|
28 |
+
|
29 |
+
Works, including the Software, are "made available" under this License
|
30 |
+
by including in or with the Work either (a) a copyright notice
|
31 |
+
referencing the applicability of this License to the Work, or (b) a
|
32 |
+
copy of this License.
|
33 |
+
|
34 |
+
2. License Grants
|
35 |
+
|
36 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
37 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
38 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
39 |
+
prepare derivative works of, publicly display, publicly perform,
|
40 |
+
sublicense and distribute its Work and any resulting derivative
|
41 |
+
works in any form.
|
42 |
+
|
43 |
+
3. Limitations
|
44 |
+
|
45 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
46 |
+
if (a) you do so under this License, (b) you include a complete
|
47 |
+
copy of this License with your distribution, and (c) you retain
|
48 |
+
without modification any copyright, patent, trademark, or
|
49 |
+
attribution notices that are present in the Work.
|
50 |
+
|
51 |
+
3.2 Derivative Works. You may specify that additional or different
|
52 |
+
terms apply to the use, reproduction, and distribution of your
|
53 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
54 |
+
provide that the use limitation in Section 3.3 applies to your
|
55 |
+
derivative works, and (b) you identify the specific derivative
|
56 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
57 |
+
this License (including the redistribution requirements in Section
|
58 |
+
3.1) will continue to apply to the Work itself.
|
59 |
+
|
60 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
61 |
+
may be used or intended for use non-commercially. The Work or
|
62 |
+
derivative works thereof may be used or intended for use by Nvidia
|
63 |
+
or its affiliates commercially or non-commercially. As used herein,
|
64 |
+
"non-commercially" means for research or evaluation purposes only.
|
65 |
+
|
66 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
67 |
+
against any Licensor (including any claim, cross-claim or
|
68 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
69 |
+
are infringed by any Work, then your rights under this License from
|
70 |
+
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
71 |
+
terminate immediately.
|
72 |
+
|
73 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
74 |
+
Licensor's or its affiliates' names, logos, or trademarks, except
|
75 |
+
as necessary to reproduce the notices described in this License.
|
76 |
+
|
77 |
+
3.6 Termination. If you violate any term of this License, then your
|
78 |
+
rights under this License (including the grants in Sections 2.1 and
|
79 |
+
2.2) will terminate immediately.
|
80 |
+
|
81 |
+
4. Disclaimer of Warranty.
|
82 |
+
|
83 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
84 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
85 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
86 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
87 |
+
THIS LICENSE.
|
88 |
+
|
89 |
+
5. Limitation of Liability.
|
90 |
+
|
91 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
92 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
93 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
94 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
95 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
96 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
97 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
98 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
99 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
100 |
+
|
101 |
+
=======================================================================
|
Time-Travel-Rephotography/LICENSE-STYLEGAN2
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Kim Seonghyeon
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Time-Travel-Rephotography/README.md
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [SIGGRAPH Asia 2021] Time-Travel Rephotography
|
2 |
+
<a href="https://arxiv.org/abs/2012.12261"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
|
3 |
+
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
|
4 |
+
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15D2WIF_vE2l48ddxEx45cM3RykZwQXM8?usp=sharing)
|
5 |
+
### [[Project Website](https://time-travel-rephotography.github.io/)]
|
6 |
+
|
7 |
+
<p align='center'>
|
8 |
+
<img src="time-travel-rephotography.gif" width='100%'/>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
Many historical people were only ever captured by old, faded, black and white photos, that are distorted due to the limitations of early cameras and the passage of time. This paper simulates traveling back in time with a modern camera to rephotograph famous subjects. Unlike conventional image restoration filters which apply independent operations like denoising, colorization, and superresolution, we leverage the StyleGAN2 framework to project old photos into the space of modern high-resolution photos, achieving all of these effects in a unified framework. A unique challenge with this approach is retaining the identity and pose of the subject in the original photo, while discarding the many artifacts frequently seen in low-quality antique photos. Our comparisons to current state-of-the-art restoration filters show significant improvements and compelling results for a variety of important historical people.
|
12 |
+
<br/>
|
13 |
+
|
14 |
+
**Time-Travel Rephotography**
|
15 |
+
<br/>
|
16 |
+
[Xuan Luo](https://roxanneluo.github.io),
|
17 |
+
[Xuaner Zhang](https://people.eecs.berkeley.edu/~cecilia77/),
|
18 |
+
[Paul Yoo](https://www.linkedin.com/in/paul-yoo-768a3715b),
|
19 |
+
[Ricardo Martin-Brualla](http://www.ricardomartinbrualla.com/),
|
20 |
+
[Jason Lawrence](http://jasonlawrence.info/), and
|
21 |
+
[Steven M. Seitz](https://homes.cs.washington.edu/~seitz/)
|
22 |
+
<br/>
|
23 |
+
In SIGGRAPH Asia 2021.
|
24 |
+
|
25 |
+
## Demo
|
26 |
+
We provide an easy-to-get-started demo using Google Colab!
|
27 |
+
The Colab will allow you to try our method on the sample Abraham Lincoln photo or **your own photos** using Cloud GPUs on Google Colab.
|
28 |
+
|
29 |
+
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15D2WIF_vE2l48ddxEx45cM3RykZwQXM8?usp=sharing)
|
30 |
+
|
31 |
+
Or you can run our method on your own machine following the instructions below.
|
32 |
+
|
33 |
+
## Prerequisite
|
34 |
+
- Pull third-party packages.
|
35 |
+
```
|
36 |
+
git submodule update --init --recursive
|
37 |
+
```
|
38 |
+
- Install python packages.
|
39 |
+
```
|
40 |
+
conda create --name rephotography python=3.8.5
|
41 |
+
conda activate rephotography
|
42 |
+
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
## Quick Start
|
47 |
+
Run our method on the example photo of Abraham Lincoln.
|
48 |
+
- Download models:
|
49 |
+
```
|
50 |
+
./scripts/download_checkpoints.sh
|
51 |
+
```
|
52 |
+
- Run:
|
53 |
+
```
|
54 |
+
./scripts/run.sh b "dataset/Abraham Lincoln_01.png" 0.75
|
55 |
+
```
|
56 |
+
- You can inspect the optimization process by
|
57 |
+
```
|
58 |
+
tensorboard --logdir "log/Abraham Lincoln_01"
|
59 |
+
```
|
60 |
+
- You can find your results as below.
|
61 |
+
```
|
62 |
+
results/
|
63 |
+
Abraham Lincoln_01/ # intermediate outputs for histogram matching and face parsing
|
64 |
+
Abraham Lincoln_01_b.png # the input after matching the histogram of the sibling image
|
65 |
+
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-init.png # the sibling image
|
66 |
+
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-init.pt # the sibing latent codes and initialized noise maps
|
67 |
+
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750).png # the output result
|
68 |
+
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750).pt # the final optimized latent codes and noise maps
|
69 |
+
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-rand.png # the result with the final latent codes but random noise maps
|
70 |
+
|
71 |
+
```
|
72 |
+
|
73 |
+
## Run on Your Own Image
|
74 |
+
- Crop and align the head regions of your images:
|
75 |
+
```
|
76 |
+
python -m tools.data.align_images <input_raw_image_dir> <aligned_image_dir>
|
77 |
+
```
|
78 |
+
- Run:
|
79 |
+
```
|
80 |
+
./scripts/run.sh <spectral_sensitivity> <input_image_path> <blur_radius>
|
81 |
+
```
|
82 |
+
The `spectral_sensitivity` can be `b` (blue-sensitive), `gb` (orthochromatic), or `g` (panchromatic). You can roughly estimate the `spectral_sensitivity` of your photo as follows. Use the *blue-sensitive* model for photos before 1873, manually select between blue-sensitive and *orthochromatic* for images from 1873 to 1906 and among all models for photos taken afterwards.
|
83 |
+
|
84 |
+
The `blur_radius` is the estimated gaussian blur radius in pixels if the input photot is resized to 1024x1024.
|
85 |
+
|
86 |
+
## Historical Wiki Face Dataset
|
87 |
+
| Path | Size | Description |
|
88 |
+
|----------- | ----------- | ----------- |
|
89 |
+
| [Historical Wiki Face Dataset.zip](https://drive.google.com/open?id=1mgC2U7quhKSz_lTL97M-0cPrIILTiUCE&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 148 MB | Images|
|
90 |
+
| [spectral_sensitivity.json](https://drive.google.com/open?id=1n3Bqd8G0g-wNpshlgoZiOMXxLlOycAXr&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 6 KB | Spectral sensitivity (`b`, `gb`, or `g`). |
|
91 |
+
| [blur_radius.json](https://drive.google.com/open?id=1n4vUsbQo2BcxtKVMGfD1wFHaINzEmAVP&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 6 KB | Blur radius in pixels|
|
92 |
+
|
93 |
+
The `json`s are dictionares that map input names to the corresponding spectral sensitivity or blur radius.
|
94 |
+
Due to copyright constraints, `Historical Wiki Face Dataset.zip` contains all images in the *Historical Wiki Face Dataset* that were used in our user study except the photo of [Mao Zedong](https://en.wikipedia.org/wiki/File:Mao_Zedong_in_1959_%28cropped%29.jpg). You can download it separately and crop it as [above](#run-on-your-own-image).
|
95 |
+
|
96 |
+
## Citation
|
97 |
+
If you find our code useful, please consider citing our paper:
|
98 |
+
```
|
99 |
+
@article{Luo-Rephotography-2021,
|
100 |
+
author = {Luo, Xuan and Zhang, Xuaner and Yoo, Paul and Martin-Brualla, Ricardo and Lawrence, Jason and Seitz, Steven M.},
|
101 |
+
title = {Time-Travel Rephotography},
|
102 |
+
journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH Asia 2021)},
|
103 |
+
publisher = {ACM New York, NY, USA},
|
104 |
+
volume = {40},
|
105 |
+
number = {6},
|
106 |
+
articleno = {213},
|
107 |
+
doi = {https://doi.org/10.1145/3478513.3480485},
|
108 |
+
year = {2021},
|
109 |
+
month = {12}
|
110 |
+
}
|
111 |
+
```
|
112 |
+
|
113 |
+
## License
|
114 |
+
This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
|
115 |
+
|
116 |
+
Codes for the StyleGAN2 model come from [https://github.com/rosinality/stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
|
117 |
+
|
118 |
+
## Acknowledgments
|
119 |
+
We thank [Nick Brandreth](https://www.nickbrandreth.com/) for capturing the dry plate photos. We thank Bo Zhang, Qingnan Fan, Roy Or-El, Aleksander Holynski and Keunhong Park for insightful advice.
|
Time-Travel-Rephotography/app.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
import os
|
3 |
+
from os.path import join as pjoin
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from typing import (
|
7 |
+
Iterable,
|
8 |
+
Optional,
|
9 |
+
)
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
+
import torch
|
15 |
+
from torch.utils.tensorboard import SummaryWriter
|
16 |
+
from torchvision.transforms import (
|
17 |
+
Compose,
|
18 |
+
Grayscale,
|
19 |
+
Resize,
|
20 |
+
ToTensor,
|
21 |
+
Normalize,
|
22 |
+
)
|
23 |
+
|
24 |
+
from losses.joint_loss import JointLoss
|
25 |
+
from model import Generator
|
26 |
+
from tools.initialize import Initializer
|
27 |
+
from tools.match_skin_histogram import match_skin_histogram
|
28 |
+
from utils.projector_arguments import ProjectorArguments
|
29 |
+
from utils import torch_helpers as th
|
30 |
+
from utils.torch_helpers import make_image
|
31 |
+
from utils.misc import stem
|
32 |
+
from utils.optimize import Optimizer
|
33 |
+
from models.degrade import (
|
34 |
+
Degrade,
|
35 |
+
Downsample,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def set_random_seed(seed: int):
|
40 |
+
# FIXME (xuanluo): this setup still allows randomness somehow
|
41 |
+
torch.manual_seed(seed)
|
42 |
+
random.seed(seed)
|
43 |
+
np.random.seed(seed)
|
44 |
+
|
45 |
+
|
46 |
+
def read_images(paths: str, max_size: Optional[int] = None):
|
47 |
+
transform = Compose(
|
48 |
+
[
|
49 |
+
Grayscale(),
|
50 |
+
ToTensor(),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
imgs = []
|
55 |
+
for path in paths:
|
56 |
+
img = Image.open(path)
|
57 |
+
if max_size is not None and img.width > max_size:
|
58 |
+
img = img.resize((max_size, max_size))
|
59 |
+
img = transform(img)
|
60 |
+
imgs.append(img)
|
61 |
+
imgs = torch.stack(imgs, 0)
|
62 |
+
return imgs
|
63 |
+
|
64 |
+
|
65 |
+
def normalize(img: torch.Tensor, mean=0.5, std=0.5):
|
66 |
+
"""[0, 1] -> [-1, 1]"""
|
67 |
+
return (img - mean) / std
|
68 |
+
|
69 |
+
|
70 |
+
def create_generator(args: Namespace, device: torch.device):
|
71 |
+
generator = Generator(args.generator_size, 512, 8)
|
72 |
+
generator.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False)
|
73 |
+
generator.eval()
|
74 |
+
generator = generator.to(device)
|
75 |
+
return generator
|
76 |
+
|
77 |
+
|
78 |
+
def save(
|
79 |
+
path_prefixes: Iterable[str],
|
80 |
+
imgs: torch.Tensor, # BCHW
|
81 |
+
latents: torch.Tensor,
|
82 |
+
noises: torch.Tensor,
|
83 |
+
imgs_rand: Optional[torch.Tensor] = None,
|
84 |
+
):
|
85 |
+
assert len(path_prefixes) == len(imgs) and len(latents) == len(path_prefixes)
|
86 |
+
if imgs_rand is not None:
|
87 |
+
assert len(imgs) == len(imgs_rand)
|
88 |
+
imgs_arr = make_image(imgs)
|
89 |
+
for path_prefix, img, latent, noise in zip(path_prefixes, imgs_arr, latents, noises):
|
90 |
+
os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
|
91 |
+
cv2.imwrite(path_prefix + ".png", img[...,::-1])
|
92 |
+
torch.save({"latent": latent.detach().cpu(), "noise": noise.detach().cpu()},
|
93 |
+
path_prefix + ".pt")
|
94 |
+
|
95 |
+
if imgs_rand is not None:
|
96 |
+
imgs_arr = make_image(imgs_rand)
|
97 |
+
for path_prefix, img in zip(path_prefixes, imgs_arr):
|
98 |
+
cv2.imwrite(path_prefix + "-rand.png", img[...,::-1])
|
99 |
+
|
100 |
+
|
101 |
+
def main(args):
|
102 |
+
opt_str = ProjectorArguments.to_string(args)
|
103 |
+
print(opt_str)
|
104 |
+
|
105 |
+
if args.rand_seed is not None:
|
106 |
+
set_random_seed(args.rand_seed)
|
107 |
+
device = th.device()
|
108 |
+
|
109 |
+
# read inputs. TODO imgs_orig has channel 1
|
110 |
+
imgs_orig = read_images([args.input], max_size=args.generator_size).to(device)
|
111 |
+
imgs = normalize(imgs_orig) # actually this will be overwritten by the histogram matching result
|
112 |
+
|
113 |
+
# initialize
|
114 |
+
with torch.no_grad():
|
115 |
+
init = Initializer(args).to(device)
|
116 |
+
latent_init = init(imgs_orig)
|
117 |
+
|
118 |
+
# create generator
|
119 |
+
generator = create_generator(args, device)
|
120 |
+
|
121 |
+
# init noises
|
122 |
+
with torch.no_grad():
|
123 |
+
noises_init = generator.make_noise()
|
124 |
+
|
125 |
+
# create a new input by matching the input's histogram to the sibling image
|
126 |
+
with torch.no_grad():
|
127 |
+
sibling, _, sibling_rgbs = generator([latent_init], input_is_latent=True, noise=noises_init)
|
128 |
+
mh_dir = pjoin(args.results_dir, stem(args.input))
|
129 |
+
imgs = match_skin_histogram(
|
130 |
+
imgs, sibling,
|
131 |
+
args.spectral_sensitivity,
|
132 |
+
pjoin(mh_dir, "input_sibling"),
|
133 |
+
pjoin(mh_dir, "skin_mask"),
|
134 |
+
matched_hist_fn=mh_dir.rstrip(os.sep) + f"_{args.spectral_sensitivity}.png",
|
135 |
+
normalize=normalize,
|
136 |
+
).to(device)
|
137 |
+
torch.cuda.empty_cache()
|
138 |
+
# TODO imgs has channel 3
|
139 |
+
|
140 |
+
degrade = Degrade(args).to(device)
|
141 |
+
|
142 |
+
rgb_levels = generator.get_latent_size(args.coarse_min) // 2 + len(args.wplus_step) - 1
|
143 |
+
criterion = JointLoss(
|
144 |
+
args, imgs,
|
145 |
+
sibling=sibling.detach(), sibling_rgbs=sibling_rgbs[:rgb_levels]).to(device)
|
146 |
+
|
147 |
+
# save initialization
|
148 |
+
save(
|
149 |
+
[pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}-init")],
|
150 |
+
sibling, latent_init, noises_init,
|
151 |
+
)
|
152 |
+
|
153 |
+
writer = SummaryWriter(pjoin(args.log_dir, f"{stem(args.input)}/{opt_str}"))
|
154 |
+
# start optimize
|
155 |
+
latent, noises = Optimizer.optimize(generator, criterion, degrade, imgs, latent_init, noises_init, args, writer=writer)
|
156 |
+
|
157 |
+
# generate output
|
158 |
+
img_out, _, _ = generator([latent], input_is_latent=True, noise=noises)
|
159 |
+
img_out_rand_noise, _, _ = generator([latent], input_is_latent=True)
|
160 |
+
# save output
|
161 |
+
save(
|
162 |
+
[pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}")],
|
163 |
+
img_out, latent, noises,
|
164 |
+
imgs_rand=img_out_rand_noise
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
def parse_args():
|
169 |
+
return ProjectorArguments().parse()
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
sys.exit(main(parse_args()))
|
Time-Travel-Rephotography/losses/color_transfer_loss.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.functional import (
|
6 |
+
smooth_l1_loss,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
def flatten_CHW(im: torch.Tensor) -> torch.Tensor:
|
11 |
+
"""
|
12 |
+
(B, C, H, W) -> (B, -1)
|
13 |
+
"""
|
14 |
+
B = im.shape[0]
|
15 |
+
return im.reshape(B, -1)
|
16 |
+
|
17 |
+
|
18 |
+
def stddev(x: torch.Tensor) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
x: (B, -1), assume with mean normalized
|
21 |
+
Retuens:
|
22 |
+
stddev: (B)
|
23 |
+
"""
|
24 |
+
return torch.sqrt(torch.mean(x * x, dim=-1))
|
25 |
+
|
26 |
+
|
27 |
+
def gram_matrix(input_):
|
28 |
+
B, C = input_.shape[:2]
|
29 |
+
features = input_.view(B, C, -1)
|
30 |
+
N = features.shape[-1]
|
31 |
+
G = torch.bmm(features, features.transpose(1, 2)) # C x C
|
32 |
+
return G.div(C * N)
|
33 |
+
|
34 |
+
|
35 |
+
class ColorTransferLoss(nn.Module):
|
36 |
+
"""Penalize the gram matrix difference between StyleGAN2's ToRGB outputs"""
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
init_rgbs,
|
40 |
+
scale_rgb: bool = False
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
with torch.no_grad():
|
45 |
+
init_feats = [x.detach() for x in init_rgbs]
|
46 |
+
self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats] # (B, 1, 1, 1) or scalar
|
47 |
+
self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)]
|
48 |
+
|
49 |
+
def forward(self, rgbs: List[torch.Tensor], level: int = None):
|
50 |
+
if level is None:
|
51 |
+
level = len(self.grams)
|
52 |
+
|
53 |
+
feats = rgbs
|
54 |
+
loss = 0
|
55 |
+
for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])):
|
56 |
+
G = gram_matrix(rgb / std)
|
57 |
+
loss = loss + smooth_l1_loss(G, self.grams[i])
|
58 |
+
|
59 |
+
return loss
|
60 |
+
|
Time-Travel-Rephotography/losses/joint_loss.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import (
|
2 |
+
ArgumentParser,
|
3 |
+
Namespace,
|
4 |
+
)
|
5 |
+
from typing import (
|
6 |
+
Dict,
|
7 |
+
Iterable,
|
8 |
+
Optional,
|
9 |
+
Tuple,
|
10 |
+
)
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from utils.misc import (
|
17 |
+
optional_string,
|
18 |
+
iterable_to_str,
|
19 |
+
)
|
20 |
+
|
21 |
+
from .contextual_loss import ContextualLoss
|
22 |
+
from .color_transfer_loss import ColorTransferLoss
|
23 |
+
from .regularize_noise import NoiseRegularizer
|
24 |
+
from .reconstruction import (
|
25 |
+
EyeLoss,
|
26 |
+
FaceLoss,
|
27 |
+
create_perceptual_loss,
|
28 |
+
ReconstructionArguments,
|
29 |
+
)
|
30 |
+
|
31 |
+
class LossArguments:
|
32 |
+
@staticmethod
|
33 |
+
def add_arguments(parser: ArgumentParser):
|
34 |
+
ReconstructionArguments.add_arguments(parser)
|
35 |
+
|
36 |
+
parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight")
|
37 |
+
parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight")
|
38 |
+
parser.add_argument('--noise_regularize', type=float, default=5e4)
|
39 |
+
# contextual loss
|
40 |
+
parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight")
|
41 |
+
parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers",
|
42 |
+
choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'],
|
43 |
+
default=['relu3_4', 'relu2_2', 'relu1_2'])
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def to_string(args: Namespace) -> str:
|
47 |
+
return (
|
48 |
+
ReconstructionArguments.to_string(args)
|
49 |
+
+ optional_string(args.eye > 0, f"-eye{args.eye}")
|
50 |
+
+ optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}")
|
51 |
+
+ optional_string(
|
52 |
+
args.contextual,
|
53 |
+
f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})"
|
54 |
+
)
|
55 |
+
#+ optional_string(args.mse, f"-mse{args.mse}")
|
56 |
+
+ optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}")
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
class BakedMultiContextualLoss(nn.Module):
|
61 |
+
"""Random sample different image patches for different vgg layers."""
|
62 |
+
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer])
|
66 |
+
for layer in args.cx_layers])
|
67 |
+
self.size = size
|
68 |
+
self.sibling = sibling.detach()
|
69 |
+
|
70 |
+
def forward(self, img: torch.Tensor):
|
71 |
+
cx_loss = 0
|
72 |
+
for cx in self.cxs:
|
73 |
+
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
|
74 |
+
cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss
|
75 |
+
return cx_loss
|
76 |
+
|
77 |
+
|
78 |
+
class BakedContextualLoss(ContextualLoss):
|
79 |
+
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
|
80 |
+
super().__init__(use_vgg=True, vgg_layers=args.cx_layers)
|
81 |
+
self.size = size
|
82 |
+
self.sibling = sibling.detach()
|
83 |
+
|
84 |
+
def forward(self, img: torch.Tensor):
|
85 |
+
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
|
86 |
+
return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size])
|
87 |
+
|
88 |
+
|
89 |
+
class JointLoss(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
args: Namespace,
|
93 |
+
target: torch.Tensor,
|
94 |
+
sibling: Optional[torch.Tensor],
|
95 |
+
sibling_rgbs: Optional[Iterable[torch.Tensor]] = None,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
self.weights = {
|
100 |
+
"face": 1., "eye": args.eye,
|
101 |
+
"contextual": args.contextual, "color_transfer": args.color_transfer,
|
102 |
+
"noise": args.noise_regularize,
|
103 |
+
}
|
104 |
+
|
105 |
+
reconstruction = {}
|
106 |
+
if args.vgg > 0 or args.vggface > 0:
|
107 |
+
percept = create_perceptual_loss(args)
|
108 |
+
reconstruction.update(
|
109 |
+
{"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)}
|
110 |
+
)
|
111 |
+
if args.eye > 0:
|
112 |
+
reconstruction.update(
|
113 |
+
{"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)}
|
114 |
+
)
|
115 |
+
self.reconstruction = nn.ModuleDict(reconstruction)
|
116 |
+
|
117 |
+
exemplar = {}
|
118 |
+
if args.contextual > 0 and len(args.cx_layers) > 0:
|
119 |
+
assert sibling is not None
|
120 |
+
exemplar.update(
|
121 |
+
{"contextual": BakedContextualLoss(sibling, args)}
|
122 |
+
)
|
123 |
+
if args.color_transfer > 0:
|
124 |
+
assert sibling_rgbs is not None
|
125 |
+
self.sibling_rgbs = sibling_rgbs
|
126 |
+
exemplar.update(
|
127 |
+
{"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)}
|
128 |
+
)
|
129 |
+
self.exemplar = nn.ModuleDict(exemplar)
|
130 |
+
|
131 |
+
if args.noise_regularize > 0:
|
132 |
+
self.noise_criterion = NoiseRegularizer()
|
133 |
+
|
134 |
+
def forward(
|
135 |
+
self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None
|
136 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
rgbs: results from the ToRGB layers
|
140 |
+
"""
|
141 |
+
# TODO: add current optimization resolution for noises
|
142 |
+
|
143 |
+
losses = {}
|
144 |
+
|
145 |
+
# reconstruction losses
|
146 |
+
for name, criterion in self.reconstruction.items():
|
147 |
+
losses[name] = criterion(img, degrade=degrade)
|
148 |
+
|
149 |
+
# exemplar losses
|
150 |
+
if 'contextual' in self.exemplar:
|
151 |
+
losses["contextual"] = self.exemplar["contextual"](img)
|
152 |
+
if "color_transfer" in self.exemplar:
|
153 |
+
assert rgbs is not None
|
154 |
+
losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level)
|
155 |
+
|
156 |
+
# noise regularizer
|
157 |
+
if self.weights["noise"] > 0:
|
158 |
+
losses["noise"] = self.noise_criterion(noises)
|
159 |
+
|
160 |
+
total_loss = 0
|
161 |
+
for name, loss in losses.items():
|
162 |
+
total_loss = total_loss + self.weights[name] * loss
|
163 |
+
return total_loss, losses
|
164 |
+
|
165 |
+
def update_sibling(self, sibling: torch.Tensor):
|
166 |
+
assert "contextual" in self.exemplar
|
167 |
+
self.exemplar["contextual"].sibling = sibling.detach()
|
Time-Travel-Rephotography/losses/perceptual_loss.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from models.vggface import VGGFaceFeats
|
7 |
+
|
8 |
+
|
9 |
+
def cos_loss(fi, ft):
|
10 |
+
return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean()
|
11 |
+
|
12 |
+
|
13 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
14 |
+
def __init__(self, resize=False):
|
15 |
+
super(VGGPerceptualLoss, self).__init__()
|
16 |
+
blocks = []
|
17 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
18 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
19 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
20 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
21 |
+
for bl in blocks:
|
22 |
+
for p in bl:
|
23 |
+
p.requires_grad = False
|
24 |
+
self.blocks = torch.nn.ModuleList(blocks)
|
25 |
+
self.transform = torch.nn.functional.interpolate
|
26 |
+
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
27 |
+
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
28 |
+
self.resize = resize
|
29 |
+
|
30 |
+
def forward(self, input, target, max_layer=4, cos_dist: bool = False):
|
31 |
+
target = (target + 1) * 0.5
|
32 |
+
input = (input + 1) * 0.5
|
33 |
+
|
34 |
+
if input.shape[1] != 3:
|
35 |
+
input = input.repeat(1, 3, 1, 1)
|
36 |
+
target = target.repeat(1, 3, 1, 1)
|
37 |
+
input = (input-self.mean) / self.std
|
38 |
+
target = (target-self.mean) / self.std
|
39 |
+
if self.resize:
|
40 |
+
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
41 |
+
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
42 |
+
x = input
|
43 |
+
y = target
|
44 |
+
loss = 0.0
|
45 |
+
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
|
46 |
+
for bi, block in enumerate(self.blocks[:max_layer]):
|
47 |
+
x = block(x)
|
48 |
+
y = block(y)
|
49 |
+
loss += loss_func(x, y.detach())
|
50 |
+
return loss
|
51 |
+
|
52 |
+
|
53 |
+
class VGGFacePerceptualLoss(torch.nn.Module):
|
54 |
+
def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False):
|
55 |
+
super().__init__()
|
56 |
+
self.vgg = VGGFaceFeats()
|
57 |
+
self.vgg.load_state_dict(torch.load(weight_path))
|
58 |
+
|
59 |
+
mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0
|
60 |
+
self.register_buffer("mean", mean)
|
61 |
+
|
62 |
+
self.transform = torch.nn.functional.interpolate
|
63 |
+
self.resize = resize
|
64 |
+
|
65 |
+
def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False):
|
66 |
+
target = (target + 1) * 0.5
|
67 |
+
input = (input + 1) * 0.5
|
68 |
+
|
69 |
+
# preprocessing
|
70 |
+
if input.shape[1] != 3:
|
71 |
+
input = input.repeat(1, 3, 1, 1)
|
72 |
+
target = target.repeat(1, 3, 1, 1)
|
73 |
+
input = input - self.mean
|
74 |
+
target = target - self.mean
|
75 |
+
if self.resize:
|
76 |
+
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
77 |
+
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
78 |
+
|
79 |
+
input_feats = self.vgg(input)
|
80 |
+
target_feats = self.vgg(target)
|
81 |
+
|
82 |
+
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
|
83 |
+
# calc perceptual loss
|
84 |
+
loss = 0.0
|
85 |
+
for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]):
|
86 |
+
loss = loss + loss_func(fi, ft.detach())
|
87 |
+
return loss
|
88 |
+
|
89 |
+
|
90 |
+
class PerceptualLoss(torch.nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1, eps: float = 1e-8, cos_dist: bool = False
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface))
|
96 |
+
self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg))
|
97 |
+
self.cos_dist = cos_dist
|
98 |
+
|
99 |
+
if lambda_vgg > eps:
|
100 |
+
self.vgg = VGGPerceptualLoss()
|
101 |
+
if lambda_vggface > eps:
|
102 |
+
self.vggface = VGGFacePerceptualLoss()
|
103 |
+
|
104 |
+
def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4):
|
105 |
+
loss = 0.0
|
106 |
+
if self.lambda_vgg > eps and use_vgg:
|
107 |
+
loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer)
|
108 |
+
if self.lambda_vggface > eps and use_vggface:
|
109 |
+
loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist)
|
110 |
+
return loss
|
111 |
+
|
Time-Travel-Rephotography/losses/reconstruction.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import (
|
2 |
+
ArgumentParser,
|
3 |
+
Namespace,
|
4 |
+
)
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from losses.perceptual_loss import PerceptualLoss
|
12 |
+
from models.degrade import Downsample
|
13 |
+
from utils.misc import optional_string
|
14 |
+
|
15 |
+
|
16 |
+
class ReconstructionArguments:
|
17 |
+
@staticmethod
|
18 |
+
def add_arguments(parser: ArgumentParser):
|
19 |
+
parser.add_argument("--vggface", type=float, default=0.3, help="vggface")
|
20 |
+
parser.add_argument("--vgg", type=float, default=1, help="vgg")
|
21 |
+
parser.add_argument('--recon_size', type=int, default=256, help="size for face reconstruction loss")
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def to_string(args: Namespace) -> str:
|
25 |
+
return (
|
26 |
+
f"s{args.recon_size}"
|
27 |
+
+ optional_string(args.vgg > 0, f"-vgg{args.vgg}")
|
28 |
+
+ optional_string(args.vggface > 0, f"-vggface{args.vggface}")
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def create_perceptual_loss(args: Namespace):
|
33 |
+
return PerceptualLoss(lambda_vgg=args.vgg, lambda_vggface=args.vggface, cos_dist=False)
|
34 |
+
|
35 |
+
|
36 |
+
class EyeLoss(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
target: torch.Tensor,
|
40 |
+
input_size: int = 1024,
|
41 |
+
input_channels: int = 3,
|
42 |
+
percept: Optional[nn.Module] = None,
|
43 |
+
args: Optional[Namespace] = None
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
target: target image
|
47 |
+
"""
|
48 |
+
assert not (percept is None and args is None)
|
49 |
+
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.target = target
|
53 |
+
|
54 |
+
target_size = target.shape[-1]
|
55 |
+
self.downsample = Downsample(input_size, target_size, input_channels) \
|
56 |
+
if target_size != input_size else (lambda x: x)
|
57 |
+
|
58 |
+
self.percept = percept if percept is not None else create_perceptual_loss(args)
|
59 |
+
|
60 |
+
eye_size = np.array((224, 224))
|
61 |
+
btlrs = []
|
62 |
+
for sgn in [1, -1]:
|
63 |
+
center = np.array((480, 384 * sgn)) # (y, x)
|
64 |
+
b, t = center[0] - eye_size[0] // 2, center[0] + eye_size[0] // 2
|
65 |
+
l, r = center[1] - eye_size[1] // 2, center[1] + eye_size[1] // 2
|
66 |
+
btlrs.append((np.array((b, t, l, r)) / 1024 * target_size).astype(int))
|
67 |
+
self.btlrs = np.stack(btlrs, axis=0)
|
68 |
+
|
69 |
+
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
|
70 |
+
"""
|
71 |
+
img: it should be the degraded version of the generated image
|
72 |
+
"""
|
73 |
+
if degrade is not None:
|
74 |
+
img = degrade(img, downsample=self.downsample)
|
75 |
+
|
76 |
+
loss = 0
|
77 |
+
for (b, t, l, r) in self.btlrs:
|
78 |
+
loss = loss + self.percept(
|
79 |
+
img[:, :, b:t, l:r], self.target[:, :, b:t, l:r],
|
80 |
+
use_vggface=False, max_vgg_layer=4,
|
81 |
+
# use_vgg=False,
|
82 |
+
)
|
83 |
+
return loss
|
84 |
+
|
85 |
+
|
86 |
+
class FaceLoss(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
target: torch.Tensor,
|
90 |
+
input_size: int = 1024,
|
91 |
+
input_channels: int = 3,
|
92 |
+
size: int = 256,
|
93 |
+
percept: Optional[nn.Module] = None,
|
94 |
+
args: Optional[Namespace] = None
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
target: target image
|
98 |
+
"""
|
99 |
+
assert not (percept is None and args is None)
|
100 |
+
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
target_size = target.shape[-1]
|
104 |
+
self.target = target if target_size == size \
|
105 |
+
else Downsample(target_size, size, target.shape[1]).to(target.device)(target)
|
106 |
+
|
107 |
+
self.downsample = Downsample(input_size, size, input_channels) \
|
108 |
+
if size != input_size else (lambda x: x)
|
109 |
+
|
110 |
+
self.percept = percept if percept is not None else create_perceptual_loss(args)
|
111 |
+
|
112 |
+
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
|
113 |
+
"""
|
114 |
+
img: it should be the degraded version of the generated image
|
115 |
+
"""
|
116 |
+
if degrade is not None:
|
117 |
+
img = degrade(img, downsample=self.downsample)
|
118 |
+
loss = self.percept(img, self.target)
|
119 |
+
return loss
|
Time-Travel-Rephotography/losses/regularize_noise.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class NoiseRegularizer(nn.Module):
|
8 |
+
def forward(self, noises: Iterable[torch.Tensor]):
|
9 |
+
loss = 0
|
10 |
+
|
11 |
+
for noise in noises:
|
12 |
+
size = noise.shape[2]
|
13 |
+
|
14 |
+
while True:
|
15 |
+
loss = (
|
16 |
+
loss
|
17 |
+
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
18 |
+
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
19 |
+
)
|
20 |
+
|
21 |
+
if size <= 8:
|
22 |
+
break
|
23 |
+
|
24 |
+
noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
|
25 |
+
noise = noise.mean([3, 5])
|
26 |
+
size //= 2
|
27 |
+
|
28 |
+
return loss
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def normalize(noises: Iterable[torch.Tensor]):
|
32 |
+
for noise in noises:
|
33 |
+
mean = noise.mean()
|
34 |
+
std = noise.std()
|
35 |
+
|
36 |
+
noise.data.add_(-mean).div_(std)
|
37 |
+
|
Time-Travel-Rephotography/model.py
ADDED
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import functools
|
4 |
+
import operator
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.autograd import Function
|
11 |
+
|
12 |
+
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
13 |
+
|
14 |
+
|
15 |
+
class PixelNorm(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def forward(self, input):
|
20 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
21 |
+
|
22 |
+
|
23 |
+
def make_kernel(k):
|
24 |
+
k = torch.tensor(k, dtype=torch.float32)
|
25 |
+
|
26 |
+
if k.ndim == 1:
|
27 |
+
k = k[None, :] * k[:, None]
|
28 |
+
|
29 |
+
k /= k.sum()
|
30 |
+
|
31 |
+
return k
|
32 |
+
|
33 |
+
|
34 |
+
class Upsample(nn.Module):
|
35 |
+
def __init__(self, kernel, factor=2):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.factor = factor
|
39 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
40 |
+
self.register_buffer('kernel', kernel)
|
41 |
+
|
42 |
+
p = kernel.shape[0] - factor
|
43 |
+
|
44 |
+
pad0 = (p + 1) // 2 + factor - 1
|
45 |
+
pad1 = p // 2
|
46 |
+
|
47 |
+
self.pad = (pad0, pad1)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Downsample(nn.Module):
|
56 |
+
def __init__(self, kernel, factor=2):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.factor = factor
|
60 |
+
kernel = make_kernel(kernel)
|
61 |
+
self.register_buffer('kernel', kernel)
|
62 |
+
|
63 |
+
p = kernel.shape[0] - factor
|
64 |
+
|
65 |
+
pad0 = (p + 1) // 2
|
66 |
+
pad1 = p // 2
|
67 |
+
|
68 |
+
self.pad = (pad0, pad1)
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
class Blur(nn.Module):
|
77 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
kernel = make_kernel(kernel)
|
81 |
+
|
82 |
+
if upsample_factor > 1:
|
83 |
+
kernel = kernel * (upsample_factor ** 2)
|
84 |
+
|
85 |
+
self.register_buffer('kernel', kernel)
|
86 |
+
|
87 |
+
self.pad = pad
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
91 |
+
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class EqualConv2d(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.weight = nn.Parameter(
|
102 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
103 |
+
)
|
104 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
105 |
+
|
106 |
+
self.stride = stride
|
107 |
+
self.padding = padding
|
108 |
+
|
109 |
+
if bias:
|
110 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
111 |
+
|
112 |
+
else:
|
113 |
+
self.bias = None
|
114 |
+
|
115 |
+
def forward(self, input):
|
116 |
+
out = F.conv2d(
|
117 |
+
input,
|
118 |
+
self.weight * self.scale,
|
119 |
+
bias=self.bias,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=self.padding,
|
122 |
+
)
|
123 |
+
|
124 |
+
return out
|
125 |
+
|
126 |
+
def __repr__(self):
|
127 |
+
return (
|
128 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
129 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
class EqualLinear(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
140 |
+
|
141 |
+
if bias:
|
142 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
143 |
+
|
144 |
+
else:
|
145 |
+
self.bias = None
|
146 |
+
|
147 |
+
self.activation = activation
|
148 |
+
|
149 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
150 |
+
self.lr_mul = lr_mul
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
if self.activation:
|
154 |
+
out = F.linear(input, self.weight * self.scale)
|
155 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
156 |
+
|
157 |
+
else:
|
158 |
+
out = F.linear(
|
159 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
160 |
+
)
|
161 |
+
|
162 |
+
return out
|
163 |
+
|
164 |
+
def __repr__(self):
|
165 |
+
return (
|
166 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
class ScaledLeakyReLU(nn.Module):
|
171 |
+
def __init__(self, negative_slope=0.2):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
self.negative_slope = negative_slope
|
175 |
+
|
176 |
+
def forward(self, input):
|
177 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
178 |
+
|
179 |
+
return out * math.sqrt(2)
|
180 |
+
|
181 |
+
|
182 |
+
class ModulatedConv2d(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_channel,
|
186 |
+
out_channel,
|
187 |
+
kernel_size,
|
188 |
+
style_dim,
|
189 |
+
demodulate=True,
|
190 |
+
upsample=False,
|
191 |
+
downsample=False,
|
192 |
+
blur_kernel=[1, 3, 3, 1],
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.eps = 1e-8
|
197 |
+
self.kernel_size = kernel_size
|
198 |
+
self.in_channel = in_channel
|
199 |
+
self.out_channel = out_channel
|
200 |
+
self.upsample = upsample
|
201 |
+
self.downsample = downsample
|
202 |
+
|
203 |
+
if upsample:
|
204 |
+
factor = 2
|
205 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
206 |
+
pad0 = (p + 1) // 2 + factor - 1
|
207 |
+
pad1 = p // 2 + 1
|
208 |
+
|
209 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
210 |
+
|
211 |
+
if downsample:
|
212 |
+
factor = 2
|
213 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
214 |
+
pad0 = (p + 1) // 2
|
215 |
+
pad1 = p // 2
|
216 |
+
|
217 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
218 |
+
|
219 |
+
fan_in = in_channel * kernel_size ** 2
|
220 |
+
self.scale = 1 / math.sqrt(fan_in)
|
221 |
+
self.padding = kernel_size // 2
|
222 |
+
|
223 |
+
self.weight = nn.Parameter(
|
224 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
225 |
+
)
|
226 |
+
|
227 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
228 |
+
|
229 |
+
self.demodulate = demodulate
|
230 |
+
|
231 |
+
def __repr__(self):
|
232 |
+
return (
|
233 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
234 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, input, style):
|
238 |
+
batch, in_channel, height, width = input.shape
|
239 |
+
|
240 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
241 |
+
weight = self.scale * self.weight * style
|
242 |
+
|
243 |
+
if self.demodulate:
|
244 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
245 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
246 |
+
|
247 |
+
weight = weight.view(
|
248 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
249 |
+
)
|
250 |
+
|
251 |
+
if self.upsample:
|
252 |
+
input = input.view(1, batch * in_channel, height, width)
|
253 |
+
weight = weight.view(
|
254 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
255 |
+
)
|
256 |
+
weight = weight.transpose(1, 2).reshape(
|
257 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
258 |
+
)
|
259 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
260 |
+
_, _, height, width = out.shape
|
261 |
+
out = out.view(batch, self.out_channel, height, width)
|
262 |
+
out = self.blur(out)
|
263 |
+
|
264 |
+
elif self.downsample:
|
265 |
+
input = self.blur(input)
|
266 |
+
_, _, height, width = input.shape
|
267 |
+
input = input.view(1, batch * in_channel, height, width)
|
268 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
269 |
+
_, _, height, width = out.shape
|
270 |
+
out = out.view(batch, self.out_channel, height, width)
|
271 |
+
|
272 |
+
else:
|
273 |
+
input = input.view(1, batch * in_channel, height, width)
|
274 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
275 |
+
_, _, height, width = out.shape
|
276 |
+
out = out.view(batch, self.out_channel, height, width)
|
277 |
+
|
278 |
+
return out
|
279 |
+
|
280 |
+
|
281 |
+
class NoiseInjection(nn.Module):
|
282 |
+
def __init__(self):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
286 |
+
|
287 |
+
def forward(self, image, noise=None):
|
288 |
+
if noise is None:
|
289 |
+
batch, _, height, width = image.shape
|
290 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
291 |
+
|
292 |
+
return image + self.weight * noise
|
293 |
+
|
294 |
+
|
295 |
+
class ConstantInput(nn.Module):
|
296 |
+
def __init__(self, channel, size=4):
|
297 |
+
super().__init__()
|
298 |
+
|
299 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
300 |
+
|
301 |
+
def forward(self, input):
|
302 |
+
batch = input.shape[0]
|
303 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
304 |
+
|
305 |
+
return out
|
306 |
+
|
307 |
+
|
308 |
+
class StyledConv(nn.Module):
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
in_channel,
|
312 |
+
out_channel,
|
313 |
+
kernel_size,
|
314 |
+
style_dim,
|
315 |
+
upsample=False,
|
316 |
+
blur_kernel=[1, 3, 3, 1],
|
317 |
+
demodulate=True,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
|
321 |
+
self.conv = ModulatedConv2d(
|
322 |
+
in_channel,
|
323 |
+
out_channel,
|
324 |
+
kernel_size,
|
325 |
+
style_dim,
|
326 |
+
upsample=upsample,
|
327 |
+
blur_kernel=blur_kernel,
|
328 |
+
demodulate=demodulate,
|
329 |
+
)
|
330 |
+
|
331 |
+
self.noise = NoiseInjection()
|
332 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
333 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
334 |
+
self.activate = FusedLeakyReLU(out_channel)
|
335 |
+
|
336 |
+
def forward(self, input, style, noise=None):
|
337 |
+
out = self.conv(input, style)
|
338 |
+
out = self.noise(out, noise=noise)
|
339 |
+
# out = out + self.bias
|
340 |
+
out = self.activate(out)
|
341 |
+
|
342 |
+
return out
|
343 |
+
|
344 |
+
|
345 |
+
class ToRGB(nn.Module):
|
346 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
347 |
+
super().__init__()
|
348 |
+
|
349 |
+
if upsample:
|
350 |
+
self.upsample = Upsample(blur_kernel)
|
351 |
+
|
352 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
353 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
354 |
+
|
355 |
+
def forward(self, input, style, skip=None):
|
356 |
+
out = self.conv(input, style)
|
357 |
+
style_modulated = out
|
358 |
+
out = out + self.bias
|
359 |
+
|
360 |
+
if skip is not None:
|
361 |
+
skip = self.upsample(skip)
|
362 |
+
|
363 |
+
out = out + skip
|
364 |
+
|
365 |
+
return out, style_modulated
|
366 |
+
|
367 |
+
|
368 |
+
class Generator(nn.Module):
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
size,
|
372 |
+
style_dim,
|
373 |
+
n_mlp,
|
374 |
+
channel_multiplier=2,
|
375 |
+
blur_kernel=[1, 3, 3, 1],
|
376 |
+
lr_mlp=0.01,
|
377 |
+
):
|
378 |
+
super().__init__()
|
379 |
+
|
380 |
+
self.size = size
|
381 |
+
|
382 |
+
self.style_dim = style_dim
|
383 |
+
|
384 |
+
layers = [PixelNorm()]
|
385 |
+
|
386 |
+
for i in range(n_mlp):
|
387 |
+
layers.append(
|
388 |
+
EqualLinear(
|
389 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
390 |
+
)
|
391 |
+
)
|
392 |
+
|
393 |
+
self.style = nn.Sequential(*layers)
|
394 |
+
|
395 |
+
self.channels = {
|
396 |
+
4: 512,
|
397 |
+
8: 512,
|
398 |
+
16: 512,
|
399 |
+
32: 512,
|
400 |
+
64: 256 * channel_multiplier,
|
401 |
+
128: 128 * channel_multiplier,
|
402 |
+
256: 64 * channel_multiplier,
|
403 |
+
512: 32 * channel_multiplier,
|
404 |
+
1024: 16 * channel_multiplier,
|
405 |
+
}
|
406 |
+
|
407 |
+
self.input = ConstantInput(self.channels[4])
|
408 |
+
self.conv1 = StyledConv(
|
409 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
410 |
+
)
|
411 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
412 |
+
|
413 |
+
self.log_size = int(math.log(size, 2))
|
414 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
415 |
+
|
416 |
+
self.convs = nn.ModuleList()
|
417 |
+
self.upsamples = nn.ModuleList()
|
418 |
+
self.to_rgbs = nn.ModuleList()
|
419 |
+
self.noises = nn.Module()
|
420 |
+
|
421 |
+
in_channel = self.channels[4]
|
422 |
+
|
423 |
+
for layer_idx in range(self.num_layers):
|
424 |
+
res = (layer_idx + 5) // 2
|
425 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
426 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
427 |
+
|
428 |
+
for i in range(3, self.log_size + 1):
|
429 |
+
out_channel = self.channels[2 ** i]
|
430 |
+
|
431 |
+
self.convs.append(
|
432 |
+
StyledConv(
|
433 |
+
in_channel,
|
434 |
+
out_channel,
|
435 |
+
3,
|
436 |
+
style_dim,
|
437 |
+
upsample=True,
|
438 |
+
blur_kernel=blur_kernel,
|
439 |
+
)
|
440 |
+
)
|
441 |
+
|
442 |
+
self.convs.append(
|
443 |
+
StyledConv(
|
444 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
445 |
+
)
|
446 |
+
)
|
447 |
+
|
448 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
449 |
+
|
450 |
+
in_channel = out_channel
|
451 |
+
|
452 |
+
self.n_latent = self.log_size * 2 - 2
|
453 |
+
|
454 |
+
@property
|
455 |
+
def device(self):
|
456 |
+
# TODO if multi-gpu is expected, could use the following more expensive version
|
457 |
+
#device, = list(set(p.device for p in self.parameters()))
|
458 |
+
return next(self.parameters()).device
|
459 |
+
|
460 |
+
@staticmethod
|
461 |
+
def get_latent_size(size):
|
462 |
+
log_size = int(math.log(size, 2))
|
463 |
+
return log_size * 2 - 2
|
464 |
+
|
465 |
+
@staticmethod
|
466 |
+
def make_noise_by_size(size: int, device: torch.device):
|
467 |
+
log_size = int(math.log(size, 2))
|
468 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
469 |
+
|
470 |
+
for i in range(3, log_size + 1):
|
471 |
+
for _ in range(2):
|
472 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
473 |
+
|
474 |
+
return noises
|
475 |
+
|
476 |
+
|
477 |
+
def make_noise(self):
|
478 |
+
return self.make_noise_by_size(self.size, self.input.input.device)
|
479 |
+
|
480 |
+
def mean_latent(self, n_latent):
|
481 |
+
latent_in = torch.randn(
|
482 |
+
n_latent, self.style_dim, device=self.input.input.device
|
483 |
+
)
|
484 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
485 |
+
|
486 |
+
return latent
|
487 |
+
|
488 |
+
def get_latent(self, input):
|
489 |
+
return self.style(input)
|
490 |
+
|
491 |
+
def forward(
|
492 |
+
self,
|
493 |
+
styles,
|
494 |
+
return_latents=False,
|
495 |
+
inject_index=None,
|
496 |
+
truncation=1,
|
497 |
+
truncation_latent=None,
|
498 |
+
input_is_latent=False,
|
499 |
+
noise=None,
|
500 |
+
randomize_noise=True,
|
501 |
+
):
|
502 |
+
if not input_is_latent:
|
503 |
+
styles = [self.style(s) for s in styles]
|
504 |
+
|
505 |
+
if noise is None:
|
506 |
+
if randomize_noise:
|
507 |
+
noise = [None] * self.num_layers
|
508 |
+
else:
|
509 |
+
noise = [
|
510 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
511 |
+
]
|
512 |
+
|
513 |
+
if truncation < 1:
|
514 |
+
style_t = []
|
515 |
+
|
516 |
+
for style in styles:
|
517 |
+
style_t.append(
|
518 |
+
truncation_latent + truncation * (style - truncation_latent)
|
519 |
+
)
|
520 |
+
|
521 |
+
styles = style_t
|
522 |
+
|
523 |
+
if len(styles) < 2:
|
524 |
+
inject_index = self.n_latent
|
525 |
+
|
526 |
+
if styles[0].ndim < 3:
|
527 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
528 |
+
|
529 |
+
else:
|
530 |
+
latent = styles[0]
|
531 |
+
|
532 |
+
else:
|
533 |
+
if inject_index is None:
|
534 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
535 |
+
|
536 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
537 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
538 |
+
|
539 |
+
latent = torch.cat([latent, latent2], 1)
|
540 |
+
|
541 |
+
out = self.input(latent)
|
542 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
543 |
+
|
544 |
+
skip, rgb_mod = self.to_rgb1(out, latent[:, 1])
|
545 |
+
|
546 |
+
|
547 |
+
rgbs = [rgb_mod] # all but the last skip
|
548 |
+
i = 1
|
549 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
550 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
551 |
+
):
|
552 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
553 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
554 |
+
skip, rgb_mod = to_rgb(out, latent[:, i + 2], skip)
|
555 |
+
rgbs.append(rgb_mod)
|
556 |
+
|
557 |
+
i += 2
|
558 |
+
|
559 |
+
image = skip
|
560 |
+
|
561 |
+
if return_latents:
|
562 |
+
return image, latent, rgbs
|
563 |
+
|
564 |
+
else:
|
565 |
+
return image, None, rgbs
|
566 |
+
|
567 |
+
|
568 |
+
class ConvLayer(nn.Sequential):
|
569 |
+
def __init__(
|
570 |
+
self,
|
571 |
+
in_channel,
|
572 |
+
out_channel,
|
573 |
+
kernel_size,
|
574 |
+
downsample=False,
|
575 |
+
blur_kernel=[1, 3, 3, 1],
|
576 |
+
bias=True,
|
577 |
+
activate=True,
|
578 |
+
):
|
579 |
+
layers = []
|
580 |
+
|
581 |
+
if downsample:
|
582 |
+
factor = 2
|
583 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
584 |
+
pad0 = (p + 1) // 2
|
585 |
+
pad1 = p // 2
|
586 |
+
|
587 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
588 |
+
|
589 |
+
stride = 2
|
590 |
+
self.padding = 0
|
591 |
+
|
592 |
+
else:
|
593 |
+
stride = 1
|
594 |
+
self.padding = kernel_size // 2
|
595 |
+
|
596 |
+
layers.append(
|
597 |
+
EqualConv2d(
|
598 |
+
in_channel,
|
599 |
+
out_channel,
|
600 |
+
kernel_size,
|
601 |
+
padding=self.padding,
|
602 |
+
stride=stride,
|
603 |
+
bias=bias and not activate,
|
604 |
+
)
|
605 |
+
)
|
606 |
+
|
607 |
+
if activate:
|
608 |
+
if bias:
|
609 |
+
layers.append(FusedLeakyReLU(out_channel))
|
610 |
+
|
611 |
+
else:
|
612 |
+
layers.append(ScaledLeakyReLU(0.2))
|
613 |
+
|
614 |
+
super().__init__(*layers)
|
615 |
+
|
616 |
+
|
617 |
+
class ResBlock(nn.Module):
|
618 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
619 |
+
super().__init__()
|
620 |
+
|
621 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
622 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
623 |
+
|
624 |
+
self.skip = ConvLayer(
|
625 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
626 |
+
)
|
627 |
+
|
628 |
+
def forward(self, input):
|
629 |
+
out = self.conv1(input)
|
630 |
+
out = self.conv2(out)
|
631 |
+
|
632 |
+
skip = self.skip(input)
|
633 |
+
out = (out + skip) / math.sqrt(2)
|
634 |
+
|
635 |
+
return out
|
636 |
+
|
637 |
+
|
638 |
+
class Discriminator(nn.Module):
|
639 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
640 |
+
super().__init__()
|
641 |
+
|
642 |
+
channels = {
|
643 |
+
4: 512,
|
644 |
+
8: 512,
|
645 |
+
16: 512,
|
646 |
+
32: 512,
|
647 |
+
64: 256 * channel_multiplier,
|
648 |
+
128: 128 * channel_multiplier,
|
649 |
+
256: 64 * channel_multiplier,
|
650 |
+
512: 32 * channel_multiplier,
|
651 |
+
1024: 16 * channel_multiplier,
|
652 |
+
}
|
653 |
+
|
654 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
655 |
+
|
656 |
+
log_size = int(math.log(size, 2))
|
657 |
+
|
658 |
+
in_channel = channels[size]
|
659 |
+
|
660 |
+
for i in range(log_size, 2, -1):
|
661 |
+
out_channel = channels[2 ** (i - 1)]
|
662 |
+
|
663 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
664 |
+
|
665 |
+
in_channel = out_channel
|
666 |
+
|
667 |
+
self.convs = nn.Sequential(*convs)
|
668 |
+
|
669 |
+
self.stddev_group = 4
|
670 |
+
self.stddev_feat = 1
|
671 |
+
|
672 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
673 |
+
self.final_linear = nn.Sequential(
|
674 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
675 |
+
EqualLinear(channels[4], 1),
|
676 |
+
)
|
677 |
+
|
678 |
+
def forward(self, input):
|
679 |
+
out = self.convs(input)
|
680 |
+
|
681 |
+
batch, channel, height, width = out.shape
|
682 |
+
group = min(batch, self.stddev_group)
|
683 |
+
stddev = out.view(
|
684 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
685 |
+
)
|
686 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
687 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
688 |
+
stddev = stddev.repeat(group, 1, height, width)
|
689 |
+
out = torch.cat([out, stddev], 1)
|
690 |
+
|
691 |
+
out = self.final_conv(out)
|
692 |
+
|
693 |
+
out = out.view(batch, -1)
|
694 |
+
out = self.final_linear(out)
|
695 |
+
|
696 |
+
return out
|
697 |
+
|
Time-Travel-Rephotography/models/__init__.py
ADDED
File without changes
|
Time-Travel-Rephotography/models/degrade.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import (
|
2 |
+
ArgumentParser,
|
3 |
+
Namespace,
|
4 |
+
)
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from utils.misc import optional_string
|
11 |
+
|
12 |
+
from .gaussian_smoothing import GaussianSmoothing
|
13 |
+
|
14 |
+
|
15 |
+
class DegradeArguments:
|
16 |
+
@staticmethod
|
17 |
+
def add_arguments(parser: ArgumentParser):
|
18 |
+
parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g",
|
19 |
+
help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)")
|
20 |
+
parser.add_argument('--gaussian', type=float, default=0,
|
21 |
+
help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024")
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def to_string(args: Namespace) -> str:
|
25 |
+
return (
|
26 |
+
f"{args.spectral_sensitivity}"
|
27 |
+
+ optional_string(args.gaussian > 0, f"-G{args.gaussian}")
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
class CameraResponse(nn.Module):
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.register_parameter("gamma", nn.Parameter(torch.ones(1)))
|
36 |
+
self.register_parameter("offset", nn.Parameter(torch.zeros(1)))
|
37 |
+
self.register_parameter("gain", nn.Parameter(torch.ones(1)))
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
x = torch.clamp(x, max=1, min=-1+1e-2)
|
41 |
+
x = (1 + x) * 0.5
|
42 |
+
x = self.offset + self.gain * torch.pow(x, self.gamma)
|
43 |
+
x = (x - 0.5) * 2
|
44 |
+
# b = torch.clamp(b, max=1, min=-1)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class SpectralResponse(nn.Module):
|
49 |
+
# TODO: use enum instead for color mode
|
50 |
+
def __init__(self, spectral_sensitivity: str = 'b'):
|
51 |
+
assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented."
|
52 |
+
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.spectral_sensitivity = spectral_sensitivity
|
56 |
+
|
57 |
+
if self.spectral_sensitivity == "g":
|
58 |
+
self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1))
|
59 |
+
|
60 |
+
def forward(self, rgb: torch.Tensor) -> torch.Tensor:
|
61 |
+
if self.spectral_sensitivity == "b":
|
62 |
+
x = rgb[:, -1:]
|
63 |
+
elif self.spectral_sensitivity == "gb":
|
64 |
+
x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5
|
65 |
+
else:
|
66 |
+
assert self.spectral_sensitivity == "g"
|
67 |
+
x = (rgb * self.to_gray).sum(dim=1, keepdim=True)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class Downsample(nn.Module):
|
72 |
+
"""Antialiasing downsampling"""
|
73 |
+
def __init__(self, input_size: int, output_size: int, channels: int):
|
74 |
+
super().__init__()
|
75 |
+
if input_size % output_size == 0:
|
76 |
+
self.stride = input_size // output_size
|
77 |
+
self.grid = None
|
78 |
+
else:
|
79 |
+
self.stride = 1
|
80 |
+
step = input_size / output_size
|
81 |
+
x = torch.arange(output_size) * step
|
82 |
+
Y, X = torch.meshgrid(x, x)
|
83 |
+
grid = torch.stack((X, Y), dim=-1)
|
84 |
+
grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1)
|
85 |
+
grid = grid * 2 - 1
|
86 |
+
self.register_buffer("grid", grid)
|
87 |
+
sigma = 0.5 * input_size / output_size
|
88 |
+
#print(f"{input_size} -> {output_size}: sigma={sigma}")
|
89 |
+
self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma)
|
90 |
+
|
91 |
+
def forward(self, im: torch.Tensor):
|
92 |
+
out = self.blur(im, stride=self.stride)
|
93 |
+
if self.grid is not None:
|
94 |
+
out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1))
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
class Degrade(nn.Module):
|
100 |
+
"""
|
101 |
+
Simulate the degradation of antique film
|
102 |
+
"""
|
103 |
+
def __init__(self, args:Namespace):
|
104 |
+
super().__init__()
|
105 |
+
self.srf = SpectralResponse(args.spectral_sensitivity)
|
106 |
+
self.crf = CameraResponse()
|
107 |
+
self.gaussian = None
|
108 |
+
if args.gaussian is not None and args.gaussian > 0:
|
109 |
+
self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian)
|
110 |
+
|
111 |
+
def forward(self, img: torch.Tensor, downsample: nn.Module = None):
|
112 |
+
if self.gaussian is not None:
|
113 |
+
img = self.gaussian(img)
|
114 |
+
if downsample is not None:
|
115 |
+
img = downsample(img)
|
116 |
+
img = self.srf(img)
|
117 |
+
img = self.crf(img)
|
118 |
+
# Note that I changed it back to 3 channels
|
119 |
+
return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img
|
120 |
+
|
121 |
+
|
122 |
+
|
Time-Travel-Rephotography/models/encoder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace, ArgumentParser
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .resnet import ResNetBasicBlock, activation_func, norm_module, Conv2dAuto
|
7 |
+
|
8 |
+
|
9 |
+
def add_arguments(parser: ArgumentParser) -> ArgumentParser:
|
10 |
+
parser.add_argument("--latent_size", type=int, default=512, help="latent size")
|
11 |
+
return parser
|
12 |
+
|
13 |
+
|
14 |
+
def create_model(args) -> nn.Module:
|
15 |
+
in_channels = 3 if "rgb" in args and args.rgb else 1
|
16 |
+
return Encoder(in_channels, args.encoder_size, latent_size=args.latent_size)
|
17 |
+
|
18 |
+
|
19 |
+
class Flatten(nn.Module):
|
20 |
+
def forward(self, input_):
|
21 |
+
return input_.view(input_.size(0), -1)
|
22 |
+
|
23 |
+
|
24 |
+
class Encoder(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self, in_channels: int, size: int, latent_size: int = 512,
|
27 |
+
activation: str = 'leaky_relu', norm: str = "instance"
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
out_channels0 = 64
|
32 |
+
norm_m = norm_module(norm)
|
33 |
+
self.conv0 = nn.Sequential(
|
34 |
+
Conv2dAuto(in_channels, out_channels0, kernel_size=5),
|
35 |
+
norm_m(out_channels0),
|
36 |
+
activation_func(activation),
|
37 |
+
)
|
38 |
+
|
39 |
+
pool_kernel = 2
|
40 |
+
self.pool = nn.AvgPool2d(pool_kernel)
|
41 |
+
|
42 |
+
num_channels = [128, 256, 512, 512]
|
43 |
+
# FIXME: this is a hack
|
44 |
+
if size >= 256:
|
45 |
+
num_channels.append(512)
|
46 |
+
|
47 |
+
residual = partial(ResNetBasicBlock, activation=activation, norm=norm, bias=True)
|
48 |
+
residual_blocks = nn.ModuleList()
|
49 |
+
for in_channel, out_channel in zip([out_channels0] + num_channels[:-1], num_channels):
|
50 |
+
residual_blocks.append(residual(in_channel, out_channel))
|
51 |
+
residual_blocks.append(nn.AvgPool2d(pool_kernel))
|
52 |
+
self.residual_blocks = nn.Sequential(*residual_blocks)
|
53 |
+
|
54 |
+
self.last = nn.Sequential(
|
55 |
+
nn.ReLU(),
|
56 |
+
nn.AvgPool2d(4), # TODO: not sure whehter this would cause problem
|
57 |
+
Flatten(),
|
58 |
+
nn.Linear(num_channels[-1], latent_size, bias=True)
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, input_):
|
62 |
+
out = self.conv0(input_)
|
63 |
+
out = self.pool(out)
|
64 |
+
out = self.residual_blocks(out)
|
65 |
+
out = self.last(out)
|
66 |
+
return out
|
Time-Travel-Rephotography/models/gaussian_smoothing.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numbers
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class GaussianSmoothing(nn.Module):
|
9 |
+
"""
|
10 |
+
Apply gaussian smoothing on a
|
11 |
+
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
12 |
+
in the input using a depthwise convolution.
|
13 |
+
Arguments:
|
14 |
+
channels (int, sequence): Number of channels of the input tensors. Output will
|
15 |
+
have this number of channels as well.
|
16 |
+
kernel_size (int, sequence): Size of the gaussian kernel.
|
17 |
+
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
18 |
+
dim (int, optional): The number of dimensions of the data.
|
19 |
+
Default value is 2 (spatial).
|
20 |
+
"""
|
21 |
+
def __init__(self, channels, kernel_size, sigma, dim=2):
|
22 |
+
super(GaussianSmoothing, self).__init__()
|
23 |
+
if isinstance(kernel_size, numbers.Number):
|
24 |
+
kernel_size = [kernel_size] * dim
|
25 |
+
if isinstance(sigma, numbers.Number):
|
26 |
+
sigma = [sigma] * dim
|
27 |
+
|
28 |
+
# The gaussian kernel is the product of the
|
29 |
+
# gaussian function of each dimension.
|
30 |
+
kernel = 1
|
31 |
+
meshgrids = torch.meshgrid(
|
32 |
+
[
|
33 |
+
torch.arange(size, dtype=torch.float32)
|
34 |
+
for size in kernel_size
|
35 |
+
]
|
36 |
+
)
|
37 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
38 |
+
mean = (size - 1) / 2
|
39 |
+
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
40 |
+
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
|
41 |
+
|
42 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
43 |
+
kernel = kernel / torch.sum(kernel)
|
44 |
+
|
45 |
+
# Reshape to depthwise convolutional weight
|
46 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
47 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
48 |
+
|
49 |
+
self.register_buffer('weight', kernel)
|
50 |
+
self.groups = channels
|
51 |
+
|
52 |
+
if dim == 1:
|
53 |
+
self.conv = F.conv1d
|
54 |
+
elif dim == 2:
|
55 |
+
self.conv = F.conv2d
|
56 |
+
elif dim == 3:
|
57 |
+
self.conv = F.conv3d
|
58 |
+
else:
|
59 |
+
raise RuntimeError(
|
60 |
+
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, input, stride: int = 1):
|
64 |
+
"""
|
65 |
+
Apply gaussian filter to input.
|
66 |
+
Arguments:
|
67 |
+
input (torch.Tensor): Input to apply gaussian filter on.
|
68 |
+
stride for applying conv
|
69 |
+
Returns:
|
70 |
+
filtered (torch.Tensor): Filtered output.
|
71 |
+
"""
|
72 |
+
padding = (self.weight.shape[-1] - 1) // 2
|
73 |
+
return self.conv(input, weight=self.weight, groups=self.groups, padding=padding, stride=stride)
|
74 |
+
|
Time-Travel-Rephotography/models/resnet.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
def activation_func(activation: str):
|
7 |
+
return nn.ModuleDict([
|
8 |
+
['relu', nn.ReLU(inplace=True)],
|
9 |
+
['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
|
10 |
+
['selu', nn.SELU(inplace=True)],
|
11 |
+
['none', nn.Identity()]
|
12 |
+
])[activation]
|
13 |
+
|
14 |
+
|
15 |
+
def norm_module(norm: str):
|
16 |
+
return {
|
17 |
+
'batch': nn.BatchNorm2d,
|
18 |
+
'instance': nn.InstanceNorm2d,
|
19 |
+
}[norm]
|
20 |
+
|
21 |
+
|
22 |
+
class Conv2dAuto(nn.Conv2d):
|
23 |
+
def __init__(self, *args, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
# dynamic add padding based on the kernel_size
|
26 |
+
self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
|
27 |
+
|
28 |
+
|
29 |
+
conv3x3 = partial(Conv2dAuto, kernel_size=3)
|
30 |
+
|
31 |
+
|
32 |
+
class ResidualBlock(nn.Module):
|
33 |
+
def __init__(self, in_channels: int, out_channels: int, activation: str = 'relu'):
|
34 |
+
super().__init__()
|
35 |
+
self.in_channels, self.out_channels = in_channels, out_channels
|
36 |
+
self.blocks = nn.Identity()
|
37 |
+
self.activate = activation_func(activation)
|
38 |
+
self.shortcut = nn.Identity()
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
residual = x
|
42 |
+
if self.should_apply_shortcut:
|
43 |
+
residual = self.shortcut(x)
|
44 |
+
x = self.blocks(x)
|
45 |
+
x += residual
|
46 |
+
x = self.activate(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
@property
|
50 |
+
def should_apply_shortcut(self):
|
51 |
+
return self.in_channels != self.out_channels
|
52 |
+
|
53 |
+
|
54 |
+
class ResNetResidualBlock(ResidualBlock):
|
55 |
+
def __init__(
|
56 |
+
self, in_channels: int, out_channels: int,
|
57 |
+
expansion: int = 1, downsampling: int = 1,
|
58 |
+
conv=conv3x3, norm: str = 'batch', *args, **kwargs
|
59 |
+
):
|
60 |
+
super().__init__(in_channels, out_channels, *args, **kwargs)
|
61 |
+
self.expansion, self.downsampling = expansion, downsampling
|
62 |
+
self.conv, self.norm = conv, norm_module(norm)
|
63 |
+
self.shortcut = nn.Sequential(
|
64 |
+
nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
|
65 |
+
stride=self.downsampling, bias=False),
|
66 |
+
self.norm(self.expanded_channels)) if self.should_apply_shortcut else None
|
67 |
+
|
68 |
+
@property
|
69 |
+
def expanded_channels(self):
|
70 |
+
return self.out_channels * self.expansion
|
71 |
+
|
72 |
+
@property
|
73 |
+
def should_apply_shortcut(self):
|
74 |
+
return self.in_channels != self.expanded_channels
|
75 |
+
|
76 |
+
|
77 |
+
def conv_norm(in_channels: int, out_channels: int, conv, norm, *args, **kwargs):
|
78 |
+
return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs), norm(out_channels))
|
79 |
+
|
80 |
+
|
81 |
+
class ResNetBasicBlock(ResNetResidualBlock):
|
82 |
+
"""
|
83 |
+
Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation
|
84 |
+
"""
|
85 |
+
expansion = 1
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self, in_channels: int, out_channels: int, bias: bool = False, *args, **kwargs
|
89 |
+
):
|
90 |
+
super().__init__(in_channels, out_channels, *args, **kwargs)
|
91 |
+
self.blocks = nn.Sequential(
|
92 |
+
conv_norm(
|
93 |
+
self.in_channels, self.out_channels, conv=self.conv, norm=self.norm,
|
94 |
+
bias=bias, stride=self.downsampling
|
95 |
+
),
|
96 |
+
self.activate,
|
97 |
+
conv_norm(self.out_channels, self.expanded_channels, conv=self.conv, norm=self.norm, bias=bias),
|
98 |
+
)
|
99 |
+
|
Time-Travel-Rephotography/models/vggface.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Vgg_face_dag(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
super(Vgg_face_dag, self).__init__()
|
10 |
+
self.meta = {'mean': [129.186279296875, 104.76238250732422, 93.59396362304688],
|
11 |
+
'std': [1, 1, 1],
|
12 |
+
'imageSize': [224, 224, 3]}
|
13 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
14 |
+
self.relu1_1 = nn.ReLU(inplace=True)
|
15 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
16 |
+
self.relu1_2 = nn.ReLU(inplace=True)
|
17 |
+
self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
18 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
19 |
+
self.relu2_1 = nn.ReLU(inplace=True)
|
20 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
21 |
+
self.relu2_2 = nn.ReLU(inplace=True)
|
22 |
+
self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
23 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
24 |
+
self.relu3_1 = nn.ReLU(inplace=True)
|
25 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
26 |
+
self.relu3_2 = nn.ReLU(inplace=True)
|
27 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
28 |
+
self.relu3_3 = nn.ReLU(inplace=True)
|
29 |
+
self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
30 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
31 |
+
self.relu4_1 = nn.ReLU(inplace=True)
|
32 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
33 |
+
self.relu4_2 = nn.ReLU(inplace=True)
|
34 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
35 |
+
self.relu4_3 = nn.ReLU(inplace=True)
|
36 |
+
self.pool4 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
37 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
38 |
+
self.relu5_1 = nn.ReLU(inplace=True)
|
39 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
40 |
+
self.relu5_2 = nn.ReLU(inplace=True)
|
41 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
42 |
+
self.relu5_3 = nn.ReLU(inplace=True)
|
43 |
+
self.pool5 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
44 |
+
self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True)
|
45 |
+
self.relu6 = nn.ReLU(inplace=True)
|
46 |
+
self.dropout6 = nn.Dropout(p=0.5)
|
47 |
+
self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
|
48 |
+
self.relu7 = nn.ReLU(inplace=True)
|
49 |
+
self.dropout7 = nn.Dropout(p=0.5)
|
50 |
+
self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True)
|
51 |
+
|
52 |
+
def forward(self, x0):
|
53 |
+
x1 = self.conv1_1(x0)
|
54 |
+
x2 = self.relu1_1(x1)
|
55 |
+
x3 = self.conv1_2(x2)
|
56 |
+
x4 = self.relu1_2(x3)
|
57 |
+
x5 = self.pool1(x4)
|
58 |
+
x6 = self.conv2_1(x5)
|
59 |
+
x7 = self.relu2_1(x6)
|
60 |
+
x8 = self.conv2_2(x7)
|
61 |
+
x9 = self.relu2_2(x8)
|
62 |
+
x10 = self.pool2(x9)
|
63 |
+
x11 = self.conv3_1(x10)
|
64 |
+
x12 = self.relu3_1(x11)
|
65 |
+
x13 = self.conv3_2(x12)
|
66 |
+
x14 = self.relu3_2(x13)
|
67 |
+
x15 = self.conv3_3(x14)
|
68 |
+
x16 = self.relu3_3(x15)
|
69 |
+
x17 = self.pool3(x16)
|
70 |
+
x18 = self.conv4_1(x17)
|
71 |
+
x19 = self.relu4_1(x18)
|
72 |
+
x20 = self.conv4_2(x19)
|
73 |
+
x21 = self.relu4_2(x20)
|
74 |
+
x22 = self.conv4_3(x21)
|
75 |
+
x23 = self.relu4_3(x22)
|
76 |
+
x24 = self.pool4(x23)
|
77 |
+
x25 = self.conv5_1(x24)
|
78 |
+
x26 = self.relu5_1(x25)
|
79 |
+
x27 = self.conv5_2(x26)
|
80 |
+
x28 = self.relu5_2(x27)
|
81 |
+
x29 = self.conv5_3(x28)
|
82 |
+
x30 = self.relu5_3(x29)
|
83 |
+
x31_preflatten = self.pool5(x30)
|
84 |
+
x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
|
85 |
+
x32 = self.fc6(x31)
|
86 |
+
x33 = self.relu6(x32)
|
87 |
+
x34 = self.dropout6(x33)
|
88 |
+
x35 = self.fc7(x34)
|
89 |
+
x36 = self.relu7(x35)
|
90 |
+
x37 = self.dropout7(x36)
|
91 |
+
x38 = self.fc8(x37)
|
92 |
+
return x38
|
93 |
+
|
94 |
+
|
95 |
+
def vgg_face_dag(weights_path=None, **kwargs):
|
96 |
+
"""
|
97 |
+
load imported model instance
|
98 |
+
|
99 |
+
Args:
|
100 |
+
weights_path (str): If set, loads model weights from the given path
|
101 |
+
"""
|
102 |
+
model = Vgg_face_dag()
|
103 |
+
if weights_path:
|
104 |
+
state_dict = torch.load(weights_path)
|
105 |
+
model.load_state_dict(state_dict)
|
106 |
+
return model
|
107 |
+
|
108 |
+
|
109 |
+
class VGGFaceFeats(Vgg_face_dag):
|
110 |
+
def forward(self, x0):
|
111 |
+
x1 = self.conv1_1(x0)
|
112 |
+
x2 = self.relu1_1(x1)
|
113 |
+
x3 = self.conv1_2(x2)
|
114 |
+
x4 = self.relu1_2(x3)
|
115 |
+
x5 = self.pool1(x4)
|
116 |
+
x6 = self.conv2_1(x5)
|
117 |
+
x7 = self.relu2_1(x6)
|
118 |
+
x8 = self.conv2_2(x7)
|
119 |
+
x9 = self.relu2_2(x8)
|
120 |
+
x10 = self.pool2(x9)
|
121 |
+
x11 = self.conv3_1(x10)
|
122 |
+
x12 = self.relu3_1(x11)
|
123 |
+
x13 = self.conv3_2(x12)
|
124 |
+
x14 = self.relu3_2(x13)
|
125 |
+
x15 = self.conv3_3(x14)
|
126 |
+
x16 = self.relu3_3(x15)
|
127 |
+
x17 = self.pool3(x16)
|
128 |
+
x18 = self.conv4_1(x17)
|
129 |
+
x19 = self.relu4_1(x18)
|
130 |
+
x20 = self.conv4_2(x19)
|
131 |
+
x21 = self.relu4_2(x20)
|
132 |
+
x22 = self.conv4_3(x21)
|
133 |
+
x23 = self.relu4_3(x22)
|
134 |
+
x24 = self.pool4(x23)
|
135 |
+
x25 = self.conv5_1(x24)
|
136 |
+
# x26 = self.relu5_1(x25)
|
137 |
+
# x27 = self.conv5_2(x26)
|
138 |
+
# x28 = self.relu5_2(x27)
|
139 |
+
# x29 = self.conv5_3(x28)
|
140 |
+
# x30 = self.relu5_3(x29)
|
141 |
+
# x31_preflatten = self.pool5(x30)
|
142 |
+
# x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
|
143 |
+
# x32 = self.fc6(x31)
|
144 |
+
# x33 = self.relu6(x32)
|
145 |
+
# x34 = self.dropout6(x33)
|
146 |
+
# x35 = self.fc7(x34)
|
147 |
+
# x36 = self.relu7(x35)
|
148 |
+
# x37 = self.dropout7(x36)
|
149 |
+
# x38 = self.fc8(x37)
|
150 |
+
return x1, x6, x11, x18, x25
|
Time-Travel-Rephotography/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
Time-Travel-Rephotography/op/fused_act.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.utils.cpp_extension import load
|
7 |
+
|
8 |
+
|
9 |
+
module_path = os.path.dirname(__file__)
|
10 |
+
fused = load(
|
11 |
+
'fused',
|
12 |
+
sources=[
|
13 |
+
os.path.join(module_path, 'fused_bias_act.cpp'),
|
14 |
+
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
15 |
+
],
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
20 |
+
@staticmethod
|
21 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
22 |
+
ctx.save_for_backward(out)
|
23 |
+
ctx.negative_slope = negative_slope
|
24 |
+
ctx.scale = scale
|
25 |
+
|
26 |
+
empty = grad_output.new_empty(0)
|
27 |
+
|
28 |
+
grad_input = fused.fused_bias_act(
|
29 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
30 |
+
)
|
31 |
+
|
32 |
+
dim = [0]
|
33 |
+
|
34 |
+
if grad_input.ndim > 2:
|
35 |
+
dim += list(range(2, grad_input.ndim))
|
36 |
+
|
37 |
+
grad_bias = grad_input.sum(dim).detach()
|
38 |
+
|
39 |
+
return grad_input, grad_bias
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
43 |
+
out, = ctx.saved_tensors
|
44 |
+
gradgrad_out = fused.fused_bias_act(
|
45 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
46 |
+
)
|
47 |
+
|
48 |
+
return gradgrad_out, None, None, None
|
49 |
+
|
50 |
+
|
51 |
+
class FusedLeakyReLUFunction(Function):
|
52 |
+
@staticmethod
|
53 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
54 |
+
empty = input.new_empty(0)
|
55 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
56 |
+
ctx.save_for_backward(out)
|
57 |
+
ctx.negative_slope = negative_slope
|
58 |
+
ctx.scale = scale
|
59 |
+
|
60 |
+
return out
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def backward(ctx, grad_output):
|
64 |
+
out, = ctx.saved_tensors
|
65 |
+
|
66 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
67 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
68 |
+
)
|
69 |
+
|
70 |
+
return grad_input, grad_bias, None, None
|
71 |
+
|
72 |
+
|
73 |
+
class FusedLeakyReLU(nn.Module):
|
74 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
78 |
+
self.negative_slope = negative_slope
|
79 |
+
self.scale = scale
|
80 |
+
|
81 |
+
def forward(self, input):
|
82 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
83 |
+
|
84 |
+
|
85 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
86 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
Time-Travel-Rephotography/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
Time-Travel-Rephotography/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
Time-Travel-Rephotography/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|
Time-Travel-Rephotography/op/upfirdn2d.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.utils.cpp_extension import load
|
6 |
+
|
7 |
+
|
8 |
+
module_path = os.path.dirname(__file__)
|
9 |
+
upfirdn2d_op = load(
|
10 |
+
'upfirdn2d',
|
11 |
+
sources=[
|
12 |
+
os.path.join(module_path, 'upfirdn2d.cpp'),
|
13 |
+
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
14 |
+
],
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class UpFirDn2dBackward(Function):
|
19 |
+
@staticmethod
|
20 |
+
def forward(
|
21 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
22 |
+
):
|
23 |
+
|
24 |
+
up_x, up_y = up
|
25 |
+
down_x, down_y = down
|
26 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
27 |
+
|
28 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
29 |
+
|
30 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
31 |
+
grad_output,
|
32 |
+
grad_kernel,
|
33 |
+
down_x,
|
34 |
+
down_y,
|
35 |
+
up_x,
|
36 |
+
up_y,
|
37 |
+
g_pad_x0,
|
38 |
+
g_pad_x1,
|
39 |
+
g_pad_y0,
|
40 |
+
g_pad_y1,
|
41 |
+
)
|
42 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
43 |
+
|
44 |
+
ctx.save_for_backward(kernel)
|
45 |
+
|
46 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
47 |
+
|
48 |
+
ctx.up_x = up_x
|
49 |
+
ctx.up_y = up_y
|
50 |
+
ctx.down_x = down_x
|
51 |
+
ctx.down_y = down_y
|
52 |
+
ctx.pad_x0 = pad_x0
|
53 |
+
ctx.pad_x1 = pad_x1
|
54 |
+
ctx.pad_y0 = pad_y0
|
55 |
+
ctx.pad_y1 = pad_y1
|
56 |
+
ctx.in_size = in_size
|
57 |
+
ctx.out_size = out_size
|
58 |
+
|
59 |
+
return grad_input
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def backward(ctx, gradgrad_input):
|
63 |
+
kernel, = ctx.saved_tensors
|
64 |
+
|
65 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
66 |
+
|
67 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
68 |
+
gradgrad_input,
|
69 |
+
kernel,
|
70 |
+
ctx.up_x,
|
71 |
+
ctx.up_y,
|
72 |
+
ctx.down_x,
|
73 |
+
ctx.down_y,
|
74 |
+
ctx.pad_x0,
|
75 |
+
ctx.pad_x1,
|
76 |
+
ctx.pad_y0,
|
77 |
+
ctx.pad_y1,
|
78 |
+
)
|
79 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
80 |
+
gradgrad_out = gradgrad_out.view(
|
81 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
82 |
+
)
|
83 |
+
|
84 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
85 |
+
|
86 |
+
|
87 |
+
class UpFirDn2d(Function):
|
88 |
+
@staticmethod
|
89 |
+
def forward(ctx, input, kernel, up, down, pad):
|
90 |
+
up_x, up_y = up
|
91 |
+
down_x, down_y = down
|
92 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
93 |
+
|
94 |
+
kernel_h, kernel_w = kernel.shape
|
95 |
+
batch, channel, in_h, in_w = input.shape
|
96 |
+
ctx.in_size = input.shape
|
97 |
+
|
98 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
99 |
+
|
100 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
101 |
+
|
102 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
103 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
104 |
+
ctx.out_size = (out_h, out_w)
|
105 |
+
|
106 |
+
ctx.up = (up_x, up_y)
|
107 |
+
ctx.down = (down_x, down_y)
|
108 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
109 |
+
|
110 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
111 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
112 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
113 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
114 |
+
|
115 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
116 |
+
|
117 |
+
out = upfirdn2d_op.upfirdn2d(
|
118 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
119 |
+
)
|
120 |
+
# out = out.view(major, out_h, out_w, minor)
|
121 |
+
out = out.view(-1, channel, out_h, out_w)
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def backward(ctx, grad_output):
|
127 |
+
kernel, grad_kernel = ctx.saved_tensors
|
128 |
+
|
129 |
+
grad_input = UpFirDn2dBackward.apply(
|
130 |
+
grad_output,
|
131 |
+
kernel,
|
132 |
+
grad_kernel,
|
133 |
+
ctx.up,
|
134 |
+
ctx.down,
|
135 |
+
ctx.pad,
|
136 |
+
ctx.g_pad,
|
137 |
+
ctx.in_size,
|
138 |
+
ctx.out_size,
|
139 |
+
)
|
140 |
+
|
141 |
+
return grad_input, None, None, None, None
|
142 |
+
|
143 |
+
|
144 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
145 |
+
out = UpFirDn2d.apply(
|
146 |
+
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
147 |
+
)
|
148 |
+
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
def upfirdn2d_native(
|
153 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
154 |
+
):
|
155 |
+
_, in_h, in_w, minor = input.shape
|
156 |
+
kernel_h, kernel_w = kernel.shape
|
157 |
+
|
158 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
159 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
160 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
161 |
+
|
162 |
+
out = F.pad(
|
163 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
164 |
+
)
|
165 |
+
out = out[
|
166 |
+
:,
|
167 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
168 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
169 |
+
:,
|
170 |
+
]
|
171 |
+
|
172 |
+
out = out.permute(0, 3, 1, 2)
|
173 |
+
out = out.reshape(
|
174 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
175 |
+
)
|
176 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
177 |
+
out = F.conv2d(out, w)
|
178 |
+
out = out.reshape(
|
179 |
+
-1,
|
180 |
+
minor,
|
181 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
182 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
183 |
+
)
|
184 |
+
out = out.permute(0, 2, 3, 1)
|
185 |
+
|
186 |
+
return out[:, ::down_y, ::down_x, :]
|
187 |
+
|
Time-Travel-Rephotography/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
19 |
+
int c = a / b;
|
20 |
+
|
21 |
+
if (c * b > a) {
|
22 |
+
c--;
|
23 |
+
}
|
24 |
+
|
25 |
+
return c;
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
struct UpFirDn2DKernelParams {
|
30 |
+
int up_x;
|
31 |
+
int up_y;
|
32 |
+
int down_x;
|
33 |
+
int down_y;
|
34 |
+
int pad_x0;
|
35 |
+
int pad_x1;
|
36 |
+
int pad_y0;
|
37 |
+
int pad_y1;
|
38 |
+
|
39 |
+
int major_dim;
|
40 |
+
int in_h;
|
41 |
+
int in_w;
|
42 |
+
int minor_dim;
|
43 |
+
int kernel_h;
|
44 |
+
int kernel_w;
|
45 |
+
int out_h;
|
46 |
+
int out_w;
|
47 |
+
int loop_major;
|
48 |
+
int loop_x;
|
49 |
+
};
|
50 |
+
|
51 |
+
|
52 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
53 |
+
__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
|
54 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
55 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
56 |
+
|
57 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
58 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
59 |
+
|
60 |
+
int minor_idx = blockIdx.x;
|
61 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
62 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
63 |
+
tile_out_y *= tile_out_h;
|
64 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
65 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
66 |
+
|
67 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
|
68 |
+
return;
|
69 |
+
}
|
70 |
+
|
71 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
|
72 |
+
int ky = tap_idx / kernel_w;
|
73 |
+
int kx = tap_idx - ky * kernel_w;
|
74 |
+
scalar_t v = 0.0;
|
75 |
+
|
76 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
77 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
78 |
+
}
|
79 |
+
|
80 |
+
sk[ky][kx] = v;
|
81 |
+
}
|
82 |
+
|
83 |
+
for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
|
84 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
|
85 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
86 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
87 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
88 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
89 |
+
|
90 |
+
__syncthreads();
|
91 |
+
|
92 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
|
93 |
+
int rel_in_y = in_idx / tile_in_w;
|
94 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
95 |
+
int in_x = rel_in_x + tile_in_x;
|
96 |
+
int in_y = rel_in_y + tile_in_y;
|
97 |
+
|
98 |
+
scalar_t v = 0.0;
|
99 |
+
|
100 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
101 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
|
102 |
+
}
|
103 |
+
|
104 |
+
sx[rel_in_y][rel_in_x] = v;
|
105 |
+
}
|
106 |
+
|
107 |
+
__syncthreads();
|
108 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
|
109 |
+
int rel_out_y = out_idx / tile_out_w;
|
110 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
111 |
+
int out_x = rel_out_x + tile_out_x;
|
112 |
+
int out_y = rel_out_y + tile_out_y;
|
113 |
+
|
114 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
115 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
116 |
+
int in_x = floor_div(mid_x, up_x);
|
117 |
+
int in_y = floor_div(mid_y, up_y);
|
118 |
+
int rel_in_x = in_x - tile_in_x;
|
119 |
+
int rel_in_y = in_y - tile_in_y;
|
120 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
121 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
122 |
+
|
123 |
+
scalar_t v = 0.0;
|
124 |
+
|
125 |
+
#pragma unroll
|
126 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
127 |
+
#pragma unroll
|
128 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
129 |
+
v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
130 |
+
|
131 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
132 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
141 |
+
int up_x, int up_y, int down_x, int down_y,
|
142 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
143 |
+
int curDevice = -1;
|
144 |
+
cudaGetDevice(&curDevice);
|
145 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
146 |
+
|
147 |
+
UpFirDn2DKernelParams p;
|
148 |
+
|
149 |
+
auto x = input.contiguous();
|
150 |
+
auto k = kernel.contiguous();
|
151 |
+
|
152 |
+
p.major_dim = x.size(0);
|
153 |
+
p.in_h = x.size(1);
|
154 |
+
p.in_w = x.size(2);
|
155 |
+
p.minor_dim = x.size(3);
|
156 |
+
p.kernel_h = k.size(0);
|
157 |
+
p.kernel_w = k.size(1);
|
158 |
+
p.up_x = up_x;
|
159 |
+
p.up_y = up_y;
|
160 |
+
p.down_x = down_x;
|
161 |
+
p.down_y = down_y;
|
162 |
+
p.pad_x0 = pad_x0;
|
163 |
+
p.pad_x1 = pad_x1;
|
164 |
+
p.pad_y0 = pad_y0;
|
165 |
+
p.pad_y1 = pad_y1;
|
166 |
+
|
167 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
|
168 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
|
169 |
+
|
170 |
+
auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
171 |
+
|
172 |
+
int mode = -1;
|
173 |
+
|
174 |
+
int tile_out_h;
|
175 |
+
int tile_out_w;
|
176 |
+
|
177 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
178 |
+
mode = 1;
|
179 |
+
tile_out_h = 16;
|
180 |
+
tile_out_w = 64;
|
181 |
+
}
|
182 |
+
|
183 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
|
184 |
+
mode = 2;
|
185 |
+
tile_out_h = 16;
|
186 |
+
tile_out_w = 64;
|
187 |
+
}
|
188 |
+
|
189 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
190 |
+
mode = 3;
|
191 |
+
tile_out_h = 16;
|
192 |
+
tile_out_w = 64;
|
193 |
+
}
|
194 |
+
|
195 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
196 |
+
mode = 4;
|
197 |
+
tile_out_h = 16;
|
198 |
+
tile_out_w = 64;
|
199 |
+
}
|
200 |
+
|
201 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
202 |
+
mode = 5;
|
203 |
+
tile_out_h = 8;
|
204 |
+
tile_out_w = 32;
|
205 |
+
}
|
206 |
+
|
207 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
208 |
+
mode = 6;
|
209 |
+
tile_out_h = 8;
|
210 |
+
tile_out_w = 32;
|
211 |
+
}
|
212 |
+
|
213 |
+
dim3 block_size;
|
214 |
+
dim3 grid_size;
|
215 |
+
|
216 |
+
if (tile_out_h > 0 && tile_out_w) {
|
217 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
218 |
+
p.loop_x = 1;
|
219 |
+
block_size = dim3(32 * 8, 1, 1);
|
220 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
221 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
222 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
223 |
+
}
|
224 |
+
|
225 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
226 |
+
switch (mode) {
|
227 |
+
case 1:
|
228 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
229 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
230 |
+
);
|
231 |
+
|
232 |
+
break;
|
233 |
+
|
234 |
+
case 2:
|
235 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
236 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
237 |
+
);
|
238 |
+
|
239 |
+
break;
|
240 |
+
|
241 |
+
case 3:
|
242 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
243 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
244 |
+
);
|
245 |
+
|
246 |
+
break;
|
247 |
+
|
248 |
+
case 4:
|
249 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
250 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
251 |
+
);
|
252 |
+
|
253 |
+
break;
|
254 |
+
|
255 |
+
case 5:
|
256 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
257 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
258 |
+
);
|
259 |
+
|
260 |
+
break;
|
261 |
+
|
262 |
+
case 6:
|
263 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
264 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
265 |
+
);
|
266 |
+
|
267 |
+
break;
|
268 |
+
}
|
269 |
+
});
|
270 |
+
|
271 |
+
return out;
|
272 |
+
}
|
Time-Travel-Rephotography/optim/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim import Adam
|
2 |
+
from torch.optim.lbfgs import LBFGS
|
3 |
+
from .radam import RAdam
|
4 |
+
|
5 |
+
|
6 |
+
OPTIMIZER_MAP = {
|
7 |
+
"adam": Adam,
|
8 |
+
"radam": RAdam,
|
9 |
+
"lbfgs": LBFGS,
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
+
def get_optimizer_class(optimizer_name):
|
14 |
+
name = optimizer_name.lower()
|
15 |
+
return OPTIMIZER_MAP[name]
|
Time-Travel-Rephotography/optim/radam.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.optim.optimizer import Optimizer, required
|
4 |
+
|
5 |
+
|
6 |
+
class RAdam(Optimizer):
|
7 |
+
|
8 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
9 |
+
if not 0.0 <= lr:
|
10 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
11 |
+
if not 0.0 <= eps:
|
12 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
13 |
+
if not 0.0 <= betas[0] < 1.0:
|
14 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
15 |
+
if not 0.0 <= betas[1] < 1.0:
|
16 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
17 |
+
|
18 |
+
self.degenerated_to_sgd = degenerated_to_sgd
|
19 |
+
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
20 |
+
for param in params:
|
21 |
+
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
|
22 |
+
param['buffer'] = [[None, None, None] for _ in range(10)]
|
23 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
24 |
+
buffer=[[None, None, None] for _ in range(10)])
|
25 |
+
super(RAdam, self).__init__(params, defaults)
|
26 |
+
|
27 |
+
def __setstate__(self, state):
|
28 |
+
super(RAdam, self).__setstate__(state)
|
29 |
+
|
30 |
+
def step(self, closure=None):
|
31 |
+
|
32 |
+
loss = None
|
33 |
+
if closure is not None:
|
34 |
+
loss = closure()
|
35 |
+
|
36 |
+
for group in self.param_groups:
|
37 |
+
|
38 |
+
for p in group['params']:
|
39 |
+
if p.grad is None:
|
40 |
+
continue
|
41 |
+
grad = p.grad.data.float()
|
42 |
+
if grad.is_sparse:
|
43 |
+
raise RuntimeError('RAdam does not support sparse gradients')
|
44 |
+
|
45 |
+
p_data_fp32 = p.data.float()
|
46 |
+
|
47 |
+
state = self.state[p]
|
48 |
+
|
49 |
+
if len(state) == 0:
|
50 |
+
state['step'] = 0
|
51 |
+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
52 |
+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
53 |
+
else:
|
54 |
+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
55 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
56 |
+
|
57 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
58 |
+
beta1, beta2 = group['betas']
|
59 |
+
|
60 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
61 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
62 |
+
|
63 |
+
state['step'] += 1
|
64 |
+
buffered = group['buffer'][int(state['step'] % 10)]
|
65 |
+
if state['step'] == buffered[0]:
|
66 |
+
N_sma, step_size = buffered[1], buffered[2]
|
67 |
+
else:
|
68 |
+
buffered[0] = state['step']
|
69 |
+
beta2_t = beta2 ** state['step']
|
70 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
71 |
+
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
72 |
+
buffered[1] = N_sma
|
73 |
+
|
74 |
+
# more conservative since it's an approximated value
|
75 |
+
if N_sma >= 5:
|
76 |
+
step_size = math.sqrt(
|
77 |
+
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
78 |
+
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
79 |
+
elif self.degenerated_to_sgd:
|
80 |
+
step_size = 1.0 / (1 - beta1 ** state['step'])
|
81 |
+
else:
|
82 |
+
step_size = -1
|
83 |
+
buffered[2] = step_size
|
84 |
+
|
85 |
+
# more conservative since it's an approximated value
|
86 |
+
if N_sma >= 5:
|
87 |
+
if group['weight_decay'] != 0:
|
88 |
+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
89 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
90 |
+
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
91 |
+
p.data.copy_(p_data_fp32)
|
92 |
+
elif step_size > 0:
|
93 |
+
if group['weight_decay'] != 0:
|
94 |
+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
95 |
+
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
96 |
+
p.data.copy_(p_data_fp32)
|
97 |
+
|
98 |
+
return loss
|
99 |
+
|
100 |
+
|
101 |
+
class PlainRAdam(Optimizer):
|
102 |
+
|
103 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
104 |
+
if not 0.0 <= lr:
|
105 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
106 |
+
if not 0.0 <= eps:
|
107 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
108 |
+
if not 0.0 <= betas[0] < 1.0:
|
109 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
110 |
+
if not 0.0 <= betas[1] < 1.0:
|
111 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
112 |
+
|
113 |
+
self.degenerated_to_sgd = degenerated_to_sgd
|
114 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
115 |
+
|
116 |
+
super(PlainRAdam, self).__init__(params, defaults)
|
117 |
+
|
118 |
+
def __setstate__(self, state):
|
119 |
+
super(PlainRAdam, self).__setstate__(state)
|
120 |
+
|
121 |
+
def step(self, closure=None):
|
122 |
+
|
123 |
+
loss = None
|
124 |
+
if closure is not None:
|
125 |
+
loss = closure()
|
126 |
+
|
127 |
+
for group in self.param_groups:
|
128 |
+
|
129 |
+
for p in group['params']:
|
130 |
+
if p.grad is None:
|
131 |
+
continue
|
132 |
+
grad = p.grad.data.float()
|
133 |
+
if grad.is_sparse:
|
134 |
+
raise RuntimeError('RAdam does not support sparse gradients')
|
135 |
+
|
136 |
+
p_data_fp32 = p.data.float()
|
137 |
+
|
138 |
+
state = self.state[p]
|
139 |
+
|
140 |
+
if len(state) == 0:
|
141 |
+
state['step'] = 0
|
142 |
+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
143 |
+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
144 |
+
else:
|
145 |
+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
146 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
147 |
+
|
148 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
149 |
+
beta1, beta2 = group['betas']
|
150 |
+
|
151 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
152 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
153 |
+
|
154 |
+
state['step'] += 1
|
155 |
+
beta2_t = beta2 ** state['step']
|
156 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
157 |
+
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
158 |
+
|
159 |
+
# more conservative since it's an approximated value
|
160 |
+
if N_sma >= 5:
|
161 |
+
if group['weight_decay'] != 0:
|
162 |
+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
163 |
+
step_size = group['lr'] * math.sqrt(
|
164 |
+
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
165 |
+
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
166 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
167 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
168 |
+
p.data.copy_(p_data_fp32)
|
169 |
+
elif self.degenerated_to_sgd:
|
170 |
+
if group['weight_decay'] != 0:
|
171 |
+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
172 |
+
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
173 |
+
p_data_fp32.add_(-step_size, exp_avg)
|
174 |
+
p.data.copy_(p_data_fp32)
|
175 |
+
|
176 |
+
return loss
|
177 |
+
|
178 |
+
|
179 |
+
class AdamW(Optimizer):
|
180 |
+
|
181 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
|
182 |
+
if not 0.0 <= lr:
|
183 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
184 |
+
if not 0.0 <= eps:
|
185 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
186 |
+
if not 0.0 <= betas[0] < 1.0:
|
187 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
188 |
+
if not 0.0 <= betas[1] < 1.0:
|
189 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
190 |
+
|
191 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
192 |
+
weight_decay=weight_decay, warmup=warmup)
|
193 |
+
super(AdamW, self).__init__(params, defaults)
|
194 |
+
|
195 |
+
def __setstate__(self, state):
|
196 |
+
super(AdamW, self).__setstate__(state)
|
197 |
+
|
198 |
+
def step(self, closure=None):
|
199 |
+
loss = None
|
200 |
+
if closure is not None:
|
201 |
+
loss = closure()
|
202 |
+
|
203 |
+
for group in self.param_groups:
|
204 |
+
|
205 |
+
for p in group['params']:
|
206 |
+
if p.grad is None:
|
207 |
+
continue
|
208 |
+
grad = p.grad.data.float()
|
209 |
+
if grad.is_sparse:
|
210 |
+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
211 |
+
|
212 |
+
p_data_fp32 = p.data.float()
|
213 |
+
|
214 |
+
state = self.state[p]
|
215 |
+
|
216 |
+
if len(state) == 0:
|
217 |
+
state['step'] = 0
|
218 |
+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
219 |
+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
220 |
+
else:
|
221 |
+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
222 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
223 |
+
|
224 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
225 |
+
beta1, beta2 = group['betas']
|
226 |
+
|
227 |
+
state['step'] += 1
|
228 |
+
|
229 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
230 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
231 |
+
|
232 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
233 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
234 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
235 |
+
|
236 |
+
if group['warmup'] > state['step']:
|
237 |
+
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
|
238 |
+
else:
|
239 |
+
scheduled_lr = group['lr']
|
240 |
+
|
241 |
+
step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
|
242 |
+
|
243 |
+
if group['weight_decay'] != 0:
|
244 |
+
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
|
245 |
+
|
246 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
247 |
+
|
248 |
+
p.data.copy_(p_data_fp32)
|
249 |
+
|
250 |
+
return loss
|
Time-Travel-Rephotography/requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Torch
|
2 |
+
#--find-links https://download.pytorch.org/whl/torch_stable.html
|
3 |
+
#torch==1.4.0+cu100
|
4 |
+
#torchvision==0.11.2+cu100
|
5 |
+
#torchaudio==0.10.1+cu100
|
6 |
+
#setuptools==59.5.0
|
7 |
+
|
8 |
+
Pillow
|
9 |
+
ninja
|
10 |
+
tqdm
|
11 |
+
opencv-python
|
12 |
+
scikit-image
|
13 |
+
numpy
|
14 |
+
|
15 |
+
tensorboard
|
16 |
+
|
17 |
+
# for face alignment
|
18 |
+
tensorflow
|
19 |
+
#keras
|
20 |
+
#bz2
|
21 |
+
dlib
|
22 |
+
scipy
|
23 |
+
|
24 |
+
matplotlib
|
25 |
+
pprintpp
|
Time-Travel-Rephotography/scripts/download_checkpoints.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -exo
|
2 |
+
|
3 |
+
mkdir -p checkpoint
|
4 |
+
gdown https://drive.google.com/uc?id=1hWc2JLM58_PkwfLG23Q5IH3Ysj2Mo1nr -O checkpoint/e4e_ffhq_encode.pt
|
5 |
+
gdown https://drive.google.com/uc?id=1hvAAql9Jo0wlmLBSHRIGrtXHcKQE-Whn -O checkpoint/stylegan2-ffhq-config-f.pt
|
6 |
+
gdown https://drive.google.com/uc?id=1mbGWbjivZxMGxZqyyOHbE310aOkYe2BR -O checkpoint/vgg_face_dag.pt
|
7 |
+
mkdir -p checkpoint/encoder
|
8 |
+
gdown https://drive.google.com/uc?id=1ha4WXsaIpZfMHsqNLvqOPlUXsgh9VawU -O checkpoint/encoder/checkpoint_b.pt
|
9 |
+
gdown https://drive.google.com/uc?id=1hfxDLujRIGU0G7pOdW9MMSBRzxZBmSKJ -O checkpoint/encoder/checkpoint_g.pt
|
10 |
+
gdown https://drive.google.com/uc?id=1htekHopgxaW-MIjs6pYy7pyIK0v7Q0iS -O checkpoint/encoder/checkpoint_gb.pt
|
11 |
+
|
12 |
+
pushd third_party/face_parsing
|
13 |
+
./scripts/download_checkpoints.sh
|
14 |
+
popd
|
Time-Travel-Rephotography/scripts/install.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# conda create -n stylegan python=3.7
|
2 |
+
# conda activate stylegan
|
3 |
+
conda install -c conda-forge/label/gcc7 opencv --yes
|
4 |
+
conda install tensorflow-gpu=1.15 cudatoolkit=10.0 --yes
|
5 |
+
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch --yes
|
6 |
+
pip install -r requirements.txt
|
Time-Travel-Rephotography/scripts/run.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -x
|
2 |
+
|
3 |
+
# Example command
|
4 |
+
# ```
|
5 |
+
# ./scripts/run.sh b "dataset/Abraham Lincoln_01.png" 0.75
|
6 |
+
# ```
|
7 |
+
|
8 |
+
spectral_sensitivity="$1"
|
9 |
+
path="$2"
|
10 |
+
blur_radius="$3"
|
11 |
+
|
12 |
+
|
13 |
+
list="$(dirname "${path}")"
|
14 |
+
list="$(basename "${list}")"
|
15 |
+
|
16 |
+
if [ "${spectral_sensitivity}" == "b" ]; then
|
17 |
+
FLAGS=(--spectral_sensitivity b --encoder_ckpt checkpoint/encoder/checkpoint_b.pt);
|
18 |
+
elif [ "${spectral_sensitivity}" == "gb" ]; then
|
19 |
+
FLAGS=(--spectral_sensitivity "gb" --encoder_ckpt checkpoint/encoder/checkpoint_gb.pt);
|
20 |
+
else
|
21 |
+
FLAGS=(--spectral_sensitivity "g" --encoder_ckpt checkpoint/encoder/checkpoint_g.pt);
|
22 |
+
fi
|
23 |
+
|
24 |
+
name="${path%.*}"
|
25 |
+
name="${name##*/}"
|
26 |
+
echo "${name}"
|
27 |
+
|
28 |
+
# TODO: I did l2 or cos for contextual
|
29 |
+
time python projector.py \
|
30 |
+
"${path}" \
|
31 |
+
--gaussian "${blur_radius}" \
|
32 |
+
--log_dir "log/" \
|
33 |
+
--results_dir "results/" \
|
34 |
+
"${FLAGS[@]}"
|
Time-Travel-Rephotography/tools/__init__.py
ADDED
File without changes
|
Time-Travel-Rephotography/tools/data/__init__.py
ADDED
File without changes
|
Time-Travel-Rephotography/tools/data/align_images.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from os.path import join as pjoin
|
5 |
+
import sys
|
6 |
+
import bz2
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
from tqdm import tqdm
|
10 |
+
from tensorflow.keras.utils import get_file
|
11 |
+
from utils.ffhq_dataset.face_alignment import image_align
|
12 |
+
from utils.ffhq_dataset.landmarks_detector import LandmarksDetector
|
13 |
+
|
14 |
+
LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
|
15 |
+
|
16 |
+
|
17 |
+
def unpack_bz2(src_path):
|
18 |
+
data = bz2.BZ2File(src_path).read()
|
19 |
+
dst_path = src_path[:-4]
|
20 |
+
with open(dst_path, 'wb') as fp:
|
21 |
+
fp.write(data)
|
22 |
+
return dst_path
|
23 |
+
|
24 |
+
|
25 |
+
class SizePathMap(dict):
|
26 |
+
"""{size: {aligned_face_path0, aligned_face_path1, ...}, ...}"""
|
27 |
+
def add_item(self, size, path):
|
28 |
+
if size not in self:
|
29 |
+
self[size] = set()
|
30 |
+
self[size].add(path)
|
31 |
+
|
32 |
+
def get_sizes(self):
|
33 |
+
sizes = []
|
34 |
+
for key, paths in self.items():
|
35 |
+
sizes.extend([key,]*len(paths))
|
36 |
+
return sizes
|
37 |
+
|
38 |
+
def serialize(self):
|
39 |
+
result = {}
|
40 |
+
for key, paths in self.items():
|
41 |
+
result[key] = list(paths)
|
42 |
+
return result
|
43 |
+
|
44 |
+
|
45 |
+
def main(args):
|
46 |
+
landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2',
|
47 |
+
LANDMARKS_MODEL_URL, cache_subdir='temp'))
|
48 |
+
|
49 |
+
landmarks_detector = LandmarksDetector(landmarks_model_path)
|
50 |
+
face_sizes = SizePathMap()
|
51 |
+
raw_img_dir = args.raw_image_dir
|
52 |
+
img_names = [n for n in os.listdir(raw_img_dir) if os.path.isfile(pjoin(raw_img_dir, n))]
|
53 |
+
aligned_image_dir = args.aligned_image_dir
|
54 |
+
os.makedirs(aligned_image_dir, exist_ok=True)
|
55 |
+
pbar = tqdm(img_names)
|
56 |
+
for img_name in pbar:
|
57 |
+
pbar.set_description(img_name)
|
58 |
+
if os.path.splitext(img_name)[-1] == '.txt':
|
59 |
+
continue
|
60 |
+
raw_img_path = os.path.join(raw_img_dir, img_name)
|
61 |
+
try:
|
62 |
+
for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
|
63 |
+
face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
|
64 |
+
aligned_face_path = os.path.join(aligned_image_dir, face_img_name)
|
65 |
+
|
66 |
+
face_size = image_align(
|
67 |
+
raw_img_path, aligned_face_path, face_landmarks, resize=args.resize
|
68 |
+
)
|
69 |
+
face_sizes.add_item(face_size, aligned_face_path)
|
70 |
+
pbar.set_description(f"{img_name}: {face_size}")
|
71 |
+
|
72 |
+
if args.draw:
|
73 |
+
visual = LandmarksDetector.draw(cv2.imread(raw_img_path), face_landmarks)
|
74 |
+
cv2.imwrite(
|
75 |
+
pjoin(args.aligned_image_dir, os.path.splitext(face_img_name)[0] + "_landmarks.png"),
|
76 |
+
visual
|
77 |
+
)
|
78 |
+
except Exception as e:
|
79 |
+
print('[Error]', e, 'error happened when processing', raw_img_path)
|
80 |
+
|
81 |
+
print(args.raw_image_dir, ':')
|
82 |
+
sizes = face_sizes.get_sizes()
|
83 |
+
results = {
|
84 |
+
'mean_size': np.mean(sizes),
|
85 |
+
'num_faces_detected': len(sizes),
|
86 |
+
'num_images': len(img_names),
|
87 |
+
'sizes': sizes,
|
88 |
+
'size_path_dict': face_sizes.serialize(),
|
89 |
+
}
|
90 |
+
print('\t', results)
|
91 |
+
if args.out_stats is not None:
|
92 |
+
os.makedirs(os.path.dirname(args.out_stats), exist_ok=True)
|
93 |
+
with open(out_stats, 'w') as f:
|
94 |
+
json.dump(results, f)
|
95 |
+
|
96 |
+
|
97 |
+
def parse_args(args=None, namespace=None):
|
98 |
+
parser = argparse.ArgumentParser(description="""
|
99 |
+
Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
|
100 |
+
python align_images.py /raw_images /aligned_images
|
101 |
+
"""
|
102 |
+
)
|
103 |
+
parser.add_argument('raw_image_dir')
|
104 |
+
parser.add_argument('aligned_image_dir')
|
105 |
+
parser.add_argument('--resize',
|
106 |
+
help="True if want to resize to 1024",
|
107 |
+
action='store_true')
|
108 |
+
parser.add_argument('--draw',
|
109 |
+
help="True if want to visualize landmarks",
|
110 |
+
action='store_true')
|
111 |
+
parser.add_argument('--out_stats',
|
112 |
+
help="output_fn for statistics of faces", default=None)
|
113 |
+
return parser.parse_args(args=args, namespace=namespace)
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
main(parse_args())
|
Time-Travel-Rephotography/tools/initialize.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser, Namespace
|
2 |
+
from typing import (
|
3 |
+
List,
|
4 |
+
Tuple,
|
5 |
+
)
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torchvision.transforms import (
|
13 |
+
Compose,
|
14 |
+
Grayscale,
|
15 |
+
Resize,
|
16 |
+
ToTensor,
|
17 |
+
)
|
18 |
+
|
19 |
+
from models.encoder import Encoder
|
20 |
+
from models.encoder4editing import (
|
21 |
+
get_latents as get_e4e_latents,
|
22 |
+
setup_model as setup_e4e_model,
|
23 |
+
)
|
24 |
+
from utils.misc import (
|
25 |
+
optional_string,
|
26 |
+
iterable_to_str,
|
27 |
+
stem,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
class ColorEncoderArguments:
|
33 |
+
def __init__(self):
|
34 |
+
parser = ArgumentParser("Encode an image via a feed-forward encoder")
|
35 |
+
|
36 |
+
self.add_arguments(parser)
|
37 |
+
|
38 |
+
self.parser = parser
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_arguments(parser: ArgumentParser):
|
42 |
+
parser.add_argument("--encoder_ckpt", default=None,
|
43 |
+
help="encoder checkpoint path. initialize w with encoder output if specified")
|
44 |
+
parser.add_argument("--encoder_size", type=int, default=256,
|
45 |
+
help="Resize to this size to pass as input to the encoder")
|
46 |
+
|
47 |
+
|
48 |
+
class InitializerArguments:
|
49 |
+
@classmethod
|
50 |
+
def add_arguments(cls, parser: ArgumentParser):
|
51 |
+
ColorEncoderArguments.add_arguments(parser)
|
52 |
+
cls.add_e4e_arguments(parser)
|
53 |
+
parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
|
54 |
+
help="replace layers <start> to <end> in the e4e code by the color code")
|
55 |
+
|
56 |
+
parser.add_argument("--init_latent", default=None, help="path to init wp")
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def to_string(args: Namespace):
|
60 |
+
return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
|
61 |
+
else f"init({iterable_to_str(args.mix_layer_range)})")
|
62 |
+
#+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def add_e4e_arguments(parser: ArgumentParser):
|
66 |
+
parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
|
67 |
+
help="e4e checkpoint path.")
|
68 |
+
parser.add_argument("--e4e_size", type=int, default=256,
|
69 |
+
help="Resize to this size to pass as input to the e4e")
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def create_color_encoder(args: Namespace):
|
74 |
+
encoder = Encoder(1, args.encoder_size, 512)
|
75 |
+
ckpt = torch.load(args.encoder_ckpt)
|
76 |
+
encoder.load_state_dict(ckpt["model"])
|
77 |
+
return encoder
|
78 |
+
|
79 |
+
|
80 |
+
def transform_input(img: Image):
|
81 |
+
tsfm = Compose([
|
82 |
+
Grayscale(),
|
83 |
+
Resize(args.encoder_size),
|
84 |
+
ToTensor(),
|
85 |
+
])
|
86 |
+
return tsfm(img)
|
87 |
+
|
88 |
+
|
89 |
+
def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
|
90 |
+
assert args.encoder_size is not None
|
91 |
+
|
92 |
+
imgs = Resize(args.encoder_size)(imgs)
|
93 |
+
|
94 |
+
color_encoder = create_color_encoder(args).to(imgs.device)
|
95 |
+
color_encoder.eval()
|
96 |
+
with torch.no_grad():
|
97 |
+
latent = color_encoder(imgs)
|
98 |
+
return latent.detach()
|
99 |
+
|
100 |
+
|
101 |
+
def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
|
102 |
+
return F.interpolate(imgs, size=size, mode='bilinear')
|
103 |
+
|
104 |
+
|
105 |
+
class Initializer(nn.Module):
|
106 |
+
def __init__(self, args: Namespace):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
self.path = None
|
110 |
+
if args.init_latent is not None:
|
111 |
+
self.path = args.init_latent
|
112 |
+
return
|
113 |
+
|
114 |
+
|
115 |
+
assert args.encoder_size is not None
|
116 |
+
self.color_encoder = create_color_encoder(args)
|
117 |
+
self.color_encoder.eval()
|
118 |
+
self.color_encoder_size = args.encoder_size
|
119 |
+
|
120 |
+
self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
|
121 |
+
assert 'cars_' not in e4e_opts.dataset_type
|
122 |
+
self.e4e.decoder.eval()
|
123 |
+
self.e4e.eval()
|
124 |
+
self.e4e_size = args.e4e_size
|
125 |
+
|
126 |
+
self.mix_layer_range = args.mix_layer_range
|
127 |
+
|
128 |
+
def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
|
129 |
+
"""
|
130 |
+
Get the color W code
|
131 |
+
"""
|
132 |
+
imgs = resize(imgs, self.color_encoder_size)
|
133 |
+
|
134 |
+
latent = self.color_encoder(imgs)
|
135 |
+
|
136 |
+
return latent
|
137 |
+
|
138 |
+
def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
|
139 |
+
imgs = resize(imgs, self.e4e_size)
|
140 |
+
imgs = (imgs - 0.5) / 0.5
|
141 |
+
if imgs.shape[1] == 1: # 1 channel
|
142 |
+
imgs = imgs.repeat(1, 3, 1, 1)
|
143 |
+
return get_e4e_latents(self.e4e, imgs)
|
144 |
+
|
145 |
+
def load(self, device: torch.device):
|
146 |
+
latent_np = np.load(self.path)
|
147 |
+
return torch.tensor(latent_np, device=device)[None, ...]
|
148 |
+
|
149 |
+
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
150 |
+
if self.path is not None:
|
151 |
+
return self.load(imgs.device)
|
152 |
+
|
153 |
+
shape_code = self.encode_shape(imgs)
|
154 |
+
color_code = self.encode_color(imgs)
|
155 |
+
|
156 |
+
# style mix
|
157 |
+
latent = shape_code
|
158 |
+
start, end = self.mix_layer_range
|
159 |
+
latent[:, start:end] = color_code
|
160 |
+
return latent
|
Time-Travel-Rephotography/tools/match_histogram.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import (
|
2 |
+
ArgumentParser,
|
3 |
+
Namespace,
|
4 |
+
)
|
5 |
+
import os
|
6 |
+
from os.path import join as pjoin
|
7 |
+
from typing import Optional
|
8 |
+
import sys
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
from skimage import exposure
|
13 |
+
|
14 |
+
|
15 |
+
# sys.path.append('Face_Detection')
|
16 |
+
# from align_warp_back_multiple_dlib import match_histograms
|
17 |
+
|
18 |
+
|
19 |
+
def calculate_cdf(histogram):
|
20 |
+
"""
|
21 |
+
This method calculates the cumulative distribution function
|
22 |
+
:param array histogram: The values of the histogram
|
23 |
+
:return: normalized_cdf: The normalized cumulative distribution function
|
24 |
+
:rtype: array
|
25 |
+
"""
|
26 |
+
# Get the cumulative sum of the elements
|
27 |
+
cdf = histogram.cumsum()
|
28 |
+
|
29 |
+
# Normalize the cdf
|
30 |
+
normalized_cdf = cdf / float(cdf.max())
|
31 |
+
|
32 |
+
return normalized_cdf
|
33 |
+
|
34 |
+
|
35 |
+
def calculate_lookup(src_cdf, ref_cdf):
|
36 |
+
"""
|
37 |
+
This method creates the lookup table
|
38 |
+
:param array src_cdf: The cdf for the source image
|
39 |
+
:param array ref_cdf: The cdf for the reference image
|
40 |
+
:return: lookup_table: The lookup table
|
41 |
+
:rtype: array
|
42 |
+
"""
|
43 |
+
lookup_table = np.zeros(256)
|
44 |
+
lookup_val = 0
|
45 |
+
for src_pixel_val in range(len(src_cdf)):
|
46 |
+
lookup_val
|
47 |
+
for ref_pixel_val in range(len(ref_cdf)):
|
48 |
+
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
|
49 |
+
lookup_val = ref_pixel_val
|
50 |
+
break
|
51 |
+
lookup_table[src_pixel_val] = lookup_val
|
52 |
+
return lookup_table
|
53 |
+
|
54 |
+
|
55 |
+
def match_histograms(src_image, ref_image, src_mask=None, ref_mask=None):
|
56 |
+
"""
|
57 |
+
This method matches the source image histogram to the
|
58 |
+
reference signal
|
59 |
+
:param image src_image: The original source image
|
60 |
+
:param image ref_image: The reference image
|
61 |
+
:return: image_after_matching
|
62 |
+
:rtype: image (array)
|
63 |
+
"""
|
64 |
+
# Split the images into the different color channels
|
65 |
+
# b means blue, g means green and r means red
|
66 |
+
src_b, src_g, src_r = cv2.split(src_image)
|
67 |
+
ref_b, ref_g, ref_r = cv2.split(ref_image)
|
68 |
+
|
69 |
+
def rv(im):
|
70 |
+
if ref_mask is None:
|
71 |
+
return im.flatten()
|
72 |
+
return im[ref_mask]
|
73 |
+
|
74 |
+
def sv(im):
|
75 |
+
if src_mask is None:
|
76 |
+
return im.flatten()
|
77 |
+
return im[src_mask]
|
78 |
+
|
79 |
+
# Compute the b, g, and r histograms separately
|
80 |
+
# The flatten() Numpy method returns a copy of the array c
|
81 |
+
# collapsed into one dimension.
|
82 |
+
src_hist_blue, bin_0 = np.histogram(sv(src_b), 256, [0, 256])
|
83 |
+
src_hist_green, bin_1 = np.histogram(sv(src_g), 256, [0, 256])
|
84 |
+
src_hist_red, bin_2 = np.histogram(sv(src_r), 256, [0, 256])
|
85 |
+
ref_hist_blue, bin_3 = np.histogram(rv(ref_b), 256, [0, 256])
|
86 |
+
ref_hist_green, bin_4 = np.histogram(rv(ref_g), 256, [0, 256])
|
87 |
+
ref_hist_red, bin_5 = np.histogram(rv(ref_r), 256, [0, 256])
|
88 |
+
|
89 |
+
# Compute the normalized cdf for the source and reference image
|
90 |
+
src_cdf_blue = calculate_cdf(src_hist_blue)
|
91 |
+
src_cdf_green = calculate_cdf(src_hist_green)
|
92 |
+
src_cdf_red = calculate_cdf(src_hist_red)
|
93 |
+
ref_cdf_blue = calculate_cdf(ref_hist_blue)
|
94 |
+
ref_cdf_green = calculate_cdf(ref_hist_green)
|
95 |
+
ref_cdf_red = calculate_cdf(ref_hist_red)
|
96 |
+
|
97 |
+
# Make a separate lookup table for each color
|
98 |
+
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
|
99 |
+
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
|
100 |
+
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
|
101 |
+
|
102 |
+
# Use the lookup function to transform the colors of the original
|
103 |
+
# source image
|
104 |
+
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
|
105 |
+
green_after_transform = cv2.LUT(src_g, green_lookup_table)
|
106 |
+
red_after_transform = cv2.LUT(src_r, red_lookup_table)
|
107 |
+
|
108 |
+
# Put the image back together
|
109 |
+
image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
|
110 |
+
image_after_matching = cv2.convertScaleAbs(image_after_matching)
|
111 |
+
|
112 |
+
return image_after_matching
|
113 |
+
|
114 |
+
|
115 |
+
def convert_to_BW(im, mode):
|
116 |
+
if mode == "b":
|
117 |
+
gray = im[..., 0]
|
118 |
+
elif mode == "gb":
|
119 |
+
gray = (im[..., 0].astype(float) + im[..., 1]) / 2.0
|
120 |
+
else:
|
121 |
+
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
122 |
+
gray = gray.astype(np.uint8)
|
123 |
+
|
124 |
+
return np.stack([gray] * 3, axis=-1)
|
125 |
+
|
126 |
+
|
127 |
+
def parse_args(args=None, namespace: Optional[Namespace] = None):
|
128 |
+
parser = ArgumentParser('match histogram of src to ref')
|
129 |
+
parser.add_argument('src')
|
130 |
+
parser.add_argument('ref')
|
131 |
+
parser.add_argument('--out', default=None, help="converted src that matches ref")
|
132 |
+
parser.add_argument('--src_mask', default=None, help="mask on which to match the histogram")
|
133 |
+
parser.add_argument('--ref_mask', default=None, help="mask on which to match the histogram")
|
134 |
+
parser.add_argument('--spectral_sensitivity', choices=['b', 'gb', 'g'], help="match the histogram of corresponding sensitive channel(s)")
|
135 |
+
parser.add_argument('--crop', type=int, default=0, help="crop the boundary to match")
|
136 |
+
return parser.parse_args(args=args, namespace=namespace)
|
137 |
+
|
138 |
+
|
139 |
+
def main(args):
|
140 |
+
A = cv2.imread(args.ref)
|
141 |
+
A = convert_to_BW(A, args.spectral_sensitivity)
|
142 |
+
B = cv2.imread(args.src, 0)
|
143 |
+
B = np.stack((B,) * 3, axis=-1)
|
144 |
+
|
145 |
+
mask_A = cv2.resize(cv2.imread(args.ref_mask, 0), A.shape[:2][::-1],
|
146 |
+
interpolation=cv2.INTER_NEAREST) > 0 if args.ref_mask else None
|
147 |
+
mask_B = cv2.resize(cv2.imread(args.src_mask, 0), B.shape[:2][::-1],
|
148 |
+
interpolation=cv2.INTER_NEAREST) > 0 if args.src_mask else None
|
149 |
+
|
150 |
+
if args.crop > 0:
|
151 |
+
c = args.crop
|
152 |
+
bc = int(c / A.shape[0] * B.shape[0] + 0.5)
|
153 |
+
A = A[c:-c, c:-c]
|
154 |
+
B = B[bc:-bc, bc:-bc]
|
155 |
+
|
156 |
+
B = match_histograms(B, A, src_mask=mask_B, ref_mask=mask_A)
|
157 |
+
# B = exposure.match_histograms(B, A, multichannel=True)
|
158 |
+
|
159 |
+
if args.out:
|
160 |
+
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
161 |
+
cv2.imwrite(args.out, B)
|
162 |
+
|
163 |
+
return B
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == "__main__":
|
167 |
+
main(parse_args())
|
Time-Travel-Rephotography/tools/match_skin_histogram.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
import os
|
3 |
+
from os.path import join as pjoin
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from tools import (
|
10 |
+
parse_face,
|
11 |
+
match_histogram,
|
12 |
+
)
|
13 |
+
from utils.torch_helpers import make_image
|
14 |
+
from utils.misc import stem
|
15 |
+
|
16 |
+
|
17 |
+
def match_skin_histogram(
|
18 |
+
imgs: torch.Tensor,
|
19 |
+
sibling_img: torch.Tensor,
|
20 |
+
spectral_sensitivity,
|
21 |
+
im_sibling_dir: str,
|
22 |
+
mask_dir: str,
|
23 |
+
matched_hist_fn: Optional[str] = None,
|
24 |
+
normalize=None, # normalize the range of the tensor
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Extract the skin of the input and sibling images. Create a new input image by matching
|
28 |
+
its histogram to the sibling.
|
29 |
+
"""
|
30 |
+
# TODO: Currently only allows imgs of batch size 1
|
31 |
+
im_sibling_dir = os.path.abspath(im_sibling_dir)
|
32 |
+
mask_dir = os.path.abspath(mask_dir)
|
33 |
+
|
34 |
+
img_np = make_image(imgs)[0]
|
35 |
+
sibling_np = make_image(sibling_img)[0][...,::-1]
|
36 |
+
|
37 |
+
# save img, sibling
|
38 |
+
os.makedirs(im_sibling_dir, exist_ok=True)
|
39 |
+
im_name, sibling_name = 'input.png', 'sibling.png'
|
40 |
+
cv2.imwrite(pjoin(im_sibling_dir, im_name), img_np)
|
41 |
+
cv2.imwrite(pjoin(im_sibling_dir, sibling_name), sibling_np)
|
42 |
+
|
43 |
+
# face parsing
|
44 |
+
parse_face.main(
|
45 |
+
Namespace(in_dir=im_sibling_dir, out_dir=mask_dir, include_hair=False)
|
46 |
+
)
|
47 |
+
|
48 |
+
# match_histogram
|
49 |
+
mh_args = match_histogram.parse_args(
|
50 |
+
args=[
|
51 |
+
pjoin(im_sibling_dir, im_name),
|
52 |
+
pjoin(im_sibling_dir, sibling_name),
|
53 |
+
],
|
54 |
+
namespace=Namespace(
|
55 |
+
out=matched_hist_fn if matched_hist_fn else pjoin(im_sibling_dir, "match_histogram.png"),
|
56 |
+
src_mask=pjoin(mask_dir, im_name),
|
57 |
+
ref_mask=pjoin(mask_dir, sibling_name),
|
58 |
+
spectral_sensitivity=spectral_sensitivity,
|
59 |
+
)
|
60 |
+
)
|
61 |
+
matched_np = match_histogram.main(mh_args) / 255.0 # [0, 1]
|
62 |
+
matched = torch.FloatTensor(matched_np).permute(2, 0, 1)[None,...] #BCHW
|
63 |
+
|
64 |
+
if normalize is not None:
|
65 |
+
matched = normalize(matched)
|
66 |
+
|
67 |
+
return matched
|
Time-Travel-Rephotography/tools/parse_face.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
import os
|
3 |
+
from os.path import join as pjoin
|
4 |
+
from subprocess import run
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def create_skin_mask(anno_dir, mask_dir, skin_thresh=13, include_hair=False):
|
12 |
+
names = os.listdir(anno_dir)
|
13 |
+
names = [n for n in names if n.endswith('.png')]
|
14 |
+
os.makedirs(mask_dir, exist_ok=True)
|
15 |
+
for name in tqdm(names):
|
16 |
+
anno = cv2.imread(pjoin(anno_dir, name), 0)
|
17 |
+
mask = np.logical_and(0 < anno, anno <= skin_thresh)
|
18 |
+
if include_hair:
|
19 |
+
mask |= anno == 17
|
20 |
+
cv2.imwrite(pjoin(mask_dir, name), mask * 255)
|
21 |
+
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
FACE_PARSING_DIR = 'third_party/face_parsing'
|
25 |
+
|
26 |
+
main_env = os.getcwd()
|
27 |
+
os.chdir(FACE_PARSING_DIR)
|
28 |
+
tmp_parse_dir = pjoin(args.out_dir, 'face_parsing')
|
29 |
+
cmd = [
|
30 |
+
'python',
|
31 |
+
'test.py',
|
32 |
+
args.in_dir,
|
33 |
+
tmp_parse_dir,
|
34 |
+
]
|
35 |
+
print(' '.join(cmd))
|
36 |
+
run(cmd)
|
37 |
+
|
38 |
+
create_skin_mask(tmp_parse_dir, args.out_dir, include_hair=args.include_hair)
|
39 |
+
|
40 |
+
os.chdir(main_env)
|
41 |
+
|
42 |
+
|
43 |
+
def parse_args(args=None, namespace=None):
|
44 |
+
parser = ArgumentParser("Face Parsing and generate skin (& hair) mask")
|
45 |
+
parser.add_argument('in_dir')
|
46 |
+
parser.add_argument('out_dir')
|
47 |
+
parser.add_argument('--include_hair', action="store_true", help="include hair in the mask")
|
48 |
+
return parser.parse_args(args=args, namespace=namespace)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
main(parse_args())
|
53 |
+
|
54 |
+
|
55 |
+
|
Time-Travel-Rephotography/utils/__init__.py
ADDED
File without changes
|
Time-Travel-Rephotography/utils/ffhq_dataset/__init__.py
ADDED
File without changes
|
Time-Travel-Rephotography/utils/ffhq_dataset/face_alignment.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.ndimage
|
3 |
+
import os
|
4 |
+
import PIL.Image
|
5 |
+
|
6 |
+
|
7 |
+
def image_align(src_file, dst_file, face_landmarks, resize=True, output_size=1024, transform_size=4096, enable_padding=True):
|
8 |
+
# Align function from FFHQ dataset pre-processing step
|
9 |
+
# https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
|
10 |
+
|
11 |
+
lm = np.array(face_landmarks)
|
12 |
+
lm_chin = lm[0 : 17] # left-right
|
13 |
+
lm_eyebrow_left = lm[17 : 22] # left-right
|
14 |
+
lm_eyebrow_right = lm[22 : 27] # left-right
|
15 |
+
lm_nose = lm[27 : 31] # top-down
|
16 |
+
lm_nostrils = lm[31 : 36] # top-down
|
17 |
+
lm_eye_left = lm[36 : 42] # left-clockwise
|
18 |
+
lm_eye_right = lm[42 : 48] # left-clockwise
|
19 |
+
lm_mouth_outer = lm[48 : 60] # left-clockwise
|
20 |
+
lm_mouth_inner = lm[60 : 68] # left-clockwise
|
21 |
+
|
22 |
+
# Calculate auxiliary vectors.
|
23 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
24 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
25 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
26 |
+
eye_to_eye = eye_right - eye_left
|
27 |
+
mouth_left = lm_mouth_outer[0]
|
28 |
+
mouth_right = lm_mouth_outer[6]
|
29 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
30 |
+
eye_to_mouth = mouth_avg - eye_avg
|
31 |
+
|
32 |
+
# Choose oriented crop rectangle.
|
33 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
34 |
+
x /= np.hypot(*x)
|
35 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
36 |
+
y = np.flipud(x) * [-1, 1]
|
37 |
+
c = eye_avg + eye_to_mouth * 0.1
|
38 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
39 |
+
qsize = np.hypot(*x) * 2
|
40 |
+
|
41 |
+
# Load in-the-wild image.
|
42 |
+
if not os.path.isfile(src_file):
|
43 |
+
print('\nCannot find source image. Please run "--wilds" before "--align".')
|
44 |
+
return
|
45 |
+
#img = cv2.imread(src_file)
|
46 |
+
#img = PIL.Image.fromarray(img)
|
47 |
+
img = PIL.Image.open(src_file)
|
48 |
+
|
49 |
+
# Shrink.
|
50 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
51 |
+
if shrink > 1:
|
52 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
|
53 |
+
img = img.resize(rsize, PIL.Image.ANTIALIAS)
|
54 |
+
quad /= shrink
|
55 |
+
qsize /= shrink
|
56 |
+
|
57 |
+
# Crop.
|
58 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
59 |
+
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
|
60 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
|
61 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
62 |
+
img = img.crop(crop)
|
63 |
+
quad -= crop[0:2]
|
64 |
+
|
65 |
+
# Pad.
|
66 |
+
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
|
67 |
+
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
|
68 |
+
if enable_padding and max(pad) > border - 4:
|
69 |
+
img = np.float32(img)
|
70 |
+
if img.ndim == 2:
|
71 |
+
img = np.stack((img,)*3, axis=-1)
|
72 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
73 |
+
img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
74 |
+
h, w, _ = img.shape
|
75 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
76 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
|
77 |
+
blur = qsize * 0.02
|
78 |
+
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
79 |
+
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
|
80 |
+
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
|
81 |
+
quad += pad[:2]
|
82 |
+
|
83 |
+
xmin, xmax = np.amin(quad[:,0]), np.amax(quad[:,0])
|
84 |
+
ymin, ymax = np.amin(quad[:,1]), np.amax(quad[:,1])
|
85 |
+
quad_size = int(max(xmax-xmin, ymax-ymin)+0.5)
|
86 |
+
|
87 |
+
if not resize:
|
88 |
+
transform_size = output_size = quad_size
|
89 |
+
|
90 |
+
|
91 |
+
# Transform.
|
92 |
+
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
93 |
+
if output_size < transform_size:
|
94 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
95 |
+
|
96 |
+
# Save aligned image.
|
97 |
+
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
98 |
+
img.save(dst_file, 'PNG')
|
99 |
+
return quad_size
|
Time-Travel-Rephotography/utils/ffhq_dataset/landmarks_detector.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dlib
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
class LandmarksDetector:
|
6 |
+
def __init__(self, predictor_model_path):
|
7 |
+
"""
|
8 |
+
:param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
|
9 |
+
"""
|
10 |
+
self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used
|
11 |
+
self.shape_predictor = dlib.shape_predictor(predictor_model_path)
|
12 |
+
|
13 |
+
def get_landmarks(self, image):
|
14 |
+
img = dlib.load_rgb_image(image)
|
15 |
+
dets = self.detector(img, 1)
|
16 |
+
#print('face bounding boxes', dets)
|
17 |
+
|
18 |
+
for detection in dets:
|
19 |
+
face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()]
|
20 |
+
#print('face landmarks', face_landmarks)
|
21 |
+
yield face_landmarks
|
22 |
+
|
23 |
+
def draw(img, landmarks):
|
24 |
+
for (x, y) in landmarks:
|
25 |
+
cv2.circle(img, (x, y), 1, (0, 0, 255), -1)
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
class DNNLandmarksDetector:
|
30 |
+
def __init__(self, predictor_model_path, DNN='TF'):
|
31 |
+
"""
|
32 |
+
:param
|
33 |
+
DNN: "TF" or "CAFFE"
|
34 |
+
predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
|
35 |
+
"""
|
36 |
+
if DNN == "CAFFE":
|
37 |
+
modelFile = "res10_300x300_ssd_iter_140000_fp16.caffemodel"
|
38 |
+
configFile = "deploy.prototxt"
|
39 |
+
net = cv2.dnn.readNetFromCaffe(configFile, modelFile)
|
40 |
+
else:
|
41 |
+
modelFile = "opencv_face_detector_uint8.pb"
|
42 |
+
configFile = "opencv_face_detector.pbtxt"
|
43 |
+
net = cv2.dnn.readNetFromTensorflow(modelFile, configFile)
|
44 |
+
|
45 |
+
self.shape_predictor = dlib.shape_predictor(predictor_model_path)
|
46 |
+
|
47 |
+
def detect_faces(self, image, conf_threshold=0):
|
48 |
+
H, W = image.shape[:2]
|
49 |
+
blob = cv2.dnn.blobFromImage(image, 1.0, (300, 300), [104, 117, 123], False, False)
|
50 |
+
net.setInput(blob)
|
51 |
+
detections = net.forward()
|
52 |
+
bboxes = []
|
53 |
+
for i in range(detections.shape[2]):
|
54 |
+
confidence = detections[0, 0, i, 2]
|
55 |
+
if confidence > conf_threshold:
|
56 |
+
x1 = int(detections[0, 0, i, 3] * W)
|
57 |
+
y1 = int(detections[0, 0, i, 4] * H)
|
58 |
+
x2 = int(detections[0, 0, i, 5] * W)
|
59 |
+
y2 = int(detections[0, 0, i, 6] * H)
|
60 |
+
bboxes.append(dlib.rectangle(x1, y1, x2, y2))
|
61 |
+
return bboxes
|
62 |
+
|
63 |
+
def get_landmarks(self, image):
|
64 |
+
img = cv2.imread(image)
|
65 |
+
dets = self.detect_faces(img, 0)
|
66 |
+
print('face bounding boxes', dets)
|
67 |
+
|
68 |
+
for detection in dets:
|
69 |
+
face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()]
|
70 |
+
print('face landmarks', face_landmarks)
|
71 |
+
yield face_landmarks
|
Time-Travel-Rephotography/utils/misc.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Iterable
|
3 |
+
|
4 |
+
|
5 |
+
def optional_string(condition: bool, string: str):
|
6 |
+
return string if condition else ""
|
7 |
+
|
8 |
+
|
9 |
+
def parent_dir(path: str) -> str:
|
10 |
+
return os.path.basename(os.path.dirname(path))
|
11 |
+
|
12 |
+
|
13 |
+
def stem(path: str) -> str:
|
14 |
+
return os.path.splitext(os.path.basename(path))[0]
|
15 |
+
|
16 |
+
|
17 |
+
def iterable_to_str(iterable: Iterable) -> str:
|
18 |
+
return ','.join([str(x) for x in iterable])
|
Time-Travel-Rephotography/utils/optimize.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from argparse import (
|
3 |
+
ArgumentParser,
|
4 |
+
Namespace,
|
5 |
+
)
|
6 |
+
from typing import (
|
7 |
+
Dict,
|
8 |
+
Iterable,
|
9 |
+
Optional,
|
10 |
+
Tuple,
|
11 |
+
)
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from tqdm import tqdm
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.utils.tensorboard import SummaryWriter
|
19 |
+
from torchvision.utils import make_grid
|
20 |
+
from torchvision.transforms import Resize
|
21 |
+
|
22 |
+
#from optim import get_optimizer_class, OPTIMIZER_MAP
|
23 |
+
from losses.regularize_noise import NoiseRegularizer
|
24 |
+
from optim import RAdam
|
25 |
+
from utils.misc import (
|
26 |
+
iterable_to_str,
|
27 |
+
optional_string,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
class OptimizerArguments:
|
32 |
+
@staticmethod
|
33 |
+
def add_arguments(parser: ArgumentParser):
|
34 |
+
parser.add_argument('--coarse_min', type=int, default=32)
|
35 |
+
parser.add_argument('--wplus_step', type=int, nargs="+", default=[250, 750], help="#step for optimizing w_plus")
|
36 |
+
#parser.add_argument('--lr_rampup', type=float, default=0.05)
|
37 |
+
#parser.add_argument('--lr_rampdown', type=float, default=0.25)
|
38 |
+
parser.add_argument('--lr', type=float, default=0.1)
|
39 |
+
parser.add_argument('--noise_strength', type=float, default=.0)
|
40 |
+
parser.add_argument('--noise_ramp', type=float, default=0.75)
|
41 |
+
#parser.add_argument('--optimize_noise', action="store_true")
|
42 |
+
parser.add_argument('--camera_lr', type=float, default=0.01)
|
43 |
+
|
44 |
+
parser.add_argument("--log_dir", default="log/projector", help="tensorboard log directory")
|
45 |
+
parser.add_argument("--log_freq", type=int, default=10, help="log frequency")
|
46 |
+
parser.add_argument("--log_visual_freq", type=int, default=50, help="log frequency")
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def to_string(args: Namespace) -> str:
|
50 |
+
return (
|
51 |
+
f"lr{args.lr}_{args.camera_lr}-c{args.coarse_min}"
|
52 |
+
+ f"-wp({iterable_to_str(args.wplus_step)})"
|
53 |
+
+ optional_string(args.noise_strength, f"-n{args.noise_strength}")
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class LatentNoiser(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self, generator: torch.nn,
|
60 |
+
noise_ramp: float = 0.75, noise_strength: float = 0.05,
|
61 |
+
n_mean_latent: int = 10000
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.noise_ramp = noise_ramp
|
66 |
+
self.noise_strength = noise_strength
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
# TODO: get 512 from generator
|
70 |
+
noise_sample = torch.randn(n_mean_latent, 512, device=generator.device)
|
71 |
+
latent_out = generator.style(noise_sample)
|
72 |
+
|
73 |
+
latent_mean = latent_out.mean(0)
|
74 |
+
self.latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
|
75 |
+
|
76 |
+
def forward(self, latent: torch.Tensor, t: float) -> torch.Tensor:
|
77 |
+
strength = self.latent_std * self.noise_strength * max(0, 1 - t / self.noise_ramp) ** 2
|
78 |
+
noise = torch.randn_like(latent) * strength
|
79 |
+
return latent + noise
|
80 |
+
|
81 |
+
|
82 |
+
class Optimizer:
|
83 |
+
@classmethod
|
84 |
+
def optimize(
|
85 |
+
cls,
|
86 |
+
generator: torch.nn,
|
87 |
+
criterion: torch.nn,
|
88 |
+
degrade: torch.nn,
|
89 |
+
target: torch.Tensor, # only used in writer since it's mostly baked in criterion
|
90 |
+
latent_init: torch.Tensor,
|
91 |
+
noise_init: torch.Tensor,
|
92 |
+
args: Namespace,
|
93 |
+
writer: Optional[SummaryWriter] = None,
|
94 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95 |
+
# do not optimize generator
|
96 |
+
generator = generator.eval()
|
97 |
+
target = target.detach()
|
98 |
+
# prepare parameters
|
99 |
+
noises = []
|
100 |
+
for n in noise_init:
|
101 |
+
noise = n.detach().clone()
|
102 |
+
noise.requires_grad = True
|
103 |
+
noises.append(noise)
|
104 |
+
|
105 |
+
|
106 |
+
def create_parameters(latent_coarse):
|
107 |
+
parameters = [
|
108 |
+
{'params': [latent_coarse], 'lr': args.lr},
|
109 |
+
{'params': noises, 'lr': args.lr},
|
110 |
+
{'params': degrade.parameters(), 'lr': args.camera_lr},
|
111 |
+
]
|
112 |
+
return parameters
|
113 |
+
|
114 |
+
|
115 |
+
device = target.device
|
116 |
+
|
117 |
+
# start optimize
|
118 |
+
total_steps = np.sum(args.wplus_step)
|
119 |
+
max_coarse_size = (2 ** (len(args.wplus_step) - 1)) * args.coarse_min
|
120 |
+
noiser = LatentNoiser(generator, noise_ramp=args.noise_ramp, noise_strength=args.noise_strength).to(device)
|
121 |
+
latent = latent_init.detach().clone()
|
122 |
+
for coarse_level, steps in enumerate(args.wplus_step):
|
123 |
+
if criterion.weights["contextual"] > 0:
|
124 |
+
with torch.no_grad():
|
125 |
+
# synthesize new sibling image using the current optimization results
|
126 |
+
# FIXME: update rgbs sibling
|
127 |
+
sibling, _, _ = generator([latent], input_is_latent=True, randomize_noise=True)
|
128 |
+
criterion.update_sibling(sibling)
|
129 |
+
|
130 |
+
coarse_size = (2 ** coarse_level) * args.coarse_min
|
131 |
+
latent_coarse, latent_fine = cls.split_latent(
|
132 |
+
latent, generator.get_latent_size(coarse_size))
|
133 |
+
parameters = create_parameters(latent_coarse)
|
134 |
+
optimizer = RAdam(parameters)
|
135 |
+
|
136 |
+
print(f"Optimizing {coarse_size}x{coarse_size}")
|
137 |
+
pbar = tqdm(range(steps))
|
138 |
+
for si in pbar:
|
139 |
+
latent = torch.cat((latent_coarse, latent_fine), dim=1)
|
140 |
+
niters = si + np.sum(args.wplus_step[:coarse_level])
|
141 |
+
latent_noisy = noiser(latent, niters / total_steps)
|
142 |
+
img_gen, _, rgbs = generator([latent_noisy], input_is_latent=True, noise=noises)
|
143 |
+
# TODO: use coarse_size instead of args.coarse_size for rgb_level
|
144 |
+
loss, losses = criterion(img_gen, degrade=degrade, noises=noises, rgbs=rgbs)
|
145 |
+
|
146 |
+
optimizer.zero_grad()
|
147 |
+
loss.backward()
|
148 |
+
optimizer.step()
|
149 |
+
|
150 |
+
NoiseRegularizer.normalize(noises)
|
151 |
+
|
152 |
+
# log
|
153 |
+
pbar.set_description("; ".join([f"{k}: {v.item(): .3e}" for k, v in losses.items()]))
|
154 |
+
|
155 |
+
if writer is not None and niters % args.log_freq == 0:
|
156 |
+
cls.log_losses(writer, niters, loss, losses, criterion.weights)
|
157 |
+
cls.log_parameters(writer, niters, degrade.named_parameters())
|
158 |
+
if writer is not None and niters % args.log_visual_freq == 0:
|
159 |
+
cls.log_visuals(writer, niters, img_gen, target, degraded=degrade(img_gen), rgbs=rgbs)
|
160 |
+
|
161 |
+
latent = torch.cat((latent_coarse, latent_fine), dim=1).detach()
|
162 |
+
|
163 |
+
return latent, noises
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def split_latent(latent: torch.Tensor, coarse_latent_size: int):
|
167 |
+
latent_coarse = latent[:, :coarse_latent_size]
|
168 |
+
latent_coarse.requires_grad = True
|
169 |
+
latent_fine = latent[:, coarse_latent_size:]
|
170 |
+
latent_fine.requires_grad = False
|
171 |
+
return latent_coarse, latent_fine
|
172 |
+
|
173 |
+
@staticmethod
|
174 |
+
def log_losses(
|
175 |
+
writer: SummaryWriter,
|
176 |
+
niters: int,
|
177 |
+
loss_total: torch.Tensor,
|
178 |
+
losses: Dict[str, torch.Tensor],
|
179 |
+
weights: Optional[Dict[str, torch.Tensor]] = None
|
180 |
+
):
|
181 |
+
writer.add_scalar("loss", loss_total.item(), niters)
|
182 |
+
|
183 |
+
for name, loss in losses.items():
|
184 |
+
writer.add_scalar(name, loss.item(), niters)
|
185 |
+
if weights is not None:
|
186 |
+
writer.add_scalar(f"weighted_{name}", weights[name] * loss.item(), niters)
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def log_parameters(
|
190 |
+
writer: SummaryWriter,
|
191 |
+
niters: int,
|
192 |
+
named_parameters: Iterable[Tuple[str, torch.nn.Parameter]],
|
193 |
+
):
|
194 |
+
for name, para in named_parameters:
|
195 |
+
writer.add_scalar(name, para.item(), niters)
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def log_visuals(
|
199 |
+
cls,
|
200 |
+
writer: SummaryWriter,
|
201 |
+
niters: int,
|
202 |
+
img: torch.Tensor,
|
203 |
+
target: torch.Tensor,
|
204 |
+
degraded=None,
|
205 |
+
rgbs=None,
|
206 |
+
):
|
207 |
+
if target.shape[-1] != img.shape[-1]:
|
208 |
+
visual = make_grid(img, nrow=1, normalize=True, range=(-1, 1))
|
209 |
+
writer.add_image("pred", visual, niters)
|
210 |
+
|
211 |
+
def resize(img):
|
212 |
+
return F.interpolate(img, size=target.shape[2:], mode="area")
|
213 |
+
|
214 |
+
vis = resize(img)
|
215 |
+
if degraded is not None:
|
216 |
+
vis = torch.cat((resize(degraded), vis), dim=-1)
|
217 |
+
visual = make_grid(torch.cat((target.repeat(1, vis.shape[1] // target.shape[1], 1, 1), vis), dim=-1), nrow=1, normalize=True, range=(-1, 1))
|
218 |
+
writer.add_image("gnd[-degraded]-pred", visual, niters)
|
219 |
+
|
220 |
+
# log to rgbs
|
221 |
+
if rgbs is not None:
|
222 |
+
cls.log_torgbs(writer, niters, rgbs)
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def log_torgbs(writer: SummaryWriter, niters: int, rgbs: Iterable[torch.Tensor], prefix: str = ""):
|
226 |
+
for ri, rgb in enumerate(rgbs):
|
227 |
+
scale = 2 ** (-(len(rgbs) - ri))
|
228 |
+
visual = make_grid(torch.cat((rgb, rgb / scale), dim=-1), nrow=1, normalize=True, range=(-1, 1))
|
229 |
+
writer.add_image(f"{prefix}to_rbg_{2 ** (ri + 2)}", visual, niters)
|
230 |
+
|
Time-Travel-Rephotography/utils/projector_arguments.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import (
|
3 |
+
ArgumentParser,
|
4 |
+
Namespace,
|
5 |
+
)
|
6 |
+
|
7 |
+
from models.degrade import DegradeArguments
|
8 |
+
from tools.initialize import InitializerArguments
|
9 |
+
from losses.joint_loss import LossArguments
|
10 |
+
from utils.optimize import OptimizerArguments
|
11 |
+
from .misc import (
|
12 |
+
optional_string,
|
13 |
+
iterable_to_str,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class ProjectorArguments:
|
18 |
+
def __init__(self):
|
19 |
+
parser = ArgumentParser("Project image into stylegan2")
|
20 |
+
self.add_arguments(parser)
|
21 |
+
self.parser = parser
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def add_arguments(cls, parser: ArgumentParser):
|
25 |
+
parser.add_argument('--rand_seed', type=int, default=None,
|
26 |
+
help="random seed")
|
27 |
+
cls.add_io_args(parser)
|
28 |
+
cls.add_preprocess_args(parser)
|
29 |
+
cls.add_stylegan_args(parser)
|
30 |
+
|
31 |
+
InitializerArguments.add_arguments(parser)
|
32 |
+
LossArguments.add_arguments(parser)
|
33 |
+
OptimizerArguments.add_arguments(parser)
|
34 |
+
DegradeArguments.add_arguments(parser)
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def add_stylegan_args(parser: ArgumentParser):
|
38 |
+
parser.add_argument('--ckpt', type=str, default="checkpoint/stylegan2-ffhq-config-f.pt",
|
39 |
+
help="stylegan2 checkpoint")
|
40 |
+
parser.add_argument('--generator_size', type=int, default=1024,
|
41 |
+
help="output size of the generator")
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def add_io_args(parser: ArgumentParser) -> ArgumentParser:
|
45 |
+
parser.add_argument('input', type=str, help="input image path")
|
46 |
+
parser.add_argument('--results_dir', default="results/projector", help="directory to save results.")
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def add_preprocess_args(parser: ArgumentParser):
|
50 |
+
# parser.add_argument("--match_histogram", action='store_true', help="match the histogram of the input image to the sibling")
|
51 |
+
pass
|
52 |
+
|
53 |
+
def parse(self, args=None, namespace=None) -> Namespace:
|
54 |
+
args = self.parser.parse_args(args, namespace=namespace)
|
55 |
+
self.print(args)
|
56 |
+
return args
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def print(args: Namespace):
|
60 |
+
print("------------ Parameters -------------")
|
61 |
+
args = vars(args)
|
62 |
+
for k, v in sorted(args.items()):
|
63 |
+
print(f"{k}: {v}")
|
64 |
+
print("-------------------------------------")
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def to_string(args: Namespace) -> str:
|
68 |
+
return "-".join([
|
69 |
+
#+ optional_string(args.no_camera_response, "-noCR")
|
70 |
+
#+ optional_string(args.match_histogram, "-MH")
|
71 |
+
DegradeArguments.to_string(args),
|
72 |
+
InitializerArguments.to_string(args),
|
73 |
+
LossArguments.to_string(args),
|
74 |
+
OptimizerArguments.to_string(args),
|
75 |
+
]) + optional_string(args.rand_seed is not None, f"-S{args.rand_seed}")
|
76 |
+
|
Time-Travel-Rephotography/utils/torch_helpers.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
def device(gpu_id=0):
|
6 |
+
if torch.cuda.is_available():
|
7 |
+
return torch.device(f"cuda:{gpu_id}")
|
8 |
+
return torch.device("cpu")
|
9 |
+
|
10 |
+
|
11 |
+
def load_matching_state_dict(model: nn.Module, state_dict):
|
12 |
+
model_dict = model.state_dict()
|
13 |
+
filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict}
|
14 |
+
model.load_state_dict(filtered_dict)
|
15 |
+
|
16 |
+
|
17 |
+
def resize(t: torch.Tensor, size: int) -> torch.Tensor:
|
18 |
+
B, C, H, W = t.shape
|
19 |
+
t = t.reshape(B, C, size, H // size, size, W // size)
|
20 |
+
return t.mean([3, 5])
|
21 |
+
|
22 |
+
|
23 |
+
def make_image(tensor):
|
24 |
+
return (
|
25 |
+
tensor.detach()
|
26 |
+
.clamp_(min=-1, max=1)
|
27 |
+
.add(1)
|
28 |
+
.div_(2)
|
29 |
+
.mul(255)
|
30 |
+
.type(torch.uint8)
|
31 |
+
.permute(0, 2, 3, 1)
|
32 |
+
.to('cpu')
|
33 |
+
.numpy()
|
34 |
+
)
|
35 |
+
|
36 |
+
|