Spaces:
Runtime error
Runtime error
time
Browse files- .gitignore +133 -0
- LICENSE +21 -0
- LICENSE-FID +201 -0
- LICENSE-LPIPS +24 -0
- LICENSE-NVIDIA +101 -0
- apply_factor.py +94 -0
- calc_inception.py +130 -0
- checkpoint/.gitignore +1 -0
- closed_form_factorization.py +33 -0
- convert_weight.py +301 -0
- dataset.py +40 -0
- distributed.py +126 -0
- fid.py +129 -0
- generate.py +84 -0
- inception.py +310 -0
- lpips/__init__.py +160 -0
- lpips/base_model.py +58 -0
- lpips/dist_model.py +284 -0
- lpips/networks_basic.py +187 -0
- lpips/pretrained_networks.py +181 -0
- lpips/weights/v0.0/alex.pth +3 -0
- lpips/weights/v0.0/squeeze.pth +3 -0
- lpips/weights/v0.0/vgg.pth +3 -0
- lpips/weights/v0.1/alex.pth +3 -0
- lpips/weights/v0.1/squeeze.pth +3 -0
- lpips/weights/v0.1/vgg.pth +3 -0
- model.py +698 -0
- non_leaking.py +465 -0
- op/__init__.py +2 -0
- op/conv2d_gradfix.py +227 -0
- op/fused_act.py +127 -0
- op/fused_bias_act.cpp +32 -0
- op/fused_bias_act_kernel.cu +105 -0
- op/upfirdn2d.cpp +31 -0
- op/upfirdn2d.py +209 -0
- op/upfirdn2d_kernel.cu +369 -0
- ppl.py +130 -0
- prepare_data.py +101 -0
- projector.py +248 -0
- sample/.gitignore +1 -0
- swagan.py +440 -0
- train.py +531 -0
.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
wandb/
|
132 |
+
*.lmdb/
|
133 |
+
*.pkl
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Kim Seonghyeon
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LICENSE-FID
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
LICENSE-LPIPS
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
+
|
LICENSE-NVIDIA
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
Nvidia Source Code License-NC
|
5 |
+
|
6 |
+
=======================================================================
|
7 |
+
|
8 |
+
1. Definitions
|
9 |
+
|
10 |
+
"Licensor" means any person or entity that distributes its Work.
|
11 |
+
|
12 |
+
"Software" means the original work of authorship made available under
|
13 |
+
this License.
|
14 |
+
|
15 |
+
"Work" means the Software and any additions to or derivative works of
|
16 |
+
the Software that are made available under this License.
|
17 |
+
|
18 |
+
"Nvidia Processors" means any central processing unit (CPU), graphics
|
19 |
+
processing unit (GPU), field-programmable gate array (FPGA),
|
20 |
+
application-specific integrated circuit (ASIC) or any combination
|
21 |
+
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
22 |
+
|
23 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
24 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
25 |
+
provided, however, that for the purposes of this License, derivative
|
26 |
+
works shall not include works that remain separable from, or merely
|
27 |
+
link (or bind by name) to the interfaces of, the Work.
|
28 |
+
|
29 |
+
Works, including the Software, are "made available" under this License
|
30 |
+
by including in or with the Work either (a) a copyright notice
|
31 |
+
referencing the applicability of this License to the Work, or (b) a
|
32 |
+
copy of this License.
|
33 |
+
|
34 |
+
2. License Grants
|
35 |
+
|
36 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
37 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
38 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
39 |
+
prepare derivative works of, publicly display, publicly perform,
|
40 |
+
sublicense and distribute its Work and any resulting derivative
|
41 |
+
works in any form.
|
42 |
+
|
43 |
+
3. Limitations
|
44 |
+
|
45 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
46 |
+
if (a) you do so under this License, (b) you include a complete
|
47 |
+
copy of this License with your distribution, and (c) you retain
|
48 |
+
without modification any copyright, patent, trademark, or
|
49 |
+
attribution notices that are present in the Work.
|
50 |
+
|
51 |
+
3.2 Derivative Works. You may specify that additional or different
|
52 |
+
terms apply to the use, reproduction, and distribution of your
|
53 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
54 |
+
provide that the use limitation in Section 3.3 applies to your
|
55 |
+
derivative works, and (b) you identify the specific derivative
|
56 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
57 |
+
this License (including the redistribution requirements in Section
|
58 |
+
3.1) will continue to apply to the Work itself.
|
59 |
+
|
60 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
61 |
+
may be used or intended for use non-commercially. The Work or
|
62 |
+
derivative works thereof may be used or intended for use by Nvidia
|
63 |
+
or its affiliates commercially or non-commercially. As used herein,
|
64 |
+
"non-commercially" means for research or evaluation purposes only.
|
65 |
+
|
66 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
67 |
+
against any Licensor (including any claim, cross-claim or
|
68 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
69 |
+
are infringed by any Work, then your rights under this License from
|
70 |
+
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
71 |
+
terminate immediately.
|
72 |
+
|
73 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
74 |
+
Licensor's or its affiliates' names, logos, or trademarks, except
|
75 |
+
as necessary to reproduce the notices described in this License.
|
76 |
+
|
77 |
+
3.6 Termination. If you violate any term of this License, then your
|
78 |
+
rights under this License (including the grants in Sections 2.1 and
|
79 |
+
2.2) will terminate immediately.
|
80 |
+
|
81 |
+
4. Disclaimer of Warranty.
|
82 |
+
|
83 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
84 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
85 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
86 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
87 |
+
THIS LICENSE.
|
88 |
+
|
89 |
+
5. Limitation of Liability.
|
90 |
+
|
91 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
92 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
93 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
94 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
95 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
96 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
97 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
98 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
99 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
100 |
+
|
101 |
+
=======================================================================
|
apply_factor.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import utils
|
5 |
+
|
6 |
+
from model import Generator
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
torch.set_grad_enabled(False)
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser(description="Apply closed form factorization")
|
13 |
+
|
14 |
+
parser.add_argument(
|
15 |
+
"-i", "--index", type=int, default=0, help="index of eigenvector"
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"-d",
|
19 |
+
"--degree",
|
20 |
+
type=float,
|
21 |
+
default=5,
|
22 |
+
help="scalar factors for moving latent vectors along eigenvector",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--channel_multiplier",
|
26 |
+
type=int,
|
27 |
+
default=2,
|
28 |
+
help='channel multiplier factor. config-f = 2, else = 1',
|
29 |
+
)
|
30 |
+
parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints")
|
31 |
+
parser.add_argument(
|
32 |
+
"--size", type=int, default=256, help="output image size of the generator"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"-n", "--n_sample", type=int, default=7, help="number of samples created"
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--truncation", type=float, default=0.7, help="truncation factor"
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--device", type=str, default="cuda", help="device to run the model"
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--out_prefix",
|
45 |
+
type=str,
|
46 |
+
default="factor",
|
47 |
+
help="filename prefix to result samples",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"factor",
|
51 |
+
type=str,
|
52 |
+
help="name of the closed form factorization result factor file",
|
53 |
+
)
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
eigvec = torch.load(args.factor)["eigvec"].to(args.device)
|
58 |
+
ckpt = torch.load(args.ckpt)
|
59 |
+
g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)
|
60 |
+
g.load_state_dict(ckpt["g_ema"], strict=False)
|
61 |
+
|
62 |
+
trunc = g.mean_latent(4096)
|
63 |
+
|
64 |
+
latent = torch.randn(args.n_sample, 512, device=args.device)
|
65 |
+
latent = g.get_latent(latent)
|
66 |
+
|
67 |
+
direction = args.degree * eigvec[:, args.index].unsqueeze(0)
|
68 |
+
|
69 |
+
img, _ = g(
|
70 |
+
[latent],
|
71 |
+
truncation=args.truncation,
|
72 |
+
truncation_latent=trunc,
|
73 |
+
input_is_latent=True,
|
74 |
+
)
|
75 |
+
img1, _ = g(
|
76 |
+
[latent + direction],
|
77 |
+
truncation=args.truncation,
|
78 |
+
truncation_latent=trunc,
|
79 |
+
input_is_latent=True,
|
80 |
+
)
|
81 |
+
img2, _ = g(
|
82 |
+
[latent - direction],
|
83 |
+
truncation=args.truncation,
|
84 |
+
truncation_latent=trunc,
|
85 |
+
input_is_latent=True,
|
86 |
+
)
|
87 |
+
|
88 |
+
grid = utils.save_image(
|
89 |
+
torch.cat([img1, img, img2], 0),
|
90 |
+
f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png",
|
91 |
+
normalize=True,
|
92 |
+
range=(-1, 1),
|
93 |
+
nrow=args.n_sample,
|
94 |
+
)
|
calc_inception.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.models import inception_v3, Inception3
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from inception import InceptionV3
|
15 |
+
from dataset import MultiResolutionDataset
|
16 |
+
|
17 |
+
|
18 |
+
class Inception3Feature(Inception3):
|
19 |
+
def forward(self, x):
|
20 |
+
if x.shape[2] != 299 or x.shape[3] != 299:
|
21 |
+
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
|
22 |
+
|
23 |
+
x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
|
24 |
+
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
|
25 |
+
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
|
26 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
|
27 |
+
|
28 |
+
x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
|
29 |
+
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
|
30 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
|
31 |
+
|
32 |
+
x = self.Mixed_5b(x) # 35 x 35 x 192
|
33 |
+
x = self.Mixed_5c(x) # 35 x 35 x 256
|
34 |
+
x = self.Mixed_5d(x) # 35 x 35 x 288
|
35 |
+
|
36 |
+
x = self.Mixed_6a(x) # 35 x 35 x 288
|
37 |
+
x = self.Mixed_6b(x) # 17 x 17 x 768
|
38 |
+
x = self.Mixed_6c(x) # 17 x 17 x 768
|
39 |
+
x = self.Mixed_6d(x) # 17 x 17 x 768
|
40 |
+
x = self.Mixed_6e(x) # 17 x 17 x 768
|
41 |
+
|
42 |
+
x = self.Mixed_7a(x) # 17 x 17 x 768
|
43 |
+
x = self.Mixed_7b(x) # 8 x 8 x 1280
|
44 |
+
x = self.Mixed_7c(x) # 8 x 8 x 2048
|
45 |
+
|
46 |
+
x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
|
47 |
+
|
48 |
+
return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
|
49 |
+
|
50 |
+
|
51 |
+
def load_patched_inception_v3():
|
52 |
+
# inception = inception_v3(pretrained=True)
|
53 |
+
# inception_feat = Inception3Feature()
|
54 |
+
# inception_feat.load_state_dict(inception.state_dict())
|
55 |
+
inception_feat = InceptionV3([3], normalize_input=False)
|
56 |
+
|
57 |
+
return inception_feat
|
58 |
+
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def extract_features(loader, inception, device):
|
62 |
+
pbar = tqdm(loader)
|
63 |
+
|
64 |
+
feature_list = []
|
65 |
+
|
66 |
+
for img in pbar:
|
67 |
+
img = img.to(device)
|
68 |
+
feature = inception(img)[0].view(img.shape[0], -1)
|
69 |
+
feature_list.append(feature.to("cpu"))
|
70 |
+
|
71 |
+
features = torch.cat(feature_list, 0)
|
72 |
+
|
73 |
+
return features
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
78 |
+
|
79 |
+
parser = argparse.ArgumentParser(
|
80 |
+
description="Calculate Inception v3 features for datasets"
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--size",
|
84 |
+
type=int,
|
85 |
+
default=256,
|
86 |
+
help="image sizes used for embedding calculation",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--batch", default=64, type=int, help="batch size for inception networks"
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--n_sample",
|
93 |
+
type=int,
|
94 |
+
default=50000,
|
95 |
+
help="number of samples used for embedding calculation",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--flip", action="store_true", help="apply random flipping to real images"
|
99 |
+
)
|
100 |
+
parser.add_argument("path", metavar="PATH", help="path to datset lmdb file")
|
101 |
+
|
102 |
+
args = parser.parse_args()
|
103 |
+
|
104 |
+
inception = load_patched_inception_v3()
|
105 |
+
inception = nn.DataParallel(inception).eval().to(device)
|
106 |
+
|
107 |
+
transform = transforms.Compose(
|
108 |
+
[
|
109 |
+
transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
|
110 |
+
transforms.ToTensor(),
|
111 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
112 |
+
]
|
113 |
+
)
|
114 |
+
|
115 |
+
dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
|
116 |
+
loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
|
117 |
+
|
118 |
+
features = extract_features(loader, inception, device).numpy()
|
119 |
+
|
120 |
+
features = features[: args.n_sample]
|
121 |
+
|
122 |
+
print(f"extracted {features.shape[0]} features")
|
123 |
+
|
124 |
+
mean = np.mean(features, 0)
|
125 |
+
cov = np.cov(features, rowvar=False)
|
126 |
+
|
127 |
+
name = os.path.splitext(os.path.basename(args.path))[0]
|
128 |
+
|
129 |
+
with open(f"inception_{name}.pkl", "wb") as f:
|
130 |
+
pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f)
|
checkpoint/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pt
|
closed_form_factorization.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
parser = argparse.ArgumentParser(
|
8 |
+
description="Extract factor/eigenvectors of latent spaces using closed form factorization"
|
9 |
+
)
|
10 |
+
|
11 |
+
parser.add_argument(
|
12 |
+
"--out", type=str, default="factor.pt", help="name of the result factor file"
|
13 |
+
)
|
14 |
+
parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
|
15 |
+
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
ckpt = torch.load(args.ckpt)
|
19 |
+
modulate = {
|
20 |
+
k: v
|
21 |
+
for k, v in ckpt["g_ema"].items()
|
22 |
+
if "modulation" in k and "to_rgbs" not in k and "weight" in k
|
23 |
+
}
|
24 |
+
|
25 |
+
weight_mat = []
|
26 |
+
for k, v in modulate.items():
|
27 |
+
weight_mat.append(v)
|
28 |
+
|
29 |
+
W = torch.cat(weight_mat, 0)
|
30 |
+
eigvec = torch.svd(W).V.to("cpu")
|
31 |
+
|
32 |
+
torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out)
|
33 |
+
|
convert_weight.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from torchvision import utils
|
10 |
+
|
11 |
+
from model import Generator, Discriminator
|
12 |
+
|
13 |
+
|
14 |
+
def convert_modconv(vars, source_name, target_name, flip=False):
|
15 |
+
weight = vars[source_name + "/weight"].value().eval()
|
16 |
+
mod_weight = vars[source_name + "/mod_weight"].value().eval()
|
17 |
+
mod_bias = vars[source_name + "/mod_bias"].value().eval()
|
18 |
+
noise = vars[source_name + "/noise_strength"].value().eval()
|
19 |
+
bias = vars[source_name + "/bias"].value().eval()
|
20 |
+
|
21 |
+
dic = {
|
22 |
+
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
23 |
+
"conv.modulation.weight": mod_weight.transpose((1, 0)),
|
24 |
+
"conv.modulation.bias": mod_bias + 1,
|
25 |
+
"noise.weight": np.array([noise]),
|
26 |
+
"activate.bias": bias,
|
27 |
+
}
|
28 |
+
|
29 |
+
dic_torch = {}
|
30 |
+
|
31 |
+
for k, v in dic.items():
|
32 |
+
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
33 |
+
|
34 |
+
if flip:
|
35 |
+
dic_torch[target_name + ".conv.weight"] = torch.flip(
|
36 |
+
dic_torch[target_name + ".conv.weight"], [3, 4]
|
37 |
+
)
|
38 |
+
|
39 |
+
return dic_torch
|
40 |
+
|
41 |
+
|
42 |
+
def convert_conv(vars, source_name, target_name, bias=True, start=0):
|
43 |
+
weight = vars[source_name + "/weight"].value().eval()
|
44 |
+
|
45 |
+
dic = {"weight": weight.transpose((3, 2, 0, 1))}
|
46 |
+
|
47 |
+
if bias:
|
48 |
+
dic["bias"] = vars[source_name + "/bias"].value().eval()
|
49 |
+
|
50 |
+
dic_torch = {}
|
51 |
+
|
52 |
+
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
|
53 |
+
|
54 |
+
if bias:
|
55 |
+
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
|
56 |
+
|
57 |
+
return dic_torch
|
58 |
+
|
59 |
+
|
60 |
+
def convert_torgb(vars, source_name, target_name):
|
61 |
+
weight = vars[source_name + "/weight"].value().eval()
|
62 |
+
mod_weight = vars[source_name + "/mod_weight"].value().eval()
|
63 |
+
mod_bias = vars[source_name + "/mod_bias"].value().eval()
|
64 |
+
bias = vars[source_name + "/bias"].value().eval()
|
65 |
+
|
66 |
+
dic = {
|
67 |
+
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
68 |
+
"conv.modulation.weight": mod_weight.transpose((1, 0)),
|
69 |
+
"conv.modulation.bias": mod_bias + 1,
|
70 |
+
"bias": bias.reshape((1, 3, 1, 1)),
|
71 |
+
}
|
72 |
+
|
73 |
+
dic_torch = {}
|
74 |
+
|
75 |
+
for k, v in dic.items():
|
76 |
+
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
77 |
+
|
78 |
+
return dic_torch
|
79 |
+
|
80 |
+
|
81 |
+
def convert_dense(vars, source_name, target_name):
|
82 |
+
weight = vars[source_name + "/weight"].value().eval()
|
83 |
+
bias = vars[source_name + "/bias"].value().eval()
|
84 |
+
|
85 |
+
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
|
86 |
+
|
87 |
+
dic_torch = {}
|
88 |
+
|
89 |
+
for k, v in dic.items():
|
90 |
+
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
91 |
+
|
92 |
+
return dic_torch
|
93 |
+
|
94 |
+
|
95 |
+
def update(state_dict, new):
|
96 |
+
for k, v in new.items():
|
97 |
+
if k not in state_dict:
|
98 |
+
raise KeyError(k + " is not found")
|
99 |
+
|
100 |
+
if v.shape != state_dict[k].shape:
|
101 |
+
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
|
102 |
+
|
103 |
+
state_dict[k] = v
|
104 |
+
|
105 |
+
|
106 |
+
def discriminator_fill_statedict(statedict, vars, size):
|
107 |
+
log_size = int(math.log(size, 2))
|
108 |
+
|
109 |
+
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
110 |
+
|
111 |
+
conv_i = 1
|
112 |
+
|
113 |
+
for i in range(log_size - 2, 0, -1):
|
114 |
+
reso = 4 * 2 ** i
|
115 |
+
update(
|
116 |
+
statedict,
|
117 |
+
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
118 |
+
)
|
119 |
+
update(
|
120 |
+
statedict,
|
121 |
+
convert_conv(
|
122 |
+
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
123 |
+
),
|
124 |
+
)
|
125 |
+
update(
|
126 |
+
statedict,
|
127 |
+
convert_conv(
|
128 |
+
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
129 |
+
),
|
130 |
+
)
|
131 |
+
conv_i += 1
|
132 |
+
|
133 |
+
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
134 |
+
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
135 |
+
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
136 |
+
|
137 |
+
return statedict
|
138 |
+
|
139 |
+
|
140 |
+
def fill_statedict(state_dict, vars, size, n_mlp):
|
141 |
+
log_size = int(math.log(size, 2))
|
142 |
+
|
143 |
+
for i in range(n_mlp):
|
144 |
+
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
|
145 |
+
|
146 |
+
update(
|
147 |
+
state_dict,
|
148 |
+
{
|
149 |
+
"input.input": torch.from_numpy(
|
150 |
+
vars["G_synthesis/4x4/Const/const"].value().eval()
|
151 |
+
)
|
152 |
+
},
|
153 |
+
)
|
154 |
+
|
155 |
+
update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
|
156 |
+
|
157 |
+
for i in range(log_size - 2):
|
158 |
+
reso = 4 * 2 ** (i + 1)
|
159 |
+
update(
|
160 |
+
state_dict,
|
161 |
+
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
|
162 |
+
)
|
163 |
+
|
164 |
+
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
|
165 |
+
|
166 |
+
conv_i = 0
|
167 |
+
|
168 |
+
for i in range(log_size - 2):
|
169 |
+
reso = 4 * 2 ** (i + 1)
|
170 |
+
update(
|
171 |
+
state_dict,
|
172 |
+
convert_modconv(
|
173 |
+
vars,
|
174 |
+
f"G_synthesis/{reso}x{reso}/Conv0_up",
|
175 |
+
f"convs.{conv_i}",
|
176 |
+
flip=True,
|
177 |
+
),
|
178 |
+
)
|
179 |
+
update(
|
180 |
+
state_dict,
|
181 |
+
convert_modconv(
|
182 |
+
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
|
183 |
+
),
|
184 |
+
)
|
185 |
+
conv_i += 2
|
186 |
+
|
187 |
+
for i in range(0, (log_size - 2) * 2 + 1):
|
188 |
+
update(
|
189 |
+
state_dict,
|
190 |
+
{
|
191 |
+
f"noises.noise_{i}": torch.from_numpy(
|
192 |
+
vars[f"G_synthesis/noise{i}"].value().eval()
|
193 |
+
)
|
194 |
+
},
|
195 |
+
)
|
196 |
+
|
197 |
+
return state_dict
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
device = "cuda"
|
202 |
+
|
203 |
+
parser = argparse.ArgumentParser(
|
204 |
+
description="Tensorflow to pytorch model checkpoint converter"
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--repo",
|
208 |
+
type=str,
|
209 |
+
required=True,
|
210 |
+
help="path to the offical StyleGAN2 repository with dnnlib/ folder",
|
211 |
+
)
|
212 |
+
parser.add_argument(
|
213 |
+
"--gen", action="store_true", help="convert the generator weights"
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--disc", action="store_true", help="convert the discriminator weights"
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--channel_multiplier",
|
220 |
+
type=int,
|
221 |
+
default=2,
|
222 |
+
help="channel multiplier factor. config-f = 2, else = 1",
|
223 |
+
)
|
224 |
+
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
|
225 |
+
|
226 |
+
args = parser.parse_args()
|
227 |
+
|
228 |
+
sys.path.append(args.repo)
|
229 |
+
|
230 |
+
import dnnlib
|
231 |
+
from dnnlib import tflib
|
232 |
+
|
233 |
+
tflib.init_tf()
|
234 |
+
|
235 |
+
with open(args.path, "rb") as f:
|
236 |
+
generator, discriminator, g_ema = pickle.load(f)
|
237 |
+
|
238 |
+
size = g_ema.output_shape[2]
|
239 |
+
|
240 |
+
n_mlp = 0
|
241 |
+
mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
|
242 |
+
for layer in mapping_layers_names:
|
243 |
+
if layer[0].startswith('Dense'):
|
244 |
+
n_mlp += 1
|
245 |
+
|
246 |
+
g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
|
247 |
+
state_dict = g.state_dict()
|
248 |
+
state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
|
249 |
+
|
250 |
+
g.load_state_dict(state_dict)
|
251 |
+
|
252 |
+
latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
|
253 |
+
|
254 |
+
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
255 |
+
|
256 |
+
if args.gen:
|
257 |
+
g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
|
258 |
+
g_train_state = g_train.state_dict()
|
259 |
+
g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
|
260 |
+
ckpt["g"] = g_train_state
|
261 |
+
|
262 |
+
if args.disc:
|
263 |
+
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
|
264 |
+
d_state = disc.state_dict()
|
265 |
+
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
|
266 |
+
ckpt["d"] = d_state
|
267 |
+
|
268 |
+
name = os.path.splitext(os.path.basename(args.path))[0]
|
269 |
+
torch.save(ckpt, name + ".pt")
|
270 |
+
|
271 |
+
batch_size = {256: 16, 512: 9, 1024: 4}
|
272 |
+
n_sample = batch_size.get(size, 25)
|
273 |
+
|
274 |
+
g = g.to(device)
|
275 |
+
|
276 |
+
z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
|
277 |
+
|
278 |
+
with torch.no_grad():
|
279 |
+
img_pt, _ = g(
|
280 |
+
[torch.from_numpy(z).to(device)],
|
281 |
+
truncation=0.5,
|
282 |
+
truncation_latent=latent_avg.to(device),
|
283 |
+
randomize_noise=False,
|
284 |
+
)
|
285 |
+
|
286 |
+
Gs_kwargs = dnnlib.EasyDict()
|
287 |
+
Gs_kwargs.randomize_noise = False
|
288 |
+
img_tf = g_ema.run(z, None, **Gs_kwargs)
|
289 |
+
img_tf = torch.from_numpy(img_tf).to(device)
|
290 |
+
|
291 |
+
img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
|
292 |
+
0.0, 1.0
|
293 |
+
)
|
294 |
+
|
295 |
+
img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
|
296 |
+
|
297 |
+
print(img_diff.abs().max())
|
298 |
+
|
299 |
+
utils.save_image(
|
300 |
+
img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
301 |
+
)
|
dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
|
3 |
+
import lmdb
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class MultiResolutionDataset(Dataset):
|
9 |
+
def __init__(self, path, transform, resolution=256):
|
10 |
+
self.env = lmdb.open(
|
11 |
+
path,
|
12 |
+
max_readers=32,
|
13 |
+
readonly=True,
|
14 |
+
lock=False,
|
15 |
+
readahead=False,
|
16 |
+
meminit=False,
|
17 |
+
)
|
18 |
+
|
19 |
+
if not self.env:
|
20 |
+
raise IOError('Cannot open lmdb dataset', path)
|
21 |
+
|
22 |
+
with self.env.begin(write=False) as txn:
|
23 |
+
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
|
24 |
+
|
25 |
+
self.resolution = resolution
|
26 |
+
self.transform = transform
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return self.length
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
with self.env.begin(write=False) as txn:
|
33 |
+
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
|
34 |
+
img_bytes = txn.get(key)
|
35 |
+
|
36 |
+
buffer = BytesIO(img_bytes)
|
37 |
+
img = Image.open(buffer)
|
38 |
+
img = self.transform(img)
|
39 |
+
|
40 |
+
return img
|
distributed.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import distributed as dist
|
6 |
+
from torch.utils.data.sampler import Sampler
|
7 |
+
|
8 |
+
|
9 |
+
def get_rank():
|
10 |
+
if not dist.is_available():
|
11 |
+
return 0
|
12 |
+
|
13 |
+
if not dist.is_initialized():
|
14 |
+
return 0
|
15 |
+
|
16 |
+
return dist.get_rank()
|
17 |
+
|
18 |
+
|
19 |
+
def synchronize():
|
20 |
+
if not dist.is_available():
|
21 |
+
return
|
22 |
+
|
23 |
+
if not dist.is_initialized():
|
24 |
+
return
|
25 |
+
|
26 |
+
world_size = dist.get_world_size()
|
27 |
+
|
28 |
+
if world_size == 1:
|
29 |
+
return
|
30 |
+
|
31 |
+
dist.barrier()
|
32 |
+
|
33 |
+
|
34 |
+
def get_world_size():
|
35 |
+
if not dist.is_available():
|
36 |
+
return 1
|
37 |
+
|
38 |
+
if not dist.is_initialized():
|
39 |
+
return 1
|
40 |
+
|
41 |
+
return dist.get_world_size()
|
42 |
+
|
43 |
+
|
44 |
+
def reduce_sum(tensor):
|
45 |
+
if not dist.is_available():
|
46 |
+
return tensor
|
47 |
+
|
48 |
+
if not dist.is_initialized():
|
49 |
+
return tensor
|
50 |
+
|
51 |
+
tensor = tensor.clone()
|
52 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
53 |
+
|
54 |
+
return tensor
|
55 |
+
|
56 |
+
|
57 |
+
def gather_grad(params):
|
58 |
+
world_size = get_world_size()
|
59 |
+
|
60 |
+
if world_size == 1:
|
61 |
+
return
|
62 |
+
|
63 |
+
for param in params:
|
64 |
+
if param.grad is not None:
|
65 |
+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
66 |
+
param.grad.data.div_(world_size)
|
67 |
+
|
68 |
+
|
69 |
+
def all_gather(data):
|
70 |
+
world_size = get_world_size()
|
71 |
+
|
72 |
+
if world_size == 1:
|
73 |
+
return [data]
|
74 |
+
|
75 |
+
buffer = pickle.dumps(data)
|
76 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
77 |
+
tensor = torch.ByteTensor(storage).to('cuda')
|
78 |
+
|
79 |
+
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
|
80 |
+
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
|
81 |
+
dist.all_gather(size_list, local_size)
|
82 |
+
size_list = [int(size.item()) for size in size_list]
|
83 |
+
max_size = max(size_list)
|
84 |
+
|
85 |
+
tensor_list = []
|
86 |
+
for _ in size_list:
|
87 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
|
88 |
+
|
89 |
+
if local_size != max_size:
|
90 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
|
91 |
+
tensor = torch.cat((tensor, padding), 0)
|
92 |
+
|
93 |
+
dist.all_gather(tensor_list, tensor)
|
94 |
+
|
95 |
+
data_list = []
|
96 |
+
|
97 |
+
for size, tensor in zip(size_list, tensor_list):
|
98 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
99 |
+
data_list.append(pickle.loads(buffer))
|
100 |
+
|
101 |
+
return data_list
|
102 |
+
|
103 |
+
|
104 |
+
def reduce_loss_dict(loss_dict):
|
105 |
+
world_size = get_world_size()
|
106 |
+
|
107 |
+
if world_size < 2:
|
108 |
+
return loss_dict
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
keys = []
|
112 |
+
losses = []
|
113 |
+
|
114 |
+
for k in sorted(loss_dict.keys()):
|
115 |
+
keys.append(k)
|
116 |
+
losses.append(loss_dict[k])
|
117 |
+
|
118 |
+
losses = torch.stack(losses, 0)
|
119 |
+
dist.reduce(losses, dst=0)
|
120 |
+
|
121 |
+
if dist.get_rank() == 0:
|
122 |
+
losses /= world_size
|
123 |
+
|
124 |
+
reduced_losses = {k: v for k, v in zip(keys, losses)}
|
125 |
+
|
126 |
+
return reduced_losses
|
fid.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
from scipy import linalg
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from model import Generator
|
11 |
+
from calc_inception import load_patched_inception_v3
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def extract_feature_from_samples(
|
16 |
+
generator, inception, truncation, truncation_latent, batch_size, n_sample, device
|
17 |
+
):
|
18 |
+
n_batch = n_sample // batch_size
|
19 |
+
resid = n_sample - (n_batch * batch_size)
|
20 |
+
batch_sizes = [batch_size] * n_batch + [resid]
|
21 |
+
features = []
|
22 |
+
|
23 |
+
for batch in tqdm(batch_sizes):
|
24 |
+
latent = torch.randn(batch, 512, device=device)
|
25 |
+
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
|
26 |
+
feat = inception(img)[0].view(img.shape[0], -1)
|
27 |
+
features.append(feat.to("cpu"))
|
28 |
+
|
29 |
+
features = torch.cat(features, 0)
|
30 |
+
|
31 |
+
return features
|
32 |
+
|
33 |
+
|
34 |
+
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
|
35 |
+
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
|
36 |
+
|
37 |
+
if not np.isfinite(cov_sqrt).all():
|
38 |
+
print("product of cov matrices is singular")
|
39 |
+
offset = np.eye(sample_cov.shape[0]) * eps
|
40 |
+
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
|
41 |
+
|
42 |
+
if np.iscomplexobj(cov_sqrt):
|
43 |
+
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
|
44 |
+
m = np.max(np.abs(cov_sqrt.imag))
|
45 |
+
|
46 |
+
raise ValueError(f"Imaginary component {m}")
|
47 |
+
|
48 |
+
cov_sqrt = cov_sqrt.real
|
49 |
+
|
50 |
+
mean_diff = sample_mean - real_mean
|
51 |
+
mean_norm = mean_diff @ mean_diff
|
52 |
+
|
53 |
+
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
|
54 |
+
|
55 |
+
fid = mean_norm + trace
|
56 |
+
|
57 |
+
return fid
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
device = "cuda"
|
62 |
+
|
63 |
+
parser = argparse.ArgumentParser(description="Calculate FID scores")
|
64 |
+
|
65 |
+
parser.add_argument("--truncation", type=float, default=1, help="truncation factor")
|
66 |
+
parser.add_argument(
|
67 |
+
"--truncation_mean",
|
68 |
+
type=int,
|
69 |
+
default=4096,
|
70 |
+
help="number of samples to calculate mean for truncation",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--batch", type=int, default=64, help="batch size for the generator"
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--n_sample",
|
77 |
+
type=int,
|
78 |
+
default=50000,
|
79 |
+
help="number of the samples for calculating FID",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--size", type=int, default=256, help="image sizes for generator"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--inception",
|
86 |
+
type=str,
|
87 |
+
default=None,
|
88 |
+
required=True,
|
89 |
+
help="path to precomputed inception embedding",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"ckpt", metavar="CHECKPOINT", help="path to generator checkpoint"
|
93 |
+
)
|
94 |
+
|
95 |
+
args = parser.parse_args()
|
96 |
+
|
97 |
+
ckpt = torch.load(args.ckpt)
|
98 |
+
|
99 |
+
g = Generator(args.size, 512, 8).to(device)
|
100 |
+
g.load_state_dict(ckpt["g_ema"])
|
101 |
+
g = nn.DataParallel(g)
|
102 |
+
g.eval()
|
103 |
+
|
104 |
+
if args.truncation < 1:
|
105 |
+
with torch.no_grad():
|
106 |
+
mean_latent = g.mean_latent(args.truncation_mean)
|
107 |
+
|
108 |
+
else:
|
109 |
+
mean_latent = None
|
110 |
+
|
111 |
+
inception = nn.DataParallel(load_patched_inception_v3()).to(device)
|
112 |
+
inception.eval()
|
113 |
+
|
114 |
+
features = extract_feature_from_samples(
|
115 |
+
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
|
116 |
+
).numpy()
|
117 |
+
print(f"extracted {features.shape[0]} features")
|
118 |
+
|
119 |
+
sample_mean = np.mean(features, 0)
|
120 |
+
sample_cov = np.cov(features, rowvar=False)
|
121 |
+
|
122 |
+
with open(args.inception, "rb") as f:
|
123 |
+
embeds = pickle.load(f)
|
124 |
+
real_mean = embeds["mean"]
|
125 |
+
real_cov = embeds["cov"]
|
126 |
+
|
127 |
+
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
|
128 |
+
|
129 |
+
print("fid:", fid)
|
generate.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import utils
|
5 |
+
from model import Generator
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
def generate(args, g_ema, device, mean_latent):
|
10 |
+
|
11 |
+
with torch.no_grad():
|
12 |
+
g_ema.eval()
|
13 |
+
for i in tqdm(range(args.pics)):
|
14 |
+
sample_z = torch.randn(args.sample, args.latent, device=device)
|
15 |
+
|
16 |
+
sample, _ = g_ema(
|
17 |
+
[sample_z], truncation=args.truncation, truncation_latent=mean_latent
|
18 |
+
)
|
19 |
+
|
20 |
+
utils.save_image(
|
21 |
+
sample,
|
22 |
+
f"sample/{str(i).zfill(6)}.png",
|
23 |
+
nrow=1,
|
24 |
+
normalize=True,
|
25 |
+
range=(-1, 1),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
device = "cuda"
|
31 |
+
|
32 |
+
parser = argparse.ArgumentParser(description="Generate samples from the generator")
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
"--size", type=int, default=1024, help="output image size of the generator"
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--sample",
|
39 |
+
type=int,
|
40 |
+
default=1,
|
41 |
+
help="number of samples to be generated for each image",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--pics", type=int, default=20, help="number of images to be generated"
|
45 |
+
)
|
46 |
+
parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
|
47 |
+
parser.add_argument(
|
48 |
+
"--truncation_mean",
|
49 |
+
type=int,
|
50 |
+
default=4096,
|
51 |
+
help="number of vectors to calculate mean for the truncation",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--ckpt",
|
55 |
+
type=str,
|
56 |
+
default="stylegan2-ffhq-config-f.pt",
|
57 |
+
help="path to the model checkpoint",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--channel_multiplier",
|
61 |
+
type=int,
|
62 |
+
default=2,
|
63 |
+
help="channel multiplier of the generator. config-f = 2, else = 1",
|
64 |
+
)
|
65 |
+
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
args.latent = 512
|
69 |
+
args.n_mlp = 8
|
70 |
+
|
71 |
+
g_ema = Generator(
|
72 |
+
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
73 |
+
).to(device)
|
74 |
+
checkpoint = torch.load(args.ckpt)
|
75 |
+
|
76 |
+
g_ema.load_state_dict(checkpoint["g_ema"])
|
77 |
+
|
78 |
+
if args.truncation < 1:
|
79 |
+
with torch.no_grad():
|
80 |
+
mean_latent = g_ema.mean_latent(args.truncation_mean)
|
81 |
+
else:
|
82 |
+
mean_latent = None
|
83 |
+
|
84 |
+
generate(args, g_ema, device, mean_latent)
|
inception.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import models
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = models.inception_v3(pretrained=True)
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def fid_inception_v3():
|
167 |
+
"""Build pretrained Inception model for FID computation
|
168 |
+
|
169 |
+
The Inception model for FID computation uses a different set of weights
|
170 |
+
and has a slightly different structure than torchvision's Inception.
|
171 |
+
|
172 |
+
This method first constructs torchvision's Inception and then patches the
|
173 |
+
necessary parts that are different in the FID Inception model.
|
174 |
+
"""
|
175 |
+
inception = models.inception_v3(num_classes=1008,
|
176 |
+
aux_logits=False,
|
177 |
+
pretrained=False)
|
178 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
179 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
180 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
181 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
182 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
183 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
184 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
185 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
186 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
187 |
+
|
188 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
189 |
+
inception.load_state_dict(state_dict)
|
190 |
+
return inception
|
191 |
+
|
192 |
+
|
193 |
+
class FIDInceptionA(models.inception.InceptionA):
|
194 |
+
"""InceptionA block patched for FID computation"""
|
195 |
+
def __init__(self, in_channels, pool_features):
|
196 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
branch1x1 = self.branch1x1(x)
|
200 |
+
|
201 |
+
branch5x5 = self.branch5x5_1(x)
|
202 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
203 |
+
|
204 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
205 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
206 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
207 |
+
|
208 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
209 |
+
# its average calculation
|
210 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
211 |
+
count_include_pad=False)
|
212 |
+
branch_pool = self.branch_pool(branch_pool)
|
213 |
+
|
214 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
215 |
+
return torch.cat(outputs, 1)
|
216 |
+
|
217 |
+
|
218 |
+
class FIDInceptionC(models.inception.InceptionC):
|
219 |
+
"""InceptionC block patched for FID computation"""
|
220 |
+
def __init__(self, in_channels, channels_7x7):
|
221 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
branch1x1 = self.branch1x1(x)
|
225 |
+
|
226 |
+
branch7x7 = self.branch7x7_1(x)
|
227 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
228 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
229 |
+
|
230 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
231 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
232 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
233 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
234 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
235 |
+
|
236 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
237 |
+
# its average calculation
|
238 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
239 |
+
count_include_pad=False)
|
240 |
+
branch_pool = self.branch_pool(branch_pool)
|
241 |
+
|
242 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
243 |
+
return torch.cat(outputs, 1)
|
244 |
+
|
245 |
+
|
246 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
247 |
+
"""First InceptionE block patched for FID computation"""
|
248 |
+
def __init__(self, in_channels):
|
249 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
branch1x1 = self.branch1x1(x)
|
253 |
+
|
254 |
+
branch3x3 = self.branch3x3_1(x)
|
255 |
+
branch3x3 = [
|
256 |
+
self.branch3x3_2a(branch3x3),
|
257 |
+
self.branch3x3_2b(branch3x3),
|
258 |
+
]
|
259 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
260 |
+
|
261 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
262 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
263 |
+
branch3x3dbl = [
|
264 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
265 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
266 |
+
]
|
267 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
268 |
+
|
269 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
270 |
+
# its average calculation
|
271 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
272 |
+
count_include_pad=False)
|
273 |
+
branch_pool = self.branch_pool(branch_pool)
|
274 |
+
|
275 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
276 |
+
return torch.cat(outputs, 1)
|
277 |
+
|
278 |
+
|
279 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
280 |
+
"""Second InceptionE block patched for FID computation"""
|
281 |
+
def __init__(self, in_channels):
|
282 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
283 |
+
|
284 |
+
def forward(self, x):
|
285 |
+
branch1x1 = self.branch1x1(x)
|
286 |
+
|
287 |
+
branch3x3 = self.branch3x3_1(x)
|
288 |
+
branch3x3 = [
|
289 |
+
self.branch3x3_2a(branch3x3),
|
290 |
+
self.branch3x3_2b(branch3x3),
|
291 |
+
]
|
292 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
293 |
+
|
294 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
295 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
296 |
+
branch3x3dbl = [
|
297 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
298 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
299 |
+
]
|
300 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
301 |
+
|
302 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
303 |
+
# pooling. This is likely an error in this specific Inception
|
304 |
+
# implementation, as other Inception models use average pooling here
|
305 |
+
# (which matches the description in the paper).
|
306 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
307 |
+
branch_pool = self.branch_pool(branch_pool)
|
308 |
+
|
309 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
310 |
+
return torch.cat(outputs, 1)
|
lpips/__init__.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from skimage.measure import compare_ssim
|
8 |
+
import torch
|
9 |
+
from torch.autograd import Variable
|
10 |
+
|
11 |
+
from lpips import dist_model
|
12 |
+
|
13 |
+
class PerceptualLoss(torch.nn.Module):
|
14 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
|
15 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
16 |
+
super(PerceptualLoss, self).__init__()
|
17 |
+
print('Setting up Perceptual loss...')
|
18 |
+
self.use_gpu = use_gpu
|
19 |
+
self.spatial = spatial
|
20 |
+
self.gpu_ids = gpu_ids
|
21 |
+
self.model = dist_model.DistModel()
|
22 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
|
23 |
+
print('...[%s] initialized'%self.model.name())
|
24 |
+
print('...Done')
|
25 |
+
|
26 |
+
def forward(self, pred, target, normalize=False):
|
27 |
+
"""
|
28 |
+
Pred and target are Variables.
|
29 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
30 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
31 |
+
|
32 |
+
Inputs pred and target are Nx3xHxW
|
33 |
+
Output pytorch Variable N long
|
34 |
+
"""
|
35 |
+
|
36 |
+
if normalize:
|
37 |
+
target = 2 * target - 1
|
38 |
+
pred = 2 * pred - 1
|
39 |
+
|
40 |
+
return self.model.forward(target, pred)
|
41 |
+
|
42 |
+
def normalize_tensor(in_feat,eps=1e-10):
|
43 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
44 |
+
return in_feat/(norm_factor+eps)
|
45 |
+
|
46 |
+
def l2(p0, p1, range=255.):
|
47 |
+
return .5*np.mean((p0 / range - p1 / range)**2)
|
48 |
+
|
49 |
+
def psnr(p0, p1, peak=255.):
|
50 |
+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
|
51 |
+
|
52 |
+
def dssim(p0, p1, range=255.):
|
53 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
54 |
+
|
55 |
+
def rgb2lab(in_img,mean_cent=False):
|
56 |
+
from skimage import color
|
57 |
+
img_lab = color.rgb2lab(in_img)
|
58 |
+
if(mean_cent):
|
59 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
60 |
+
return img_lab
|
61 |
+
|
62 |
+
def tensor2np(tensor_obj):
|
63 |
+
# change dimension of a tensor object into a numpy array
|
64 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
65 |
+
|
66 |
+
def np2tensor(np_obj):
|
67 |
+
# change dimenion of np array into tensor array
|
68 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
69 |
+
|
70 |
+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
71 |
+
# image tensor to lab tensor
|
72 |
+
from skimage import color
|
73 |
+
|
74 |
+
img = tensor2im(image_tensor)
|
75 |
+
img_lab = color.rgb2lab(img)
|
76 |
+
if(mc_only):
|
77 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
78 |
+
if(to_norm and not mc_only):
|
79 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
80 |
+
img_lab = img_lab/100.
|
81 |
+
|
82 |
+
return np2tensor(img_lab)
|
83 |
+
|
84 |
+
def tensorlab2tensor(lab_tensor,return_inbnd=False):
|
85 |
+
from skimage import color
|
86 |
+
import warnings
|
87 |
+
warnings.filterwarnings("ignore")
|
88 |
+
|
89 |
+
lab = tensor2np(lab_tensor)*100.
|
90 |
+
lab[:,:,0] = lab[:,:,0]+50
|
91 |
+
|
92 |
+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
|
93 |
+
if(return_inbnd):
|
94 |
+
# convert back to lab, see if we match
|
95 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
96 |
+
mask = 1.*np.isclose(lab_back,lab,atol=2.)
|
97 |
+
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
|
98 |
+
return (im2tensor(rgb_back),mask)
|
99 |
+
else:
|
100 |
+
return im2tensor(rgb_back)
|
101 |
+
|
102 |
+
def rgb2lab(input):
|
103 |
+
from skimage import color
|
104 |
+
return color.rgb2lab(input / 255.)
|
105 |
+
|
106 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
107 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
108 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
109 |
+
return image_numpy.astype(imtype)
|
110 |
+
|
111 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
112 |
+
return torch.Tensor((image / factor - cent)
|
113 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
114 |
+
|
115 |
+
def tensor2vec(vector_tensor):
|
116 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
117 |
+
|
118 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
119 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
120 |
+
Compute VOC AP given precision and recall.
|
121 |
+
If use_07_metric is true, uses the
|
122 |
+
VOC 07 11 point method (default:False).
|
123 |
+
"""
|
124 |
+
if use_07_metric:
|
125 |
+
# 11 point metric
|
126 |
+
ap = 0.
|
127 |
+
for t in np.arange(0., 1.1, 0.1):
|
128 |
+
if np.sum(rec >= t) == 0:
|
129 |
+
p = 0
|
130 |
+
else:
|
131 |
+
p = np.max(prec[rec >= t])
|
132 |
+
ap = ap + p / 11.
|
133 |
+
else:
|
134 |
+
# correct AP calculation
|
135 |
+
# first append sentinel values at the end
|
136 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
137 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
138 |
+
|
139 |
+
# compute the precision envelope
|
140 |
+
for i in range(mpre.size - 1, 0, -1):
|
141 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
142 |
+
|
143 |
+
# to calculate area under PR curve, look for points
|
144 |
+
# where X axis (recall) changes value
|
145 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
146 |
+
|
147 |
+
# and sum (\Delta recall) * prec
|
148 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
149 |
+
return ap
|
150 |
+
|
151 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
152 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
153 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
154 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
155 |
+
return image_numpy.astype(imtype)
|
156 |
+
|
157 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
158 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
159 |
+
return torch.Tensor((image / factor - cent)
|
160 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
lpips/base_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from pdb import set_trace as st
|
6 |
+
from IPython import embed
|
7 |
+
|
8 |
+
class BaseModel():
|
9 |
+
def __init__(self):
|
10 |
+
pass;
|
11 |
+
|
12 |
+
def name(self):
|
13 |
+
return 'BaseModel'
|
14 |
+
|
15 |
+
def initialize(self, use_gpu=True, gpu_ids=[0]):
|
16 |
+
self.use_gpu = use_gpu
|
17 |
+
self.gpu_ids = gpu_ids
|
18 |
+
|
19 |
+
def forward(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_image_paths(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def optimize_parameters(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def get_current_visuals(self):
|
29 |
+
return self.input
|
30 |
+
|
31 |
+
def get_current_errors(self):
|
32 |
+
return {}
|
33 |
+
|
34 |
+
def save(self, label):
|
35 |
+
pass
|
36 |
+
|
37 |
+
# helper saving function that can be used by subclasses
|
38 |
+
def save_network(self, network, path, network_label, epoch_label):
|
39 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
40 |
+
save_path = os.path.join(path, save_filename)
|
41 |
+
torch.save(network.state_dict(), save_path)
|
42 |
+
|
43 |
+
# helper loading function that can be used by subclasses
|
44 |
+
def load_network(self, network, network_label, epoch_label):
|
45 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
46 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
47 |
+
print('Loading network from %s'%save_path)
|
48 |
+
network.load_state_dict(torch.load(save_path))
|
49 |
+
|
50 |
+
def update_learning_rate():
|
51 |
+
pass
|
52 |
+
|
53 |
+
def get_image_paths(self):
|
54 |
+
return self.image_paths
|
55 |
+
|
56 |
+
def save_done(self, flag=False):
|
57 |
+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
|
58 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|
lpips/dist_model.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import os
|
9 |
+
from collections import OrderedDict
|
10 |
+
from torch.autograd import Variable
|
11 |
+
import itertools
|
12 |
+
from .base_model import BaseModel
|
13 |
+
from scipy.ndimage import zoom
|
14 |
+
import fractions
|
15 |
+
import functools
|
16 |
+
import skimage.transform
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from IPython import embed
|
20 |
+
|
21 |
+
from . import networks_basic as networks
|
22 |
+
import lpips as util
|
23 |
+
|
24 |
+
class DistModel(BaseModel):
|
25 |
+
def name(self):
|
26 |
+
return self.model_name
|
27 |
+
|
28 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
29 |
+
use_gpu=True, printNet=False, spatial=False,
|
30 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
31 |
+
'''
|
32 |
+
INPUTS
|
33 |
+
model - ['net-lin'] for linearly calibrated network
|
34 |
+
['net'] for off-the-shelf network
|
35 |
+
['L2'] for L2 distance in Lab colorspace
|
36 |
+
['SSIM'] for ssim in RGB colorspace
|
37 |
+
net - ['squeeze','alex','vgg']
|
38 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
39 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
40 |
+
use_gpu - bool - whether or not to use a GPU
|
41 |
+
printNet - bool - whether or not to print network architecture out
|
42 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
43 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
44 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
45 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
46 |
+
is_train - bool - [True] for training mode
|
47 |
+
lr - float - initial learning rate
|
48 |
+
beta1 - float - initial momentum term for adam
|
49 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
50 |
+
gpu_ids - int array - [0] by default, gpus to use
|
51 |
+
'''
|
52 |
+
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
53 |
+
|
54 |
+
self.model = model
|
55 |
+
self.net = net
|
56 |
+
self.is_train = is_train
|
57 |
+
self.spatial = spatial
|
58 |
+
self.gpu_ids = gpu_ids
|
59 |
+
self.model_name = '%s [%s]'%(model,net)
|
60 |
+
|
61 |
+
if(self.model == 'net-lin'): # pretrained net + linear layer
|
62 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
63 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
64 |
+
kw = {}
|
65 |
+
if not use_gpu:
|
66 |
+
kw['map_location'] = 'cpu'
|
67 |
+
if(model_path is None):
|
68 |
+
import inspect
|
69 |
+
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
|
70 |
+
|
71 |
+
if(not is_train):
|
72 |
+
print('Loading model from: %s'%model_path)
|
73 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
74 |
+
|
75 |
+
elif(self.model=='net'): # pretrained network
|
76 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
77 |
+
elif(self.model in ['L2','l2']):
|
78 |
+
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
|
79 |
+
self.model_name = 'L2'
|
80 |
+
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
|
81 |
+
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
|
82 |
+
self.model_name = 'SSIM'
|
83 |
+
else:
|
84 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
85 |
+
|
86 |
+
self.parameters = list(self.net.parameters())
|
87 |
+
|
88 |
+
if self.is_train: # training mode
|
89 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
90 |
+
self.rankLoss = networks.BCERankingLoss()
|
91 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
92 |
+
self.lr = lr
|
93 |
+
self.old_lr = lr
|
94 |
+
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
95 |
+
else: # test mode
|
96 |
+
self.net.eval()
|
97 |
+
|
98 |
+
if(use_gpu):
|
99 |
+
self.net.to(gpu_ids[0])
|
100 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
101 |
+
if(self.is_train):
|
102 |
+
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
103 |
+
|
104 |
+
if(printNet):
|
105 |
+
print('---------- Networks initialized -------------')
|
106 |
+
networks.print_network(self.net)
|
107 |
+
print('-----------------------------------------------')
|
108 |
+
|
109 |
+
def forward(self, in0, in1, retPerLayer=False):
|
110 |
+
''' Function computes the distance between image patches in0 and in1
|
111 |
+
INPUTS
|
112 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
113 |
+
OUTPUT
|
114 |
+
computed distances between in0 and in1
|
115 |
+
'''
|
116 |
+
|
117 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
118 |
+
|
119 |
+
# ***** TRAINING FUNCTIONS *****
|
120 |
+
def optimize_parameters(self):
|
121 |
+
self.forward_train()
|
122 |
+
self.optimizer_net.zero_grad()
|
123 |
+
self.backward_train()
|
124 |
+
self.optimizer_net.step()
|
125 |
+
self.clamp_weights()
|
126 |
+
|
127 |
+
def clamp_weights(self):
|
128 |
+
for module in self.net.modules():
|
129 |
+
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
|
130 |
+
module.weight.data = torch.clamp(module.weight.data,min=0)
|
131 |
+
|
132 |
+
def set_input(self, data):
|
133 |
+
self.input_ref = data['ref']
|
134 |
+
self.input_p0 = data['p0']
|
135 |
+
self.input_p1 = data['p1']
|
136 |
+
self.input_judge = data['judge']
|
137 |
+
|
138 |
+
if(self.use_gpu):
|
139 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
140 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
141 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
142 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
143 |
+
|
144 |
+
self.var_ref = Variable(self.input_ref,requires_grad=True)
|
145 |
+
self.var_p0 = Variable(self.input_p0,requires_grad=True)
|
146 |
+
self.var_p1 = Variable(self.input_p1,requires_grad=True)
|
147 |
+
|
148 |
+
def forward_train(self): # run forward pass
|
149 |
+
# print(self.net.module.scaling_layer.shift)
|
150 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
151 |
+
|
152 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
153 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
154 |
+
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
|
155 |
+
|
156 |
+
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
|
157 |
+
|
158 |
+
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
|
159 |
+
|
160 |
+
return self.loss_total
|
161 |
+
|
162 |
+
def backward_train(self):
|
163 |
+
torch.mean(self.loss_total).backward()
|
164 |
+
|
165 |
+
def compute_accuracy(self,d0,d1,judge):
|
166 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
167 |
+
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
|
168 |
+
judge_per = judge.cpu().numpy().flatten()
|
169 |
+
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
|
170 |
+
|
171 |
+
def get_current_errors(self):
|
172 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
173 |
+
('acc_r', self.acc_r)])
|
174 |
+
|
175 |
+
for key in retDict.keys():
|
176 |
+
retDict[key] = np.mean(retDict[key])
|
177 |
+
|
178 |
+
return retDict
|
179 |
+
|
180 |
+
def get_current_visuals(self):
|
181 |
+
zoom_factor = 256/self.var_ref.data.size()[2]
|
182 |
+
|
183 |
+
ref_img = util.tensor2im(self.var_ref.data)
|
184 |
+
p0_img = util.tensor2im(self.var_p0.data)
|
185 |
+
p1_img = util.tensor2im(self.var_p1.data)
|
186 |
+
|
187 |
+
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
|
188 |
+
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
|
189 |
+
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
|
190 |
+
|
191 |
+
return OrderedDict([('ref', ref_img_vis),
|
192 |
+
('p0', p0_img_vis),
|
193 |
+
('p1', p1_img_vis)])
|
194 |
+
|
195 |
+
def save(self, path, label):
|
196 |
+
if(self.use_gpu):
|
197 |
+
self.save_network(self.net.module, path, '', label)
|
198 |
+
else:
|
199 |
+
self.save_network(self.net, path, '', label)
|
200 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
201 |
+
|
202 |
+
def update_learning_rate(self,nepoch_decay):
|
203 |
+
lrd = self.lr / nepoch_decay
|
204 |
+
lr = self.old_lr - lrd
|
205 |
+
|
206 |
+
for param_group in self.optimizer_net.param_groups:
|
207 |
+
param_group['lr'] = lr
|
208 |
+
|
209 |
+
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
|
210 |
+
self.old_lr = lr
|
211 |
+
|
212 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
213 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
214 |
+
distance function 'func' in dataset 'data_loader'
|
215 |
+
INPUTS
|
216 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
217 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
218 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
219 |
+
OUTPUTS
|
220 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
221 |
+
[1] - dictionary with following elements
|
222 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
223 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
224 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
225 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
226 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
227 |
+
CONSTS
|
228 |
+
N - number of test triplets in data_loader
|
229 |
+
'''
|
230 |
+
|
231 |
+
d0s = []
|
232 |
+
d1s = []
|
233 |
+
gts = []
|
234 |
+
|
235 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
236 |
+
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
|
237 |
+
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
|
238 |
+
gts+=data['judge'].cpu().numpy().flatten().tolist()
|
239 |
+
|
240 |
+
d0s = np.array(d0s)
|
241 |
+
d1s = np.array(d1s)
|
242 |
+
gts = np.array(gts)
|
243 |
+
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
|
244 |
+
|
245 |
+
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
|
246 |
+
|
247 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
248 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
249 |
+
INPUTS
|
250 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
251 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
252 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
253 |
+
OUTPUTS
|
254 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
255 |
+
[1] - dictionary with following elements
|
256 |
+
ds - N array containing distances between two patches shown to human evaluator
|
257 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
258 |
+
CONSTS
|
259 |
+
N - number of test triplets in data_loader
|
260 |
+
'''
|
261 |
+
|
262 |
+
ds = []
|
263 |
+
gts = []
|
264 |
+
|
265 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
266 |
+
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
|
267 |
+
gts+=data['same'].cpu().numpy().flatten().tolist()
|
268 |
+
|
269 |
+
sames = np.array(gts)
|
270 |
+
ds = np.array(ds)
|
271 |
+
|
272 |
+
sorted_inds = np.argsort(ds)
|
273 |
+
ds_sorted = ds[sorted_inds]
|
274 |
+
sames_sorted = sames[sorted_inds]
|
275 |
+
|
276 |
+
TPs = np.cumsum(sames_sorted)
|
277 |
+
FPs = np.cumsum(1-sames_sorted)
|
278 |
+
FNs = np.sum(sames_sorted)-TPs
|
279 |
+
|
280 |
+
precs = TPs/(TPs+FPs)
|
281 |
+
recs = TPs/(TPs+FNs)
|
282 |
+
score = util.voc_ap(recs,precs)
|
283 |
+
|
284 |
+
return(score, dict(ds=ds,sames=sames))
|
lpips/networks_basic.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.init as init
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import numpy as np
|
10 |
+
from pdb import set_trace as st
|
11 |
+
from skimage import color
|
12 |
+
from IPython import embed
|
13 |
+
from . import pretrained_networks as pn
|
14 |
+
|
15 |
+
import lpips as util
|
16 |
+
|
17 |
+
def spatial_average(in_tens, keepdim=True):
|
18 |
+
return in_tens.mean([2,3],keepdim=keepdim)
|
19 |
+
|
20 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
21 |
+
in_H = in_tens.shape[2]
|
22 |
+
scale_factor = 1.*out_H/in_H
|
23 |
+
|
24 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
25 |
+
|
26 |
+
# Learned perceptual metric
|
27 |
+
class PNetLin(nn.Module):
|
28 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
|
29 |
+
super(PNetLin, self).__init__()
|
30 |
+
|
31 |
+
self.pnet_type = pnet_type
|
32 |
+
self.pnet_tune = pnet_tune
|
33 |
+
self.pnet_rand = pnet_rand
|
34 |
+
self.spatial = spatial
|
35 |
+
self.lpips = lpips
|
36 |
+
self.version = version
|
37 |
+
self.scaling_layer = ScalingLayer()
|
38 |
+
|
39 |
+
if(self.pnet_type in ['vgg','vgg16']):
|
40 |
+
net_type = pn.vgg16
|
41 |
+
self.chns = [64,128,256,512,512]
|
42 |
+
elif(self.pnet_type=='alex'):
|
43 |
+
net_type = pn.alexnet
|
44 |
+
self.chns = [64,192,384,256,256]
|
45 |
+
elif(self.pnet_type=='squeeze'):
|
46 |
+
net_type = pn.squeezenet
|
47 |
+
self.chns = [64,128,256,384,384,512,512]
|
48 |
+
self.L = len(self.chns)
|
49 |
+
|
50 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
51 |
+
|
52 |
+
if(lpips):
|
53 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
54 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
55 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
56 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
57 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
58 |
+
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
59 |
+
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
60 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
61 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
62 |
+
self.lins+=[self.lin5,self.lin6]
|
63 |
+
|
64 |
+
def forward(self, in0, in1, retPerLayer=False):
|
65 |
+
# v0.0 - original release had a bug, where input was not scaled
|
66 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
67 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
68 |
+
feats0, feats1, diffs = {}, {}, {}
|
69 |
+
|
70 |
+
for kk in range(self.L):
|
71 |
+
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
|
72 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
73 |
+
|
74 |
+
if(self.lpips):
|
75 |
+
if(self.spatial):
|
76 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
77 |
+
else:
|
78 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
79 |
+
else:
|
80 |
+
if(self.spatial):
|
81 |
+
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
82 |
+
else:
|
83 |
+
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
84 |
+
|
85 |
+
val = res[0]
|
86 |
+
for l in range(1,self.L):
|
87 |
+
val += res[l]
|
88 |
+
|
89 |
+
if(retPerLayer):
|
90 |
+
return (val, res)
|
91 |
+
else:
|
92 |
+
return val
|
93 |
+
|
94 |
+
class ScalingLayer(nn.Module):
|
95 |
+
def __init__(self):
|
96 |
+
super(ScalingLayer, self).__init__()
|
97 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
98 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
99 |
+
|
100 |
+
def forward(self, inp):
|
101 |
+
return (inp - self.shift) / self.scale
|
102 |
+
|
103 |
+
|
104 |
+
class NetLinLayer(nn.Module):
|
105 |
+
''' A single linear layer which does a 1x1 conv '''
|
106 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
107 |
+
super(NetLinLayer, self).__init__()
|
108 |
+
|
109 |
+
layers = [nn.Dropout(),] if(use_dropout) else []
|
110 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
111 |
+
self.model = nn.Sequential(*layers)
|
112 |
+
|
113 |
+
|
114 |
+
class Dist2LogitLayer(nn.Module):
|
115 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
116 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
117 |
+
super(Dist2LogitLayer, self).__init__()
|
118 |
+
|
119 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
120 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
121 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
122 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
123 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
124 |
+
if(use_sigmoid):
|
125 |
+
layers += [nn.Sigmoid(),]
|
126 |
+
self.model = nn.Sequential(*layers)
|
127 |
+
|
128 |
+
def forward(self,d0,d1,eps=0.1):
|
129 |
+
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
130 |
+
|
131 |
+
class BCERankingLoss(nn.Module):
|
132 |
+
def __init__(self, chn_mid=32):
|
133 |
+
super(BCERankingLoss, self).__init__()
|
134 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
135 |
+
# self.parameters = list(self.net.parameters())
|
136 |
+
self.loss = torch.nn.BCELoss()
|
137 |
+
|
138 |
+
def forward(self, d0, d1, judge):
|
139 |
+
per = (judge+1.)/2.
|
140 |
+
self.logit = self.net.forward(d0,d1)
|
141 |
+
return self.loss(self.logit, per)
|
142 |
+
|
143 |
+
# L2, DSSIM metrics
|
144 |
+
class FakeNet(nn.Module):
|
145 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
146 |
+
super(FakeNet, self).__init__()
|
147 |
+
self.use_gpu = use_gpu
|
148 |
+
self.colorspace=colorspace
|
149 |
+
|
150 |
+
class L2(FakeNet):
|
151 |
+
|
152 |
+
def forward(self, in0, in1, retPerLayer=None):
|
153 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
154 |
+
|
155 |
+
if(self.colorspace=='RGB'):
|
156 |
+
(N,C,X,Y) = in0.size()
|
157 |
+
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
158 |
+
return value
|
159 |
+
elif(self.colorspace=='Lab'):
|
160 |
+
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
161 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
162 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
163 |
+
if(self.use_gpu):
|
164 |
+
ret_var = ret_var.cuda()
|
165 |
+
return ret_var
|
166 |
+
|
167 |
+
class DSSIM(FakeNet):
|
168 |
+
|
169 |
+
def forward(self, in0, in1, retPerLayer=None):
|
170 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
171 |
+
|
172 |
+
if(self.colorspace=='RGB'):
|
173 |
+
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
|
174 |
+
elif(self.colorspace=='Lab'):
|
175 |
+
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
176 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
177 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
178 |
+
if(self.use_gpu):
|
179 |
+
ret_var = ret_var.cuda()
|
180 |
+
return ret_var
|
181 |
+
|
182 |
+
def print_network(net):
|
183 |
+
num_params = 0
|
184 |
+
for param in net.parameters():
|
185 |
+
num_params += param.numel()
|
186 |
+
print('Network',net)
|
187 |
+
print('Total number of parameters: %d' % num_params)
|
lpips/pretrained_networks.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
from IPython import embed
|
5 |
+
|
6 |
+
class squeezenet(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
+
super(squeezenet, self).__init__()
|
9 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
+
self.slice1 = torch.nn.Sequential()
|
11 |
+
self.slice2 = torch.nn.Sequential()
|
12 |
+
self.slice3 = torch.nn.Sequential()
|
13 |
+
self.slice4 = torch.nn.Sequential()
|
14 |
+
self.slice5 = torch.nn.Sequential()
|
15 |
+
self.slice6 = torch.nn.Sequential()
|
16 |
+
self.slice7 = torch.nn.Sequential()
|
17 |
+
self.N_slices = 7
|
18 |
+
for x in range(2):
|
19 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
+
for x in range(2,5):
|
21 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
+
for x in range(5, 8):
|
23 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
+
for x in range(8, 10):
|
25 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
+
for x in range(10, 11):
|
27 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
+
for x in range(11, 12):
|
29 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
+
for x in range(12, 13):
|
31 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
+
if not requires_grad:
|
33 |
+
for param in self.parameters():
|
34 |
+
param.requires_grad = False
|
35 |
+
|
36 |
+
def forward(self, X):
|
37 |
+
h = self.slice1(X)
|
38 |
+
h_relu1 = h
|
39 |
+
h = self.slice2(h)
|
40 |
+
h_relu2 = h
|
41 |
+
h = self.slice3(h)
|
42 |
+
h_relu3 = h
|
43 |
+
h = self.slice4(h)
|
44 |
+
h_relu4 = h
|
45 |
+
h = self.slice5(h)
|
46 |
+
h_relu5 = h
|
47 |
+
h = self.slice6(h)
|
48 |
+
h_relu6 = h
|
49 |
+
h = self.slice7(h)
|
50 |
+
h_relu7 = h
|
51 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
52 |
+
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class alexnet(torch.nn.Module):
|
58 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
59 |
+
super(alexnet, self).__init__()
|
60 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
61 |
+
self.slice1 = torch.nn.Sequential()
|
62 |
+
self.slice2 = torch.nn.Sequential()
|
63 |
+
self.slice3 = torch.nn.Sequential()
|
64 |
+
self.slice4 = torch.nn.Sequential()
|
65 |
+
self.slice5 = torch.nn.Sequential()
|
66 |
+
self.N_slices = 5
|
67 |
+
for x in range(2):
|
68 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
69 |
+
for x in range(2, 5):
|
70 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
+
for x in range(5, 8):
|
72 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
+
for x in range(8, 10):
|
74 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
+
for x in range(10, 12):
|
76 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
+
if not requires_grad:
|
78 |
+
for param in self.parameters():
|
79 |
+
param.requires_grad = False
|
80 |
+
|
81 |
+
def forward(self, X):
|
82 |
+
h = self.slice1(X)
|
83 |
+
h_relu1 = h
|
84 |
+
h = self.slice2(h)
|
85 |
+
h_relu2 = h
|
86 |
+
h = self.slice3(h)
|
87 |
+
h_relu3 = h
|
88 |
+
h = self.slice4(h)
|
89 |
+
h_relu4 = h
|
90 |
+
h = self.slice5(h)
|
91 |
+
h_relu5 = h
|
92 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
93 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
class vgg16(torch.nn.Module):
|
98 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
99 |
+
super(vgg16, self).__init__()
|
100 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
101 |
+
self.slice1 = torch.nn.Sequential()
|
102 |
+
self.slice2 = torch.nn.Sequential()
|
103 |
+
self.slice3 = torch.nn.Sequential()
|
104 |
+
self.slice4 = torch.nn.Sequential()
|
105 |
+
self.slice5 = torch.nn.Sequential()
|
106 |
+
self.N_slices = 5
|
107 |
+
for x in range(4):
|
108 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
109 |
+
for x in range(4, 9):
|
110 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(9, 16):
|
112 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(16, 23):
|
114 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(23, 30):
|
116 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
if not requires_grad:
|
118 |
+
for param in self.parameters():
|
119 |
+
param.requires_grad = False
|
120 |
+
|
121 |
+
def forward(self, X):
|
122 |
+
h = self.slice1(X)
|
123 |
+
h_relu1_2 = h
|
124 |
+
h = self.slice2(h)
|
125 |
+
h_relu2_2 = h
|
126 |
+
h = self.slice3(h)
|
127 |
+
h_relu3_3 = h
|
128 |
+
h = self.slice4(h)
|
129 |
+
h_relu4_3 = h
|
130 |
+
h = self.slice5(h)
|
131 |
+
h_relu5_3 = h
|
132 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
133 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class resnet(torch.nn.Module):
|
140 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
141 |
+
super(resnet, self).__init__()
|
142 |
+
if(num==18):
|
143 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
144 |
+
elif(num==34):
|
145 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
146 |
+
elif(num==50):
|
147 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
148 |
+
elif(num==101):
|
149 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
150 |
+
elif(num==152):
|
151 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
152 |
+
self.N_slices = 5
|
153 |
+
|
154 |
+
self.conv1 = self.net.conv1
|
155 |
+
self.bn1 = self.net.bn1
|
156 |
+
self.relu = self.net.relu
|
157 |
+
self.maxpool = self.net.maxpool
|
158 |
+
self.layer1 = self.net.layer1
|
159 |
+
self.layer2 = self.net.layer2
|
160 |
+
self.layer3 = self.net.layer3
|
161 |
+
self.layer4 = self.net.layer4
|
162 |
+
|
163 |
+
def forward(self, X):
|
164 |
+
h = self.conv1(X)
|
165 |
+
h = self.bn1(h)
|
166 |
+
h = self.relu(h)
|
167 |
+
h_relu1 = h
|
168 |
+
h = self.maxpool(h)
|
169 |
+
h = self.layer1(h)
|
170 |
+
h_conv2 = h
|
171 |
+
h = self.layer2(h)
|
172 |
+
h_conv3 = h
|
173 |
+
h = self.layer3(h)
|
174 |
+
h_conv4 = h
|
175 |
+
h = self.layer4(h)
|
176 |
+
h_conv5 = h
|
177 |
+
|
178 |
+
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
179 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
180 |
+
|
181 |
+
return out
|
lpips/weights/v0.0/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
|
3 |
+
size 5455
|
lpips/weights/v0.0/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
|
3 |
+
size 10057
|
lpips/weights/v0.0/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
|
3 |
+
size 6735
|
lpips/weights/v0.1/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
3 |
+
size 6009
|
lpips/weights/v0.1/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
|
3 |
+
size 10811
|
lpips/weights/v0.1/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
model.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import functools
|
4 |
+
import operator
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.autograd import Function
|
10 |
+
|
11 |
+
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
12 |
+
|
13 |
+
|
14 |
+
class PixelNorm(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def forward(self, input):
|
19 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
20 |
+
|
21 |
+
|
22 |
+
def make_kernel(k):
|
23 |
+
k = torch.tensor(k, dtype=torch.float32)
|
24 |
+
|
25 |
+
if k.ndim == 1:
|
26 |
+
k = k[None, :] * k[:, None]
|
27 |
+
|
28 |
+
k /= k.sum()
|
29 |
+
|
30 |
+
return k
|
31 |
+
|
32 |
+
|
33 |
+
class Upsample(nn.Module):
|
34 |
+
def __init__(self, kernel, factor=2):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.factor = factor
|
38 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
39 |
+
self.register_buffer("kernel", kernel)
|
40 |
+
|
41 |
+
p = kernel.shape[0] - factor
|
42 |
+
|
43 |
+
pad0 = (p + 1) // 2 + factor - 1
|
44 |
+
pad1 = p // 2
|
45 |
+
|
46 |
+
self.pad = (pad0, pad1)
|
47 |
+
|
48 |
+
def forward(self, input):
|
49 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
50 |
+
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class Downsample(nn.Module):
|
55 |
+
def __init__(self, kernel, factor=2):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.factor = factor
|
59 |
+
kernel = make_kernel(kernel)
|
60 |
+
self.register_buffer("kernel", kernel)
|
61 |
+
|
62 |
+
p = kernel.shape[0] - factor
|
63 |
+
|
64 |
+
pad0 = (p + 1) // 2
|
65 |
+
pad1 = p // 2
|
66 |
+
|
67 |
+
self.pad = (pad0, pad1)
|
68 |
+
|
69 |
+
def forward(self, input):
|
70 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
71 |
+
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
class Blur(nn.Module):
|
76 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
kernel = make_kernel(kernel)
|
80 |
+
|
81 |
+
if upsample_factor > 1:
|
82 |
+
kernel = kernel * (upsample_factor ** 2)
|
83 |
+
|
84 |
+
self.register_buffer("kernel", kernel)
|
85 |
+
|
86 |
+
self.pad = pad
|
87 |
+
|
88 |
+
def forward(self, input):
|
89 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class EqualConv2d(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.weight = nn.Parameter(
|
101 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
102 |
+
)
|
103 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
104 |
+
|
105 |
+
self.stride = stride
|
106 |
+
self.padding = padding
|
107 |
+
|
108 |
+
if bias:
|
109 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
110 |
+
|
111 |
+
else:
|
112 |
+
self.bias = None
|
113 |
+
|
114 |
+
def forward(self, input):
|
115 |
+
out = conv2d_gradfix.conv2d(
|
116 |
+
input,
|
117 |
+
self.weight * self.scale,
|
118 |
+
bias=self.bias,
|
119 |
+
stride=self.stride,
|
120 |
+
padding=self.padding,
|
121 |
+
)
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
def __repr__(self):
|
126 |
+
return (
|
127 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
128 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
class EqualLinear(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
139 |
+
|
140 |
+
if bias:
|
141 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
142 |
+
|
143 |
+
else:
|
144 |
+
self.bias = None
|
145 |
+
|
146 |
+
self.activation = activation
|
147 |
+
|
148 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
149 |
+
self.lr_mul = lr_mul
|
150 |
+
|
151 |
+
def forward(self, input):
|
152 |
+
if self.activation:
|
153 |
+
out = F.linear(input, self.weight * self.scale)
|
154 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
155 |
+
|
156 |
+
else:
|
157 |
+
out = F.linear(
|
158 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
159 |
+
)
|
160 |
+
|
161 |
+
return out
|
162 |
+
|
163 |
+
def __repr__(self):
|
164 |
+
return (
|
165 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
class ModulatedConv2d(nn.Module):
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
in_channel,
|
173 |
+
out_channel,
|
174 |
+
kernel_size,
|
175 |
+
style_dim,
|
176 |
+
demodulate=True,
|
177 |
+
upsample=False,
|
178 |
+
downsample=False,
|
179 |
+
blur_kernel=[1, 3, 3, 1],
|
180 |
+
fused=True,
|
181 |
+
):
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
self.eps = 1e-8
|
185 |
+
self.kernel_size = kernel_size
|
186 |
+
self.in_channel = in_channel
|
187 |
+
self.out_channel = out_channel
|
188 |
+
self.upsample = upsample
|
189 |
+
self.downsample = downsample
|
190 |
+
|
191 |
+
if upsample:
|
192 |
+
factor = 2
|
193 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
194 |
+
pad0 = (p + 1) // 2 + factor - 1
|
195 |
+
pad1 = p // 2 + 1
|
196 |
+
|
197 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
198 |
+
|
199 |
+
if downsample:
|
200 |
+
factor = 2
|
201 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
202 |
+
pad0 = (p + 1) // 2
|
203 |
+
pad1 = p // 2
|
204 |
+
|
205 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
206 |
+
|
207 |
+
fan_in = in_channel * kernel_size ** 2
|
208 |
+
self.scale = 1 / math.sqrt(fan_in)
|
209 |
+
self.padding = kernel_size // 2
|
210 |
+
|
211 |
+
self.weight = nn.Parameter(
|
212 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
213 |
+
)
|
214 |
+
|
215 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
216 |
+
|
217 |
+
self.demodulate = demodulate
|
218 |
+
self.fused = fused
|
219 |
+
|
220 |
+
def __repr__(self):
|
221 |
+
return (
|
222 |
+
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
223 |
+
f"upsample={self.upsample}, downsample={self.downsample})"
|
224 |
+
)
|
225 |
+
|
226 |
+
def forward(self, input, style):
|
227 |
+
batch, in_channel, height, width = input.shape
|
228 |
+
|
229 |
+
if not self.fused:
|
230 |
+
weight = self.scale * self.weight.squeeze(0)
|
231 |
+
style = self.modulation(style)
|
232 |
+
|
233 |
+
if self.demodulate:
|
234 |
+
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
|
235 |
+
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
|
236 |
+
|
237 |
+
input = input * style.reshape(batch, in_channel, 1, 1)
|
238 |
+
|
239 |
+
if self.upsample:
|
240 |
+
weight = weight.transpose(0, 1)
|
241 |
+
out = conv2d_gradfix.conv_transpose2d(
|
242 |
+
input, weight, padding=0, stride=2
|
243 |
+
)
|
244 |
+
out = self.blur(out)
|
245 |
+
|
246 |
+
elif self.downsample:
|
247 |
+
input = self.blur(input)
|
248 |
+
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
|
249 |
+
|
250 |
+
else:
|
251 |
+
out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
|
252 |
+
|
253 |
+
if self.demodulate:
|
254 |
+
out = out * dcoefs.view(batch, -1, 1, 1)
|
255 |
+
|
256 |
+
return out
|
257 |
+
|
258 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
259 |
+
weight = self.scale * self.weight * style
|
260 |
+
|
261 |
+
if self.demodulate:
|
262 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
263 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
264 |
+
|
265 |
+
weight = weight.view(
|
266 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
267 |
+
)
|
268 |
+
|
269 |
+
if self.upsample:
|
270 |
+
input = input.view(1, batch * in_channel, height, width)
|
271 |
+
weight = weight.view(
|
272 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
273 |
+
)
|
274 |
+
weight = weight.transpose(1, 2).reshape(
|
275 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
276 |
+
)
|
277 |
+
out = conv2d_gradfix.conv_transpose2d(
|
278 |
+
input, weight, padding=0, stride=2, groups=batch
|
279 |
+
)
|
280 |
+
_, _, height, width = out.shape
|
281 |
+
out = out.view(batch, self.out_channel, height, width)
|
282 |
+
out = self.blur(out)
|
283 |
+
|
284 |
+
elif self.downsample:
|
285 |
+
input = self.blur(input)
|
286 |
+
_, _, height, width = input.shape
|
287 |
+
input = input.view(1, batch * in_channel, height, width)
|
288 |
+
out = conv2d_gradfix.conv2d(
|
289 |
+
input, weight, padding=0, stride=2, groups=batch
|
290 |
+
)
|
291 |
+
_, _, height, width = out.shape
|
292 |
+
out = out.view(batch, self.out_channel, height, width)
|
293 |
+
|
294 |
+
else:
|
295 |
+
input = input.view(1, batch * in_channel, height, width)
|
296 |
+
out = conv2d_gradfix.conv2d(
|
297 |
+
input, weight, padding=self.padding, groups=batch
|
298 |
+
)
|
299 |
+
_, _, height, width = out.shape
|
300 |
+
out = out.view(batch, self.out_channel, height, width)
|
301 |
+
|
302 |
+
return out
|
303 |
+
|
304 |
+
|
305 |
+
class NoiseInjection(nn.Module):
|
306 |
+
def __init__(self):
|
307 |
+
super().__init__()
|
308 |
+
|
309 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
310 |
+
|
311 |
+
def forward(self, image, noise=None):
|
312 |
+
if noise is None:
|
313 |
+
batch, _, height, width = image.shape
|
314 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
315 |
+
|
316 |
+
return image + self.weight * noise
|
317 |
+
|
318 |
+
|
319 |
+
class ConstantInput(nn.Module):
|
320 |
+
def __init__(self, channel, size=4):
|
321 |
+
super().__init__()
|
322 |
+
|
323 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
324 |
+
|
325 |
+
def forward(self, input):
|
326 |
+
batch = input.shape[0]
|
327 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
328 |
+
|
329 |
+
return out
|
330 |
+
|
331 |
+
|
332 |
+
class StyledConv(nn.Module):
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
in_channel,
|
336 |
+
out_channel,
|
337 |
+
kernel_size,
|
338 |
+
style_dim,
|
339 |
+
upsample=False,
|
340 |
+
blur_kernel=[1, 3, 3, 1],
|
341 |
+
demodulate=True,
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
|
345 |
+
self.conv = ModulatedConv2d(
|
346 |
+
in_channel,
|
347 |
+
out_channel,
|
348 |
+
kernel_size,
|
349 |
+
style_dim,
|
350 |
+
upsample=upsample,
|
351 |
+
blur_kernel=blur_kernel,
|
352 |
+
demodulate=demodulate,
|
353 |
+
)
|
354 |
+
|
355 |
+
self.noise = NoiseInjection()
|
356 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
357 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
358 |
+
self.activate = FusedLeakyReLU(out_channel)
|
359 |
+
|
360 |
+
def forward(self, input, style, noise=None):
|
361 |
+
out = self.conv(input, style)
|
362 |
+
out = self.noise(out, noise=noise)
|
363 |
+
# out = out + self.bias
|
364 |
+
out = self.activate(out)
|
365 |
+
|
366 |
+
return out
|
367 |
+
|
368 |
+
|
369 |
+
class ToRGB(nn.Module):
|
370 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
371 |
+
super().__init__()
|
372 |
+
|
373 |
+
if upsample:
|
374 |
+
self.upsample = Upsample(blur_kernel)
|
375 |
+
|
376 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
377 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
378 |
+
|
379 |
+
def forward(self, input, style, skip=None):
|
380 |
+
out = self.conv(input, style)
|
381 |
+
out = out + self.bias
|
382 |
+
|
383 |
+
if skip is not None:
|
384 |
+
skip = self.upsample(skip)
|
385 |
+
|
386 |
+
out = out + skip
|
387 |
+
|
388 |
+
return out
|
389 |
+
|
390 |
+
|
391 |
+
class Generator(nn.Module):
|
392 |
+
def __init__(
|
393 |
+
self,
|
394 |
+
size,
|
395 |
+
style_dim,
|
396 |
+
n_mlp,
|
397 |
+
channel_multiplier=2,
|
398 |
+
blur_kernel=[1, 3, 3, 1],
|
399 |
+
lr_mlp=0.01,
|
400 |
+
):
|
401 |
+
super().__init__()
|
402 |
+
|
403 |
+
self.size = size
|
404 |
+
|
405 |
+
self.style_dim = style_dim
|
406 |
+
|
407 |
+
layers = [PixelNorm()]
|
408 |
+
|
409 |
+
for i in range(n_mlp):
|
410 |
+
layers.append(
|
411 |
+
EqualLinear(
|
412 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
413 |
+
)
|
414 |
+
)
|
415 |
+
|
416 |
+
self.style = nn.Sequential(*layers)
|
417 |
+
|
418 |
+
self.channels = {
|
419 |
+
4: 512,
|
420 |
+
8: 512,
|
421 |
+
16: 512,
|
422 |
+
32: 512,
|
423 |
+
64: 256 * channel_multiplier,
|
424 |
+
128: 128 * channel_multiplier,
|
425 |
+
256: 64 * channel_multiplier,
|
426 |
+
512: 32 * channel_multiplier,
|
427 |
+
1024: 16 * channel_multiplier,
|
428 |
+
}
|
429 |
+
|
430 |
+
self.input = ConstantInput(self.channels[4])
|
431 |
+
self.conv1 = StyledConv(
|
432 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
433 |
+
)
|
434 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
435 |
+
|
436 |
+
self.log_size = int(math.log(size, 2))
|
437 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
438 |
+
|
439 |
+
self.convs = nn.ModuleList()
|
440 |
+
self.upsamples = nn.ModuleList()
|
441 |
+
self.to_rgbs = nn.ModuleList()
|
442 |
+
self.noises = nn.Module()
|
443 |
+
|
444 |
+
in_channel = self.channels[4]
|
445 |
+
|
446 |
+
for layer_idx in range(self.num_layers):
|
447 |
+
res = (layer_idx + 5) // 2
|
448 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
449 |
+
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
450 |
+
|
451 |
+
for i in range(3, self.log_size + 1):
|
452 |
+
out_channel = self.channels[2 ** i]
|
453 |
+
|
454 |
+
self.convs.append(
|
455 |
+
StyledConv(
|
456 |
+
in_channel,
|
457 |
+
out_channel,
|
458 |
+
3,
|
459 |
+
style_dim,
|
460 |
+
upsample=True,
|
461 |
+
blur_kernel=blur_kernel,
|
462 |
+
)
|
463 |
+
)
|
464 |
+
|
465 |
+
self.convs.append(
|
466 |
+
StyledConv(
|
467 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
468 |
+
)
|
469 |
+
)
|
470 |
+
|
471 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
472 |
+
|
473 |
+
in_channel = out_channel
|
474 |
+
|
475 |
+
self.n_latent = self.log_size * 2 - 2
|
476 |
+
|
477 |
+
def make_noise(self):
|
478 |
+
device = self.input.input.device
|
479 |
+
|
480 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
481 |
+
|
482 |
+
for i in range(3, self.log_size + 1):
|
483 |
+
for _ in range(2):
|
484 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
485 |
+
|
486 |
+
return noises
|
487 |
+
|
488 |
+
def mean_latent(self, n_latent):
|
489 |
+
latent_in = torch.randn(
|
490 |
+
n_latent, self.style_dim, device=self.input.input.device
|
491 |
+
)
|
492 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
493 |
+
|
494 |
+
return latent
|
495 |
+
|
496 |
+
def get_latent(self, input):
|
497 |
+
return self.style(input)
|
498 |
+
|
499 |
+
def forward(
|
500 |
+
self,
|
501 |
+
styles,
|
502 |
+
return_latents=False,
|
503 |
+
inject_index=None,
|
504 |
+
truncation=1,
|
505 |
+
truncation_latent=None,
|
506 |
+
input_is_latent=False,
|
507 |
+
noise=None,
|
508 |
+
randomize_noise=True,
|
509 |
+
):
|
510 |
+
if not input_is_latent:
|
511 |
+
styles = [self.style(s) for s in styles]
|
512 |
+
|
513 |
+
if noise is None:
|
514 |
+
if randomize_noise:
|
515 |
+
noise = [None] * self.num_layers
|
516 |
+
else:
|
517 |
+
noise = [
|
518 |
+
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
519 |
+
]
|
520 |
+
|
521 |
+
if truncation < 1:
|
522 |
+
style_t = []
|
523 |
+
|
524 |
+
for style in styles:
|
525 |
+
style_t.append(
|
526 |
+
truncation_latent + truncation * (style - truncation_latent)
|
527 |
+
)
|
528 |
+
|
529 |
+
styles = style_t
|
530 |
+
|
531 |
+
if len(styles) < 2:
|
532 |
+
inject_index = self.n_latent
|
533 |
+
|
534 |
+
if styles[0].ndim < 3:
|
535 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
536 |
+
|
537 |
+
else:
|
538 |
+
latent = styles[0]
|
539 |
+
|
540 |
+
else:
|
541 |
+
if inject_index is None:
|
542 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
543 |
+
|
544 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
545 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
546 |
+
|
547 |
+
latent = torch.cat([latent, latent2], 1)
|
548 |
+
|
549 |
+
out = self.input(latent)
|
550 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
551 |
+
|
552 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
553 |
+
|
554 |
+
i = 1
|
555 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
556 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
557 |
+
):
|
558 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
559 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
560 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
561 |
+
|
562 |
+
i += 2
|
563 |
+
|
564 |
+
image = skip
|
565 |
+
|
566 |
+
if return_latents:
|
567 |
+
return image, latent
|
568 |
+
|
569 |
+
else:
|
570 |
+
return image, None
|
571 |
+
|
572 |
+
|
573 |
+
class ConvLayer(nn.Sequential):
|
574 |
+
def __init__(
|
575 |
+
self,
|
576 |
+
in_channel,
|
577 |
+
out_channel,
|
578 |
+
kernel_size,
|
579 |
+
downsample=False,
|
580 |
+
blur_kernel=[1, 3, 3, 1],
|
581 |
+
bias=True,
|
582 |
+
activate=True,
|
583 |
+
):
|
584 |
+
layers = []
|
585 |
+
|
586 |
+
if downsample:
|
587 |
+
factor = 2
|
588 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
589 |
+
pad0 = (p + 1) // 2
|
590 |
+
pad1 = p // 2
|
591 |
+
|
592 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
593 |
+
|
594 |
+
stride = 2
|
595 |
+
self.padding = 0
|
596 |
+
|
597 |
+
else:
|
598 |
+
stride = 1
|
599 |
+
self.padding = kernel_size // 2
|
600 |
+
|
601 |
+
layers.append(
|
602 |
+
EqualConv2d(
|
603 |
+
in_channel,
|
604 |
+
out_channel,
|
605 |
+
kernel_size,
|
606 |
+
padding=self.padding,
|
607 |
+
stride=stride,
|
608 |
+
bias=bias and not activate,
|
609 |
+
)
|
610 |
+
)
|
611 |
+
|
612 |
+
if activate:
|
613 |
+
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
614 |
+
|
615 |
+
super().__init__(*layers)
|
616 |
+
|
617 |
+
|
618 |
+
class ResBlock(nn.Module):
|
619 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
620 |
+
super().__init__()
|
621 |
+
|
622 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
623 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
624 |
+
|
625 |
+
self.skip = ConvLayer(
|
626 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
627 |
+
)
|
628 |
+
|
629 |
+
def forward(self, input):
|
630 |
+
out = self.conv1(input)
|
631 |
+
out = self.conv2(out)
|
632 |
+
|
633 |
+
skip = self.skip(input)
|
634 |
+
out = (out + skip) / math.sqrt(2)
|
635 |
+
|
636 |
+
return out
|
637 |
+
|
638 |
+
|
639 |
+
class Discriminator(nn.Module):
|
640 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
641 |
+
super().__init__()
|
642 |
+
|
643 |
+
channels = {
|
644 |
+
4: 512,
|
645 |
+
8: 512,
|
646 |
+
16: 512,
|
647 |
+
32: 512,
|
648 |
+
64: 256 * channel_multiplier,
|
649 |
+
128: 128 * channel_multiplier,
|
650 |
+
256: 64 * channel_multiplier,
|
651 |
+
512: 32 * channel_multiplier,
|
652 |
+
1024: 16 * channel_multiplier,
|
653 |
+
}
|
654 |
+
|
655 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
656 |
+
|
657 |
+
log_size = int(math.log(size, 2))
|
658 |
+
|
659 |
+
in_channel = channels[size]
|
660 |
+
|
661 |
+
for i in range(log_size, 2, -1):
|
662 |
+
out_channel = channels[2 ** (i - 1)]
|
663 |
+
|
664 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
665 |
+
|
666 |
+
in_channel = out_channel
|
667 |
+
|
668 |
+
self.convs = nn.Sequential(*convs)
|
669 |
+
|
670 |
+
self.stddev_group = 4
|
671 |
+
self.stddev_feat = 1
|
672 |
+
|
673 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
674 |
+
self.final_linear = nn.Sequential(
|
675 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
676 |
+
EqualLinear(channels[4], 1),
|
677 |
+
)
|
678 |
+
|
679 |
+
def forward(self, input):
|
680 |
+
out = self.convs(input)
|
681 |
+
|
682 |
+
batch, channel, height, width = out.shape
|
683 |
+
group = min(batch, self.stddev_group)
|
684 |
+
stddev = out.view(
|
685 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
686 |
+
)
|
687 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
688 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
689 |
+
stddev = stddev.repeat(group, 1, height, width)
|
690 |
+
out = torch.cat([out, stddev], 1)
|
691 |
+
|
692 |
+
out = self.final_conv(out)
|
693 |
+
|
694 |
+
out = out.view(batch, -1)
|
695 |
+
out = self.final_linear(out)
|
696 |
+
|
697 |
+
return out
|
698 |
+
|
non_leaking.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import autograd
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from distributed import reduce_sum
|
9 |
+
from op import upfirdn2d
|
10 |
+
|
11 |
+
|
12 |
+
class AdaptiveAugment:
|
13 |
+
def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
|
14 |
+
self.ada_aug_target = ada_aug_target
|
15 |
+
self.ada_aug_len = ada_aug_len
|
16 |
+
self.update_every = update_every
|
17 |
+
|
18 |
+
self.ada_update = 0
|
19 |
+
self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
|
20 |
+
self.r_t_stat = 0
|
21 |
+
self.ada_aug_p = 0
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def tune(self, real_pred):
|
25 |
+
self.ada_aug_buf += torch.tensor(
|
26 |
+
(torch.sign(real_pred).sum().item(), real_pred.shape[0]),
|
27 |
+
device=real_pred.device,
|
28 |
+
)
|
29 |
+
self.ada_update += 1
|
30 |
+
|
31 |
+
if self.ada_update % self.update_every == 0:
|
32 |
+
self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
|
33 |
+
pred_signs, n_pred = self.ada_aug_buf.tolist()
|
34 |
+
|
35 |
+
self.r_t_stat = pred_signs / n_pred
|
36 |
+
|
37 |
+
if self.r_t_stat > self.ada_aug_target:
|
38 |
+
sign = 1
|
39 |
+
|
40 |
+
else:
|
41 |
+
sign = -1
|
42 |
+
|
43 |
+
self.ada_aug_p += sign * n_pred / self.ada_aug_len
|
44 |
+
self.ada_aug_p = min(1, max(0, self.ada_aug_p))
|
45 |
+
self.ada_aug_buf.mul_(0)
|
46 |
+
self.ada_update = 0
|
47 |
+
|
48 |
+
return self.ada_aug_p
|
49 |
+
|
50 |
+
|
51 |
+
SYM6 = (
|
52 |
+
0.015404109327027373,
|
53 |
+
0.0034907120842174702,
|
54 |
+
-0.11799011114819057,
|
55 |
+
-0.048311742585633,
|
56 |
+
0.4910559419267466,
|
57 |
+
0.787641141030194,
|
58 |
+
0.3379294217276218,
|
59 |
+
-0.07263752278646252,
|
60 |
+
-0.021060292512300564,
|
61 |
+
0.04472490177066578,
|
62 |
+
0.0017677118642428036,
|
63 |
+
-0.007800708325034148,
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def translate_mat(t_x, t_y, device="cpu"):
|
68 |
+
batch = t_x.shape[0]
|
69 |
+
|
70 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
71 |
+
translate = torch.stack((t_x, t_y), 1)
|
72 |
+
mat[:, :2, 2] = translate
|
73 |
+
|
74 |
+
return mat
|
75 |
+
|
76 |
+
|
77 |
+
def rotate_mat(theta, device="cpu"):
|
78 |
+
batch = theta.shape[0]
|
79 |
+
|
80 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
81 |
+
sin_t = torch.sin(theta)
|
82 |
+
cos_t = torch.cos(theta)
|
83 |
+
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
|
84 |
+
mat[:, :2, :2] = rot
|
85 |
+
|
86 |
+
return mat
|
87 |
+
|
88 |
+
|
89 |
+
def scale_mat(s_x, s_y, device="cpu"):
|
90 |
+
batch = s_x.shape[0]
|
91 |
+
|
92 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
93 |
+
mat[:, 0, 0] = s_x
|
94 |
+
mat[:, 1, 1] = s_y
|
95 |
+
|
96 |
+
return mat
|
97 |
+
|
98 |
+
|
99 |
+
def translate3d_mat(t_x, t_y, t_z):
|
100 |
+
batch = t_x.shape[0]
|
101 |
+
|
102 |
+
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
103 |
+
translate = torch.stack((t_x, t_y, t_z), 1)
|
104 |
+
mat[:, :3, 3] = translate
|
105 |
+
|
106 |
+
return mat
|
107 |
+
|
108 |
+
|
109 |
+
def rotate3d_mat(axis, theta):
|
110 |
+
batch = theta.shape[0]
|
111 |
+
|
112 |
+
u_x, u_y, u_z = axis
|
113 |
+
|
114 |
+
eye = torch.eye(3).unsqueeze(0)
|
115 |
+
cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
|
116 |
+
outer = torch.tensor(axis)
|
117 |
+
outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
|
118 |
+
|
119 |
+
sin_t = torch.sin(theta).view(-1, 1, 1)
|
120 |
+
cos_t = torch.cos(theta).view(-1, 1, 1)
|
121 |
+
|
122 |
+
rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
|
123 |
+
|
124 |
+
eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
125 |
+
eye_4[:, :3, :3] = rot
|
126 |
+
|
127 |
+
return eye_4
|
128 |
+
|
129 |
+
|
130 |
+
def scale3d_mat(s_x, s_y, s_z):
|
131 |
+
batch = s_x.shape[0]
|
132 |
+
|
133 |
+
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
134 |
+
mat[:, 0, 0] = s_x
|
135 |
+
mat[:, 1, 1] = s_y
|
136 |
+
mat[:, 2, 2] = s_z
|
137 |
+
|
138 |
+
return mat
|
139 |
+
|
140 |
+
|
141 |
+
def luma_flip_mat(axis, i):
|
142 |
+
batch = i.shape[0]
|
143 |
+
|
144 |
+
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
145 |
+
axis = torch.tensor(axis + (0,))
|
146 |
+
flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
|
147 |
+
|
148 |
+
return eye - flip
|
149 |
+
|
150 |
+
|
151 |
+
def saturation_mat(axis, i):
|
152 |
+
batch = i.shape[0]
|
153 |
+
|
154 |
+
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
155 |
+
axis = torch.tensor(axis + (0,))
|
156 |
+
axis = torch.ger(axis, axis)
|
157 |
+
saturate = axis + (eye - axis) * i.view(-1, 1, 1)
|
158 |
+
|
159 |
+
return saturate
|
160 |
+
|
161 |
+
|
162 |
+
def lognormal_sample(size, mean=0, std=1, device="cpu"):
|
163 |
+
return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
|
164 |
+
|
165 |
+
|
166 |
+
def category_sample(size, categories, device="cpu"):
|
167 |
+
category = torch.tensor(categories, device=device)
|
168 |
+
sample = torch.randint(high=len(categories), size=(size,), device=device)
|
169 |
+
|
170 |
+
return category[sample]
|
171 |
+
|
172 |
+
|
173 |
+
def uniform_sample(size, low, high, device="cpu"):
|
174 |
+
return torch.empty(size, device=device).uniform_(low, high)
|
175 |
+
|
176 |
+
|
177 |
+
def normal_sample(size, mean=0, std=1, device="cpu"):
|
178 |
+
return torch.empty(size, device=device).normal_(mean, std)
|
179 |
+
|
180 |
+
|
181 |
+
def bernoulli_sample(size, p, device="cpu"):
|
182 |
+
return torch.empty(size, device=device).bernoulli_(p)
|
183 |
+
|
184 |
+
|
185 |
+
def random_mat_apply(p, transform, prev, eye, device="cpu"):
|
186 |
+
size = transform.shape[0]
|
187 |
+
select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
|
188 |
+
select_transform = select * transform + (1 - select) * eye
|
189 |
+
|
190 |
+
return select_transform @ prev
|
191 |
+
|
192 |
+
|
193 |
+
def sample_affine(p, size, height, width, device="cpu"):
|
194 |
+
G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
|
195 |
+
eye = G
|
196 |
+
|
197 |
+
# flip
|
198 |
+
param = category_sample(size, (0, 1))
|
199 |
+
Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
|
200 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
201 |
+
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
|
202 |
+
|
203 |
+
# 90 rotate
|
204 |
+
param = category_sample(size, (0, 3))
|
205 |
+
Gc = rotate_mat(-math.pi / 2 * param, device=device)
|
206 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
207 |
+
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
|
208 |
+
|
209 |
+
# integer translate
|
210 |
+
param = uniform_sample((2, size), -0.125, 0.125)
|
211 |
+
param_height = torch.round(param[0] * height)
|
212 |
+
param_width = torch.round(param[1] * width)
|
213 |
+
Gc = translate_mat(param_width, param_height, device=device)
|
214 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
215 |
+
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
|
216 |
+
|
217 |
+
# isotropic scale
|
218 |
+
param = lognormal_sample(size, std=0.2 * math.log(2))
|
219 |
+
Gc = scale_mat(param, param, device=device)
|
220 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
221 |
+
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
|
222 |
+
|
223 |
+
p_rot = 1 - math.sqrt(1 - p)
|
224 |
+
|
225 |
+
# pre-rotate
|
226 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
227 |
+
Gc = rotate_mat(-param, device=device)
|
228 |
+
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
229 |
+
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
|
230 |
+
|
231 |
+
# anisotropic scale
|
232 |
+
param = lognormal_sample(size, std=0.2 * math.log(2))
|
233 |
+
Gc = scale_mat(param, 1 / param, device=device)
|
234 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
235 |
+
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
|
236 |
+
|
237 |
+
# post-rotate
|
238 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
239 |
+
Gc = rotate_mat(-param, device=device)
|
240 |
+
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
241 |
+
# print('post-rotate', G, rotate_mat(-param), sep='\n')
|
242 |
+
|
243 |
+
# fractional translate
|
244 |
+
param = normal_sample((2, size), std=0.125)
|
245 |
+
Gc = translate_mat(param[1] * width, param[0] * height, device=device)
|
246 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
247 |
+
# print('fractional translate', G, translate_mat(param, param), sep='\n')
|
248 |
+
|
249 |
+
return G
|
250 |
+
|
251 |
+
|
252 |
+
def sample_color(p, size):
|
253 |
+
C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
|
254 |
+
eye = C
|
255 |
+
axis_val = 1 / math.sqrt(3)
|
256 |
+
axis = (axis_val, axis_val, axis_val)
|
257 |
+
|
258 |
+
# brightness
|
259 |
+
param = normal_sample(size, std=0.2)
|
260 |
+
Cc = translate3d_mat(param, param, param)
|
261 |
+
C = random_mat_apply(p, Cc, C, eye)
|
262 |
+
|
263 |
+
# contrast
|
264 |
+
param = lognormal_sample(size, std=0.5 * math.log(2))
|
265 |
+
Cc = scale3d_mat(param, param, param)
|
266 |
+
C = random_mat_apply(p, Cc, C, eye)
|
267 |
+
|
268 |
+
# luma flip
|
269 |
+
param = category_sample(size, (0, 1))
|
270 |
+
Cc = luma_flip_mat(axis, param)
|
271 |
+
C = random_mat_apply(p, Cc, C, eye)
|
272 |
+
|
273 |
+
# hue rotation
|
274 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
275 |
+
Cc = rotate3d_mat(axis, param)
|
276 |
+
C = random_mat_apply(p, Cc, C, eye)
|
277 |
+
|
278 |
+
# saturation
|
279 |
+
param = lognormal_sample(size, std=1 * math.log(2))
|
280 |
+
Cc = saturation_mat(axis, param)
|
281 |
+
C = random_mat_apply(p, Cc, C, eye)
|
282 |
+
|
283 |
+
return C
|
284 |
+
|
285 |
+
|
286 |
+
def make_grid(shape, x0, x1, y0, y1, device):
|
287 |
+
n, c, h, w = shape
|
288 |
+
grid = torch.empty(n, h, w, 3, device=device)
|
289 |
+
grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
|
290 |
+
grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
|
291 |
+
grid[:, :, :, 2] = 1
|
292 |
+
|
293 |
+
return grid
|
294 |
+
|
295 |
+
|
296 |
+
def affine_grid(grid, mat):
|
297 |
+
n, h, w, _ = grid.shape
|
298 |
+
return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
|
299 |
+
|
300 |
+
|
301 |
+
def get_padding(G, height, width, kernel_size):
|
302 |
+
device = G.device
|
303 |
+
|
304 |
+
cx = (width - 1) / 2
|
305 |
+
cy = (height - 1) / 2
|
306 |
+
cp = torch.tensor(
|
307 |
+
[(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
|
308 |
+
)
|
309 |
+
cp = G @ cp.T
|
310 |
+
|
311 |
+
pad_k = kernel_size // 4
|
312 |
+
|
313 |
+
pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
|
314 |
+
pad = torch.cat((-pad, pad)).max(1).values
|
315 |
+
pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
|
316 |
+
pad = pad.max(torch.tensor([0, 0] * 2, device=device))
|
317 |
+
pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
|
318 |
+
|
319 |
+
pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
|
320 |
+
|
321 |
+
return pad_x1, pad_x2, pad_y1, pad_y2
|
322 |
+
|
323 |
+
|
324 |
+
def try_sample_affine_and_pad(img, p, kernel_size, G=None):
|
325 |
+
batch, _, height, width = img.shape
|
326 |
+
|
327 |
+
G_try = G
|
328 |
+
|
329 |
+
if G is None:
|
330 |
+
G_try = torch.inverse(sample_affine(p, batch, height, width))
|
331 |
+
|
332 |
+
pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
|
333 |
+
|
334 |
+
img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
|
335 |
+
|
336 |
+
return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
|
337 |
+
|
338 |
+
|
339 |
+
class GridSampleForward(autograd.Function):
|
340 |
+
@staticmethod
|
341 |
+
def forward(ctx, input, grid):
|
342 |
+
out = F.grid_sample(
|
343 |
+
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
|
344 |
+
)
|
345 |
+
ctx.save_for_backward(input, grid)
|
346 |
+
|
347 |
+
return out
|
348 |
+
|
349 |
+
@staticmethod
|
350 |
+
def backward(ctx, grad_output):
|
351 |
+
input, grid = ctx.saved_tensors
|
352 |
+
grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
|
353 |
+
|
354 |
+
return grad_input, grad_grid
|
355 |
+
|
356 |
+
|
357 |
+
class GridSampleBackward(autograd.Function):
|
358 |
+
@staticmethod
|
359 |
+
def forward(ctx, grad_output, input, grid):
|
360 |
+
op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
|
361 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
362 |
+
ctx.save_for_backward(grid)
|
363 |
+
|
364 |
+
return grad_input, grad_grid
|
365 |
+
|
366 |
+
@staticmethod
|
367 |
+
def backward(ctx, grad_grad_input, grad_grad_grid):
|
368 |
+
(grid,) = ctx.saved_tensors
|
369 |
+
grad_grad_output = None
|
370 |
+
|
371 |
+
if ctx.needs_input_grad[0]:
|
372 |
+
grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
|
373 |
+
|
374 |
+
return grad_grad_output, None, None
|
375 |
+
|
376 |
+
|
377 |
+
grid_sample = GridSampleForward.apply
|
378 |
+
|
379 |
+
|
380 |
+
def scale_mat_single(s_x, s_y):
|
381 |
+
return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
|
382 |
+
|
383 |
+
|
384 |
+
def translate_mat_single(t_x, t_y):
|
385 |
+
return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
|
386 |
+
|
387 |
+
|
388 |
+
def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
|
389 |
+
kernel = antialiasing_kernel
|
390 |
+
len_k = len(kernel)
|
391 |
+
|
392 |
+
kernel = torch.as_tensor(kernel).to(img)
|
393 |
+
# kernel = torch.ger(kernel, kernel).to(img)
|
394 |
+
kernel_flip = torch.flip(kernel, (0,))
|
395 |
+
|
396 |
+
img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
|
397 |
+
img, p, len_k, G
|
398 |
+
)
|
399 |
+
|
400 |
+
G_inv = (
|
401 |
+
translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
|
402 |
+
@ G
|
403 |
+
)
|
404 |
+
up_pad = (
|
405 |
+
(len_k + 2 - 1) // 2,
|
406 |
+
(len_k - 2) // 2,
|
407 |
+
(len_k + 2 - 1) // 2,
|
408 |
+
(len_k - 2) // 2,
|
409 |
+
)
|
410 |
+
img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
|
411 |
+
img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
|
412 |
+
G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
|
413 |
+
G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
|
414 |
+
batch_size, channel, height, width = img.shape
|
415 |
+
pad_k = len_k // 4
|
416 |
+
shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
|
417 |
+
G_inv = (
|
418 |
+
scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
|
419 |
+
@ G_inv
|
420 |
+
@ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
|
421 |
+
)
|
422 |
+
grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
|
423 |
+
img_affine = grid_sample(img_2x, grid)
|
424 |
+
d_p = -pad_k * 2
|
425 |
+
down_pad = (
|
426 |
+
d_p + (len_k - 2 + 1) // 2,
|
427 |
+
d_p + (len_k - 2) // 2,
|
428 |
+
d_p + (len_k - 2 + 1) // 2,
|
429 |
+
d_p + (len_k - 2) // 2,
|
430 |
+
)
|
431 |
+
img_down = upfirdn2d(
|
432 |
+
img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
|
433 |
+
)
|
434 |
+
img_down = upfirdn2d(
|
435 |
+
img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
|
436 |
+
)
|
437 |
+
|
438 |
+
return img_down, G
|
439 |
+
|
440 |
+
|
441 |
+
def apply_color(img, mat):
|
442 |
+
batch = img.shape[0]
|
443 |
+
img = img.permute(0, 2, 3, 1)
|
444 |
+
mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
|
445 |
+
mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
|
446 |
+
img = img @ mat_mul + mat_add
|
447 |
+
img = img.permute(0, 3, 1, 2)
|
448 |
+
|
449 |
+
return img
|
450 |
+
|
451 |
+
|
452 |
+
def random_apply_color(img, p, C=None):
|
453 |
+
if C is None:
|
454 |
+
C = sample_color(p, img.shape[0])
|
455 |
+
|
456 |
+
img = apply_color(img, C.to(img))
|
457 |
+
|
458 |
+
return img, C
|
459 |
+
|
460 |
+
|
461 |
+
def augment(img, p, transform_matrix=(None, None)):
|
462 |
+
img, G = random_apply_affine(img, p, transform_matrix[0])
|
463 |
+
img, C = random_apply_color(img, p, transform_matrix[1])
|
464 |
+
|
465 |
+
return img, (G, C)
|
op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
op/conv2d_gradfix.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import autograd
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
enabled = True
|
9 |
+
weight_gradients_disabled = False
|
10 |
+
|
11 |
+
|
12 |
+
@contextlib.contextmanager
|
13 |
+
def no_weight_gradients():
|
14 |
+
global weight_gradients_disabled
|
15 |
+
|
16 |
+
old = weight_gradients_disabled
|
17 |
+
weight_gradients_disabled = True
|
18 |
+
yield
|
19 |
+
weight_gradients_disabled = old
|
20 |
+
|
21 |
+
|
22 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
23 |
+
if could_use_op(input):
|
24 |
+
return conv2d_gradfix(
|
25 |
+
transpose=False,
|
26 |
+
weight_shape=weight.shape,
|
27 |
+
stride=stride,
|
28 |
+
padding=padding,
|
29 |
+
output_padding=0,
|
30 |
+
dilation=dilation,
|
31 |
+
groups=groups,
|
32 |
+
).apply(input, weight, bias)
|
33 |
+
|
34 |
+
return F.conv2d(
|
35 |
+
input=input,
|
36 |
+
weight=weight,
|
37 |
+
bias=bias,
|
38 |
+
stride=stride,
|
39 |
+
padding=padding,
|
40 |
+
dilation=dilation,
|
41 |
+
groups=groups,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def conv_transpose2d(
|
46 |
+
input,
|
47 |
+
weight,
|
48 |
+
bias=None,
|
49 |
+
stride=1,
|
50 |
+
padding=0,
|
51 |
+
output_padding=0,
|
52 |
+
groups=1,
|
53 |
+
dilation=1,
|
54 |
+
):
|
55 |
+
if could_use_op(input):
|
56 |
+
return conv2d_gradfix(
|
57 |
+
transpose=True,
|
58 |
+
weight_shape=weight.shape,
|
59 |
+
stride=stride,
|
60 |
+
padding=padding,
|
61 |
+
output_padding=output_padding,
|
62 |
+
groups=groups,
|
63 |
+
dilation=dilation,
|
64 |
+
).apply(input, weight, bias)
|
65 |
+
|
66 |
+
return F.conv_transpose2d(
|
67 |
+
input=input,
|
68 |
+
weight=weight,
|
69 |
+
bias=bias,
|
70 |
+
stride=stride,
|
71 |
+
padding=padding,
|
72 |
+
output_padding=output_padding,
|
73 |
+
dilation=dilation,
|
74 |
+
groups=groups,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def could_use_op(input):
|
79 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
80 |
+
return False
|
81 |
+
|
82 |
+
if input.device.type != "cuda":
|
83 |
+
return False
|
84 |
+
|
85 |
+
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
86 |
+
return True
|
87 |
+
|
88 |
+
warnings.warn(
|
89 |
+
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
90 |
+
)
|
91 |
+
|
92 |
+
return False
|
93 |
+
|
94 |
+
|
95 |
+
def ensure_tuple(xs, ndim):
|
96 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
97 |
+
|
98 |
+
return xs
|
99 |
+
|
100 |
+
|
101 |
+
conv2d_gradfix_cache = dict()
|
102 |
+
|
103 |
+
|
104 |
+
def conv2d_gradfix(
|
105 |
+
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
106 |
+
):
|
107 |
+
ndim = 2
|
108 |
+
weight_shape = tuple(weight_shape)
|
109 |
+
stride = ensure_tuple(stride, ndim)
|
110 |
+
padding = ensure_tuple(padding, ndim)
|
111 |
+
output_padding = ensure_tuple(output_padding, ndim)
|
112 |
+
dilation = ensure_tuple(dilation, ndim)
|
113 |
+
|
114 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
115 |
+
if key in conv2d_gradfix_cache:
|
116 |
+
return conv2d_gradfix_cache[key]
|
117 |
+
|
118 |
+
common_kwargs = dict(
|
119 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups
|
120 |
+
)
|
121 |
+
|
122 |
+
def calc_output_padding(input_shape, output_shape):
|
123 |
+
if transpose:
|
124 |
+
return [0, 0]
|
125 |
+
|
126 |
+
return [
|
127 |
+
input_shape[i + 2]
|
128 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
129 |
+
- (1 - 2 * padding[i])
|
130 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
131 |
+
for i in range(ndim)
|
132 |
+
]
|
133 |
+
|
134 |
+
class Conv2d(autograd.Function):
|
135 |
+
@staticmethod
|
136 |
+
def forward(ctx, input, weight, bias):
|
137 |
+
if not transpose:
|
138 |
+
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
139 |
+
|
140 |
+
else:
|
141 |
+
out = F.conv_transpose2d(
|
142 |
+
input=input,
|
143 |
+
weight=weight,
|
144 |
+
bias=bias,
|
145 |
+
output_padding=output_padding,
|
146 |
+
**common_kwargs,
|
147 |
+
)
|
148 |
+
|
149 |
+
ctx.save_for_backward(input, weight)
|
150 |
+
|
151 |
+
return out
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def backward(ctx, grad_output):
|
155 |
+
input, weight = ctx.saved_tensors
|
156 |
+
grad_input, grad_weight, grad_bias = None, None, None
|
157 |
+
|
158 |
+
if ctx.needs_input_grad[0]:
|
159 |
+
p = calc_output_padding(
|
160 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
161 |
+
)
|
162 |
+
grad_input = conv2d_gradfix(
|
163 |
+
transpose=(not transpose),
|
164 |
+
weight_shape=weight_shape,
|
165 |
+
output_padding=p,
|
166 |
+
**common_kwargs,
|
167 |
+
).apply(grad_output, weight, None)
|
168 |
+
|
169 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
170 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
171 |
+
|
172 |
+
if ctx.needs_input_grad[2]:
|
173 |
+
grad_bias = grad_output.sum((0, 2, 3))
|
174 |
+
|
175 |
+
return grad_input, grad_weight, grad_bias
|
176 |
+
|
177 |
+
class Conv2dGradWeight(autograd.Function):
|
178 |
+
@staticmethod
|
179 |
+
def forward(ctx, grad_output, input):
|
180 |
+
op = torch._C._jit_get_operation(
|
181 |
+
"aten::cudnn_convolution_backward_weight"
|
182 |
+
if not transpose
|
183 |
+
else "aten::cudnn_convolution_transpose_backward_weight"
|
184 |
+
)
|
185 |
+
flags = [
|
186 |
+
torch.backends.cudnn.benchmark,
|
187 |
+
torch.backends.cudnn.deterministic,
|
188 |
+
torch.backends.cudnn.allow_tf32,
|
189 |
+
]
|
190 |
+
grad_weight = op(
|
191 |
+
weight_shape,
|
192 |
+
grad_output,
|
193 |
+
input,
|
194 |
+
padding,
|
195 |
+
stride,
|
196 |
+
dilation,
|
197 |
+
groups,
|
198 |
+
*flags,
|
199 |
+
)
|
200 |
+
ctx.save_for_backward(grad_output, input)
|
201 |
+
|
202 |
+
return grad_weight
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def backward(ctx, grad_grad_weight):
|
206 |
+
grad_output, input = ctx.saved_tensors
|
207 |
+
grad_grad_output, grad_grad_input = None, None
|
208 |
+
|
209 |
+
if ctx.needs_input_grad[0]:
|
210 |
+
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
211 |
+
|
212 |
+
if ctx.needs_input_grad[1]:
|
213 |
+
p = calc_output_padding(
|
214 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
215 |
+
)
|
216 |
+
grad_grad_input = conv2d_gradfix(
|
217 |
+
transpose=(not transpose),
|
218 |
+
weight_shape=weight_shape,
|
219 |
+
output_padding=p,
|
220 |
+
**common_kwargs,
|
221 |
+
).apply(grad_output, grad_grad_weight, None)
|
222 |
+
|
223 |
+
return grad_grad_output, grad_grad_input
|
224 |
+
|
225 |
+
conv2d_gradfix_cache[key] = Conv2d
|
226 |
+
|
227 |
+
return Conv2d
|
op/fused_act.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
|
10 |
+
module_path = os.path.dirname(__file__)
|
11 |
+
fused = load(
|
12 |
+
"fused",
|
13 |
+
sources=[
|
14 |
+
os.path.join(module_path, "fused_bias_act.cpp"),
|
15 |
+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
16 |
+
],
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
23 |
+
ctx.save_for_backward(out)
|
24 |
+
ctx.negative_slope = negative_slope
|
25 |
+
ctx.scale = scale
|
26 |
+
|
27 |
+
empty = grad_output.new_empty(0)
|
28 |
+
|
29 |
+
grad_input = fused.fused_bias_act(
|
30 |
+
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
|
31 |
+
)
|
32 |
+
|
33 |
+
dim = [0]
|
34 |
+
|
35 |
+
if grad_input.ndim > 2:
|
36 |
+
dim += list(range(2, grad_input.ndim))
|
37 |
+
|
38 |
+
if bias:
|
39 |
+
grad_bias = grad_input.sum(dim).detach()
|
40 |
+
|
41 |
+
else:
|
42 |
+
grad_bias = empty
|
43 |
+
|
44 |
+
return grad_input, grad_bias
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
48 |
+
out, = ctx.saved_tensors
|
49 |
+
gradgrad_out = fused.fused_bias_act(
|
50 |
+
gradgrad_input.contiguous(),
|
51 |
+
gradgrad_bias,
|
52 |
+
out,
|
53 |
+
3,
|
54 |
+
1,
|
55 |
+
ctx.negative_slope,
|
56 |
+
ctx.scale,
|
57 |
+
)
|
58 |
+
|
59 |
+
return gradgrad_out, None, None, None, None
|
60 |
+
|
61 |
+
|
62 |
+
class FusedLeakyReLUFunction(Function):
|
63 |
+
@staticmethod
|
64 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
65 |
+
empty = input.new_empty(0)
|
66 |
+
|
67 |
+
ctx.bias = bias is not None
|
68 |
+
|
69 |
+
if bias is None:
|
70 |
+
bias = empty
|
71 |
+
|
72 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
73 |
+
ctx.save_for_backward(out)
|
74 |
+
ctx.negative_slope = negative_slope
|
75 |
+
ctx.scale = scale
|
76 |
+
|
77 |
+
return out
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def backward(ctx, grad_output):
|
81 |
+
out, = ctx.saved_tensors
|
82 |
+
|
83 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
84 |
+
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
|
85 |
+
)
|
86 |
+
|
87 |
+
if not ctx.bias:
|
88 |
+
grad_bias = None
|
89 |
+
|
90 |
+
return grad_input, grad_bias, None, None
|
91 |
+
|
92 |
+
|
93 |
+
class FusedLeakyReLU(nn.Module):
|
94 |
+
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
if bias:
|
98 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
99 |
+
|
100 |
+
else:
|
101 |
+
self.bias = None
|
102 |
+
|
103 |
+
self.negative_slope = negative_slope
|
104 |
+
self.scale = scale
|
105 |
+
|
106 |
+
def forward(self, input):
|
107 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
108 |
+
|
109 |
+
|
110 |
+
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
111 |
+
if input.device.type == "cpu":
|
112 |
+
if bias is not None:
|
113 |
+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
114 |
+
return (
|
115 |
+
F.leaky_relu(
|
116 |
+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
117 |
+
)
|
118 |
+
* scale
|
119 |
+
)
|
120 |
+
|
121 |
+
else:
|
122 |
+
return F.leaky_relu(input, negative_slope=0.2) * scale
|
123 |
+
|
124 |
+
else:
|
125 |
+
return FusedLeakyReLUFunction.apply(
|
126 |
+
input.contiguous(), bias, negative_slope, scale
|
127 |
+
)
|
op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
6 |
+
const torch::Tensor &bias,
|
7 |
+
const torch::Tensor &refer, int act, int grad,
|
8 |
+
float alpha, float scale);
|
9 |
+
|
10 |
+
#define CHECK_CUDA(x) \
|
11 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) \
|
13 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
14 |
+
#define CHECK_INPUT(x) \
|
15 |
+
CHECK_CUDA(x); \
|
16 |
+
CHECK_CONTIGUOUS(x)
|
17 |
+
|
18 |
+
torch::Tensor fused_bias_act(const torch::Tensor &input,
|
19 |
+
const torch::Tensor &bias,
|
20 |
+
const torch::Tensor &refer, int act, int grad,
|
21 |
+
float alpha, float scale) {
|
22 |
+
CHECK_INPUT(input);
|
23 |
+
CHECK_INPUT(bias);
|
24 |
+
|
25 |
+
at::DeviceGuard guard(input.device());
|
26 |
+
|
27 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
28 |
+
}
|
29 |
+
|
30 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
31 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
32 |
+
}
|
op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
|
15 |
+
#include <cuda.h>
|
16 |
+
#include <cuda_runtime.h>
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void
|
20 |
+
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
|
21 |
+
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
|
22 |
+
scalar_t scale, int loop_x, int size_x, int step_b,
|
23 |
+
int size_b, int use_bias, int use_ref) {
|
24 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
25 |
+
|
26 |
+
scalar_t zero = 0.0;
|
27 |
+
|
28 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
|
29 |
+
loop_idx++, xi += blockDim.x) {
|
30 |
+
scalar_t x = p_x[xi];
|
31 |
+
|
32 |
+
if (use_bias) {
|
33 |
+
x += p_b[(xi / step_b) % size_b];
|
34 |
+
}
|
35 |
+
|
36 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
37 |
+
|
38 |
+
scalar_t y;
|
39 |
+
|
40 |
+
switch (act * 10 + grad) {
|
41 |
+
default:
|
42 |
+
case 10:
|
43 |
+
y = x;
|
44 |
+
break;
|
45 |
+
case 11:
|
46 |
+
y = x;
|
47 |
+
break;
|
48 |
+
case 12:
|
49 |
+
y = 0.0;
|
50 |
+
break;
|
51 |
+
|
52 |
+
case 30:
|
53 |
+
y = (x > 0.0) ? x : x * alpha;
|
54 |
+
break;
|
55 |
+
case 31:
|
56 |
+
y = (ref > 0.0) ? x : x * alpha;
|
57 |
+
break;
|
58 |
+
case 32:
|
59 |
+
y = 0.0;
|
60 |
+
break;
|
61 |
+
}
|
62 |
+
|
63 |
+
out[xi] = y * scale;
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
68 |
+
const torch::Tensor &bias,
|
69 |
+
const torch::Tensor &refer, int act, int grad,
|
70 |
+
float alpha, float scale) {
|
71 |
+
int curDevice = -1;
|
72 |
+
cudaGetDevice(&curDevice);
|
73 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
74 |
+
|
75 |
+
auto x = input.contiguous();
|
76 |
+
auto b = bias.contiguous();
|
77 |
+
auto ref = refer.contiguous();
|
78 |
+
|
79 |
+
int use_bias = b.numel() ? 1 : 0;
|
80 |
+
int use_ref = ref.numel() ? 1 : 0;
|
81 |
+
|
82 |
+
int size_x = x.numel();
|
83 |
+
int size_b = b.numel();
|
84 |
+
int step_b = 1;
|
85 |
+
|
86 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
87 |
+
step_b *= x.size(i);
|
88 |
+
}
|
89 |
+
|
90 |
+
int loop_x = 4;
|
91 |
+
int block_size = 4 * 32;
|
92 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
93 |
+
|
94 |
+
auto y = torch::empty_like(x);
|
95 |
+
|
96 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
97 |
+
x.scalar_type(), "fused_bias_act_kernel", [&] {
|
98 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
99 |
+
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
100 |
+
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
|
101 |
+
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
|
102 |
+
});
|
103 |
+
|
104 |
+
return y;
|
105 |
+
}
|
op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
5 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
6 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
7 |
+
int pad_y0, int pad_y1);
|
8 |
+
|
9 |
+
#define CHECK_CUDA(x) \
|
10 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
+
#define CHECK_CONTIGUOUS(x) \
|
12 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
13 |
+
#define CHECK_INPUT(x) \
|
14 |
+
CHECK_CUDA(x); \
|
15 |
+
CHECK_CONTIGUOUS(x)
|
16 |
+
|
17 |
+
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
|
18 |
+
int up_x, int up_y, int down_x, int down_y, int pad_x0,
|
19 |
+
int pad_x1, int pad_y0, int pad_y1) {
|
20 |
+
CHECK_INPUT(input);
|
21 |
+
CHECK_INPUT(kernel);
|
22 |
+
|
23 |
+
at::DeviceGuard guard(input.device());
|
24 |
+
|
25 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
26 |
+
pad_y0, pad_y1);
|
27 |
+
}
|
28 |
+
|
29 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
30 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
31 |
+
}
|
op/upfirdn2d.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import abc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
|
10 |
+
module_path = os.path.dirname(__file__)
|
11 |
+
upfirdn2d_op = load(
|
12 |
+
"upfirdn2d",
|
13 |
+
sources=[
|
14 |
+
os.path.join(module_path, "upfirdn2d.cpp"),
|
15 |
+
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
16 |
+
],
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class UpFirDn2dBackward(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(
|
23 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
24 |
+
):
|
25 |
+
|
26 |
+
up_x, up_y = up
|
27 |
+
down_x, down_y = down
|
28 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
29 |
+
|
30 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
31 |
+
|
32 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
33 |
+
grad_output,
|
34 |
+
grad_kernel,
|
35 |
+
down_x,
|
36 |
+
down_y,
|
37 |
+
up_x,
|
38 |
+
up_y,
|
39 |
+
g_pad_x0,
|
40 |
+
g_pad_x1,
|
41 |
+
g_pad_y0,
|
42 |
+
g_pad_y1,
|
43 |
+
)
|
44 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
45 |
+
|
46 |
+
ctx.save_for_backward(kernel)
|
47 |
+
|
48 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
49 |
+
|
50 |
+
ctx.up_x = up_x
|
51 |
+
ctx.up_y = up_y
|
52 |
+
ctx.down_x = down_x
|
53 |
+
ctx.down_y = down_y
|
54 |
+
ctx.pad_x0 = pad_x0
|
55 |
+
ctx.pad_x1 = pad_x1
|
56 |
+
ctx.pad_y0 = pad_y0
|
57 |
+
ctx.pad_y1 = pad_y1
|
58 |
+
ctx.in_size = in_size
|
59 |
+
ctx.out_size = out_size
|
60 |
+
|
61 |
+
return grad_input
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, gradgrad_input):
|
65 |
+
kernel, = ctx.saved_tensors
|
66 |
+
|
67 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
68 |
+
|
69 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
70 |
+
gradgrad_input,
|
71 |
+
kernel,
|
72 |
+
ctx.up_x,
|
73 |
+
ctx.up_y,
|
74 |
+
ctx.down_x,
|
75 |
+
ctx.down_y,
|
76 |
+
ctx.pad_x0,
|
77 |
+
ctx.pad_x1,
|
78 |
+
ctx.pad_y0,
|
79 |
+
ctx.pad_y1,
|
80 |
+
)
|
81 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
82 |
+
gradgrad_out = gradgrad_out.view(
|
83 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
84 |
+
)
|
85 |
+
|
86 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
87 |
+
|
88 |
+
|
89 |
+
class UpFirDn2d(Function):
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, input, kernel, up, down, pad):
|
92 |
+
up_x, up_y = up
|
93 |
+
down_x, down_y = down
|
94 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
95 |
+
|
96 |
+
kernel_h, kernel_w = kernel.shape
|
97 |
+
batch, channel, in_h, in_w = input.shape
|
98 |
+
ctx.in_size = input.shape
|
99 |
+
|
100 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
101 |
+
|
102 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
103 |
+
|
104 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
105 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
106 |
+
ctx.out_size = (out_h, out_w)
|
107 |
+
|
108 |
+
ctx.up = (up_x, up_y)
|
109 |
+
ctx.down = (down_x, down_y)
|
110 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
111 |
+
|
112 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
113 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
114 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
115 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
116 |
+
|
117 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
118 |
+
|
119 |
+
out = upfirdn2d_op.upfirdn2d(
|
120 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
121 |
+
)
|
122 |
+
# out = out.view(major, out_h, out_w, minor)
|
123 |
+
out = out.view(-1, channel, out_h, out_w)
|
124 |
+
|
125 |
+
return out
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def backward(ctx, grad_output):
|
129 |
+
kernel, grad_kernel = ctx.saved_tensors
|
130 |
+
|
131 |
+
grad_input = None
|
132 |
+
|
133 |
+
if ctx.needs_input_grad[0]:
|
134 |
+
grad_input = UpFirDn2dBackward.apply(
|
135 |
+
grad_output,
|
136 |
+
kernel,
|
137 |
+
grad_kernel,
|
138 |
+
ctx.up,
|
139 |
+
ctx.down,
|
140 |
+
ctx.pad,
|
141 |
+
ctx.g_pad,
|
142 |
+
ctx.in_size,
|
143 |
+
ctx.out_size,
|
144 |
+
)
|
145 |
+
|
146 |
+
return grad_input, None, None, None, None
|
147 |
+
|
148 |
+
|
149 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
150 |
+
if not isinstance(up, abc.Iterable):
|
151 |
+
up = (up, up)
|
152 |
+
|
153 |
+
if not isinstance(down, abc.Iterable):
|
154 |
+
down = (down, down)
|
155 |
+
|
156 |
+
if len(pad) == 2:
|
157 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
158 |
+
|
159 |
+
if input.device.type == "cpu":
|
160 |
+
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
|
161 |
+
|
162 |
+
else:
|
163 |
+
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
164 |
+
|
165 |
+
return out
|
166 |
+
|
167 |
+
|
168 |
+
def upfirdn2d_native(
|
169 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
170 |
+
):
|
171 |
+
_, channel, in_h, in_w = input.shape
|
172 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
173 |
+
|
174 |
+
_, in_h, in_w, minor = input.shape
|
175 |
+
kernel_h, kernel_w = kernel.shape
|
176 |
+
|
177 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
178 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
179 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
180 |
+
|
181 |
+
out = F.pad(
|
182 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
183 |
+
)
|
184 |
+
out = out[
|
185 |
+
:,
|
186 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
187 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
188 |
+
:,
|
189 |
+
]
|
190 |
+
|
191 |
+
out = out.permute(0, 3, 1, 2)
|
192 |
+
out = out.reshape(
|
193 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
194 |
+
)
|
195 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
196 |
+
out = F.conv2d(out, w)
|
197 |
+
out = out.reshape(
|
198 |
+
-1,
|
199 |
+
minor,
|
200 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
201 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
202 |
+
)
|
203 |
+
out = out.permute(0, 2, 3, 1)
|
204 |
+
out = out[:, ::down_y, ::down_x, :]
|
205 |
+
|
206 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
207 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
208 |
+
|
209 |
+
return out.view(-1, channel, out_h, out_w)
|
op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
18 |
+
int c = a / b;
|
19 |
+
|
20 |
+
if (c * b > a) {
|
21 |
+
c--;
|
22 |
+
}
|
23 |
+
|
24 |
+
return c;
|
25 |
+
}
|
26 |
+
|
27 |
+
struct UpFirDn2DKernelParams {
|
28 |
+
int up_x;
|
29 |
+
int up_y;
|
30 |
+
int down_x;
|
31 |
+
int down_y;
|
32 |
+
int pad_x0;
|
33 |
+
int pad_x1;
|
34 |
+
int pad_y0;
|
35 |
+
int pad_y1;
|
36 |
+
|
37 |
+
int major_dim;
|
38 |
+
int in_h;
|
39 |
+
int in_w;
|
40 |
+
int minor_dim;
|
41 |
+
int kernel_h;
|
42 |
+
int kernel_w;
|
43 |
+
int out_h;
|
44 |
+
int out_w;
|
45 |
+
int loop_major;
|
46 |
+
int loop_x;
|
47 |
+
};
|
48 |
+
|
49 |
+
template <typename scalar_t>
|
50 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
51 |
+
const scalar_t *kernel,
|
52 |
+
const UpFirDn2DKernelParams p) {
|
53 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
54 |
+
int out_y = minor_idx / p.minor_dim;
|
55 |
+
minor_idx -= out_y * p.minor_dim;
|
56 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
57 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
58 |
+
|
59 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
60 |
+
major_idx_base >= p.major_dim) {
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
|
64 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
65 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
66 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
67 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
68 |
+
|
69 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
70 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
71 |
+
loop_major++, major_idx++) {
|
72 |
+
for (int loop_x = 0, out_x = out_x_base;
|
73 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
74 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
75 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
76 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
77 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
78 |
+
|
79 |
+
const scalar_t *x_p =
|
80 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
81 |
+
minor_idx];
|
82 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
83 |
+
int x_px = p.minor_dim;
|
84 |
+
int k_px = -p.up_x;
|
85 |
+
int x_py = p.in_w * p.minor_dim;
|
86 |
+
int k_py = -p.up_y * p.kernel_w;
|
87 |
+
|
88 |
+
scalar_t v = 0.0f;
|
89 |
+
|
90 |
+
for (int y = 0; y < h; y++) {
|
91 |
+
for (int x = 0; x < w; x++) {
|
92 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
93 |
+
x_p += x_px;
|
94 |
+
k_p += k_px;
|
95 |
+
}
|
96 |
+
|
97 |
+
x_p += x_py - w * x_px;
|
98 |
+
k_p += k_py - w * k_px;
|
99 |
+
}
|
100 |
+
|
101 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
102 |
+
minor_idx] = v;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
108 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
109 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
110 |
+
const scalar_t *kernel,
|
111 |
+
const UpFirDn2DKernelParams p) {
|
112 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
113 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
114 |
+
|
115 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
116 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
117 |
+
|
118 |
+
int minor_idx = blockIdx.x;
|
119 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
120 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
121 |
+
tile_out_y *= tile_out_h;
|
122 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
123 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
124 |
+
|
125 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
126 |
+
major_idx_base >= p.major_dim) {
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
131 |
+
tap_idx += blockDim.x) {
|
132 |
+
int ky = tap_idx / kernel_w;
|
133 |
+
int kx = tap_idx - ky * kernel_w;
|
134 |
+
scalar_t v = 0.0;
|
135 |
+
|
136 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
137 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
138 |
+
}
|
139 |
+
|
140 |
+
sk[ky][kx] = v;
|
141 |
+
}
|
142 |
+
|
143 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
144 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
145 |
+
loop_major++, major_idx++) {
|
146 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
147 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
148 |
+
loop_x++, tile_out_x += tile_out_w) {
|
149 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
150 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
151 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
152 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
153 |
+
|
154 |
+
__syncthreads();
|
155 |
+
|
156 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
157 |
+
in_idx += blockDim.x) {
|
158 |
+
int rel_in_y = in_idx / tile_in_w;
|
159 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
160 |
+
int in_x = rel_in_x + tile_in_x;
|
161 |
+
int in_y = rel_in_y + tile_in_y;
|
162 |
+
|
163 |
+
scalar_t v = 0.0;
|
164 |
+
|
165 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
166 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
167 |
+
p.minor_dim +
|
168 |
+
minor_idx];
|
169 |
+
}
|
170 |
+
|
171 |
+
sx[rel_in_y][rel_in_x] = v;
|
172 |
+
}
|
173 |
+
|
174 |
+
__syncthreads();
|
175 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
176 |
+
out_idx += blockDim.x) {
|
177 |
+
int rel_out_y = out_idx / tile_out_w;
|
178 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
179 |
+
int out_x = rel_out_x + tile_out_x;
|
180 |
+
int out_y = rel_out_y + tile_out_y;
|
181 |
+
|
182 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
183 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
184 |
+
int in_x = floor_div(mid_x, up_x);
|
185 |
+
int in_y = floor_div(mid_y, up_y);
|
186 |
+
int rel_in_x = in_x - tile_in_x;
|
187 |
+
int rel_in_y = in_y - tile_in_y;
|
188 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
189 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
190 |
+
|
191 |
+
scalar_t v = 0.0;
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
195 |
+
#pragma unroll
|
196 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
197 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
198 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
199 |
+
|
200 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
201 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
202 |
+
minor_idx] = v;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
210 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
211 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
212 |
+
int pad_y0, int pad_y1) {
|
213 |
+
int curDevice = -1;
|
214 |
+
cudaGetDevice(&curDevice);
|
215 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
216 |
+
|
217 |
+
UpFirDn2DKernelParams p;
|
218 |
+
|
219 |
+
auto x = input.contiguous();
|
220 |
+
auto k = kernel.contiguous();
|
221 |
+
|
222 |
+
p.major_dim = x.size(0);
|
223 |
+
p.in_h = x.size(1);
|
224 |
+
p.in_w = x.size(2);
|
225 |
+
p.minor_dim = x.size(3);
|
226 |
+
p.kernel_h = k.size(0);
|
227 |
+
p.kernel_w = k.size(1);
|
228 |
+
p.up_x = up_x;
|
229 |
+
p.up_y = up_y;
|
230 |
+
p.down_x = down_x;
|
231 |
+
p.down_y = down_y;
|
232 |
+
p.pad_x0 = pad_x0;
|
233 |
+
p.pad_x1 = pad_x1;
|
234 |
+
p.pad_y0 = pad_y0;
|
235 |
+
p.pad_y1 = pad_y1;
|
236 |
+
|
237 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
238 |
+
p.down_y;
|
239 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
240 |
+
p.down_x;
|
241 |
+
|
242 |
+
auto out =
|
243 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
244 |
+
|
245 |
+
int mode = -1;
|
246 |
+
|
247 |
+
int tile_out_h = -1;
|
248 |
+
int tile_out_w = -1;
|
249 |
+
|
250 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
251 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
252 |
+
mode = 1;
|
253 |
+
tile_out_h = 16;
|
254 |
+
tile_out_w = 64;
|
255 |
+
}
|
256 |
+
|
257 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
258 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
259 |
+
mode = 2;
|
260 |
+
tile_out_h = 16;
|
261 |
+
tile_out_w = 64;
|
262 |
+
}
|
263 |
+
|
264 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
265 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
266 |
+
mode = 3;
|
267 |
+
tile_out_h = 16;
|
268 |
+
tile_out_w = 64;
|
269 |
+
}
|
270 |
+
|
271 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
272 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
273 |
+
mode = 4;
|
274 |
+
tile_out_h = 16;
|
275 |
+
tile_out_w = 64;
|
276 |
+
}
|
277 |
+
|
278 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
279 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
280 |
+
mode = 5;
|
281 |
+
tile_out_h = 8;
|
282 |
+
tile_out_w = 32;
|
283 |
+
}
|
284 |
+
|
285 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
286 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
287 |
+
mode = 6;
|
288 |
+
tile_out_h = 8;
|
289 |
+
tile_out_w = 32;
|
290 |
+
}
|
291 |
+
|
292 |
+
dim3 block_size;
|
293 |
+
dim3 grid_size;
|
294 |
+
|
295 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
296 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
297 |
+
p.loop_x = 1;
|
298 |
+
block_size = dim3(32 * 8, 1, 1);
|
299 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
300 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
301 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
302 |
+
} else {
|
303 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
304 |
+
p.loop_x = 4;
|
305 |
+
block_size = dim3(4, 32, 1);
|
306 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
307 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
308 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
309 |
+
}
|
310 |
+
|
311 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
312 |
+
switch (mode) {
|
313 |
+
case 1:
|
314 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
315 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
316 |
+
x.data_ptr<scalar_t>(),
|
317 |
+
k.data_ptr<scalar_t>(), p);
|
318 |
+
|
319 |
+
break;
|
320 |
+
|
321 |
+
case 2:
|
322 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
323 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
324 |
+
x.data_ptr<scalar_t>(),
|
325 |
+
k.data_ptr<scalar_t>(), p);
|
326 |
+
|
327 |
+
break;
|
328 |
+
|
329 |
+
case 3:
|
330 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
331 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
332 |
+
x.data_ptr<scalar_t>(),
|
333 |
+
k.data_ptr<scalar_t>(), p);
|
334 |
+
|
335 |
+
break;
|
336 |
+
|
337 |
+
case 4:
|
338 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
339 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
340 |
+
x.data_ptr<scalar_t>(),
|
341 |
+
k.data_ptr<scalar_t>(), p);
|
342 |
+
|
343 |
+
break;
|
344 |
+
|
345 |
+
case 5:
|
346 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
347 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
348 |
+
x.data_ptr<scalar_t>(),
|
349 |
+
k.data_ptr<scalar_t>(), p);
|
350 |
+
|
351 |
+
break;
|
352 |
+
|
353 |
+
case 6:
|
354 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
355 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
356 |
+
x.data_ptr<scalar_t>(),
|
357 |
+
k.data_ptr<scalar_t>(), p);
|
358 |
+
|
359 |
+
break;
|
360 |
+
|
361 |
+
default:
|
362 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
363 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
364 |
+
k.data_ptr<scalar_t>(), p);
|
365 |
+
}
|
366 |
+
});
|
367 |
+
|
368 |
+
return out;
|
369 |
+
}
|
ppl.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import lpips
|
9 |
+
from model import Generator
|
10 |
+
|
11 |
+
|
12 |
+
def normalize(x):
|
13 |
+
return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
|
14 |
+
|
15 |
+
|
16 |
+
def slerp(a, b, t):
|
17 |
+
a = normalize(a)
|
18 |
+
b = normalize(b)
|
19 |
+
d = (a * b).sum(-1, keepdim=True)
|
20 |
+
p = t * torch.acos(d)
|
21 |
+
c = normalize(b - d * a)
|
22 |
+
d = a * torch.cos(p) + c * torch.sin(p)
|
23 |
+
|
24 |
+
return normalize(d)
|
25 |
+
|
26 |
+
|
27 |
+
def lerp(a, b, t):
|
28 |
+
return a + (b - a) * t
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
device = "cuda"
|
33 |
+
|
34 |
+
parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
|
35 |
+
|
36 |
+
parser.add_argument(
|
37 |
+
"--space", choices=["z", "w"], help="space that PPL calculated with"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--batch", type=int, default=64, help="batch size for the models"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--n_sample",
|
44 |
+
type=int,
|
45 |
+
default=5000,
|
46 |
+
help="number of the samples for calculating PPL",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--size", type=int, default=256, help="output image sizes of the generator"
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--eps", type=float, default=1e-4, help="epsilon for numerical stability"
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--crop", action="store_true", help="apply center crop to the images"
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--sampling",
|
59 |
+
default="end",
|
60 |
+
choices=["end", "full"],
|
61 |
+
help="set endpoint sampling method",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"ckpt", metavar="CHECKPOINT", help="path to the model checkpoints"
|
65 |
+
)
|
66 |
+
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
latent_dim = 512
|
70 |
+
|
71 |
+
ckpt = torch.load(args.ckpt)
|
72 |
+
|
73 |
+
g = Generator(args.size, latent_dim, 8).to(device)
|
74 |
+
g.load_state_dict(ckpt["g_ema"])
|
75 |
+
g.eval()
|
76 |
+
|
77 |
+
percept = lpips.PerceptualLoss(
|
78 |
+
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
79 |
+
)
|
80 |
+
|
81 |
+
distances = []
|
82 |
+
|
83 |
+
n_batch = args.n_sample // args.batch
|
84 |
+
resid = args.n_sample - (n_batch * args.batch)
|
85 |
+
batch_sizes = [args.batch] * n_batch + [resid]
|
86 |
+
|
87 |
+
with torch.no_grad():
|
88 |
+
for batch in tqdm(batch_sizes):
|
89 |
+
noise = g.make_noise()
|
90 |
+
|
91 |
+
inputs = torch.randn([batch * 2, latent_dim], device=device)
|
92 |
+
if args.sampling == "full":
|
93 |
+
lerp_t = torch.rand(batch, device=device)
|
94 |
+
else:
|
95 |
+
lerp_t = torch.zeros(batch, device=device)
|
96 |
+
|
97 |
+
if args.space == "w":
|
98 |
+
latent = g.get_latent(inputs)
|
99 |
+
latent_t0, latent_t1 = latent[::2], latent[1::2]
|
100 |
+
latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
|
101 |
+
latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
|
102 |
+
latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
|
103 |
+
|
104 |
+
image, _ = g([latent_e], input_is_latent=True, noise=noise)
|
105 |
+
|
106 |
+
if args.crop:
|
107 |
+
c = image.shape[2] // 8
|
108 |
+
image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
|
109 |
+
|
110 |
+
factor = image.shape[2] // 256
|
111 |
+
|
112 |
+
if factor > 1:
|
113 |
+
image = F.interpolate(
|
114 |
+
image, size=(256, 256), mode="bilinear", align_corners=False
|
115 |
+
)
|
116 |
+
|
117 |
+
dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
|
118 |
+
args.eps ** 2
|
119 |
+
)
|
120 |
+
distances.append(dist.to("cpu").numpy())
|
121 |
+
|
122 |
+
distances = np.concatenate(distances, 0)
|
123 |
+
|
124 |
+
lo = np.percentile(distances, 1, interpolation="lower")
|
125 |
+
hi = np.percentile(distances, 99, interpolation="higher")
|
126 |
+
filtered_dist = np.extract(
|
127 |
+
np.logical_and(lo <= distances, distances <= hi), distances
|
128 |
+
)
|
129 |
+
|
130 |
+
print("ppl:", filtered_dist.mean())
|
prepare_data.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from io import BytesIO
|
3 |
+
import multiprocessing
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import lmdb
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torchvision import datasets
|
10 |
+
from torchvision.transforms import functional as trans_fn
|
11 |
+
|
12 |
+
|
13 |
+
def resize_and_convert(img, size, resample, quality=100):
|
14 |
+
img = trans_fn.resize(img, size, resample)
|
15 |
+
img = trans_fn.center_crop(img, size)
|
16 |
+
buffer = BytesIO()
|
17 |
+
img.save(buffer, format="jpeg", quality=quality)
|
18 |
+
val = buffer.getvalue()
|
19 |
+
|
20 |
+
return val
|
21 |
+
|
22 |
+
|
23 |
+
def resize_multiple(
|
24 |
+
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
|
25 |
+
):
|
26 |
+
imgs = []
|
27 |
+
|
28 |
+
for size in sizes:
|
29 |
+
imgs.append(resize_and_convert(img, size, resample, quality))
|
30 |
+
|
31 |
+
return imgs
|
32 |
+
|
33 |
+
|
34 |
+
def resize_worker(img_file, sizes, resample):
|
35 |
+
i, file = img_file
|
36 |
+
img = Image.open(file)
|
37 |
+
img = img.convert("RGB")
|
38 |
+
out = resize_multiple(img, sizes=sizes, resample=resample)
|
39 |
+
|
40 |
+
return i, out
|
41 |
+
|
42 |
+
|
43 |
+
def prepare(
|
44 |
+
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
|
45 |
+
):
|
46 |
+
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
|
47 |
+
|
48 |
+
files = sorted(dataset.imgs, key=lambda x: x[0])
|
49 |
+
files = [(i, file) for i, (file, label) in enumerate(files)]
|
50 |
+
total = 0
|
51 |
+
|
52 |
+
with multiprocessing.Pool(n_worker) as pool:
|
53 |
+
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
|
54 |
+
for size, img in zip(sizes, imgs):
|
55 |
+
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
|
56 |
+
|
57 |
+
with env.begin(write=True) as txn:
|
58 |
+
txn.put(key, img)
|
59 |
+
|
60 |
+
total += 1
|
61 |
+
|
62 |
+
with env.begin(write=True) as txn:
|
63 |
+
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
parser = argparse.ArgumentParser(description="Preprocess images for model training")
|
68 |
+
parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
|
69 |
+
parser.add_argument(
|
70 |
+
"--size",
|
71 |
+
type=str,
|
72 |
+
default="128,256,512,1024",
|
73 |
+
help="resolutions of images for the dataset",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--n_worker",
|
77 |
+
type=int,
|
78 |
+
default=8,
|
79 |
+
help="number of workers for preparing dataset",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--resample",
|
83 |
+
type=str,
|
84 |
+
default="lanczos",
|
85 |
+
help="resampling methods for resizing images",
|
86 |
+
)
|
87 |
+
parser.add_argument("path", type=str, help="path to the image dataset")
|
88 |
+
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
|
92 |
+
resample = resample_map[args.resample]
|
93 |
+
|
94 |
+
sizes = [int(s.strip()) for s in args.size.split(",")]
|
95 |
+
|
96 |
+
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
|
97 |
+
|
98 |
+
imgset = datasets.ImageFolder(args.path)
|
99 |
+
|
100 |
+
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
|
101 |
+
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
|
projector.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import optim
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torchvision import transforms
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import lpips
|
13 |
+
from model import Generator
|
14 |
+
|
15 |
+
|
16 |
+
def noise_regularize(noises):
|
17 |
+
loss = 0
|
18 |
+
|
19 |
+
for noise in noises:
|
20 |
+
size = noise.shape[2]
|
21 |
+
|
22 |
+
while True:
|
23 |
+
loss = (
|
24 |
+
loss
|
25 |
+
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
26 |
+
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
27 |
+
)
|
28 |
+
|
29 |
+
if size <= 8:
|
30 |
+
break
|
31 |
+
|
32 |
+
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
|
33 |
+
noise = noise.mean([3, 5])
|
34 |
+
size //= 2
|
35 |
+
|
36 |
+
return loss
|
37 |
+
|
38 |
+
|
39 |
+
def noise_normalize_(noises):
|
40 |
+
for noise in noises:
|
41 |
+
mean = noise.mean()
|
42 |
+
std = noise.std()
|
43 |
+
|
44 |
+
noise.data.add_(-mean).div_(std)
|
45 |
+
|
46 |
+
|
47 |
+
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
48 |
+
lr_ramp = min(1, (1 - t) / rampdown)
|
49 |
+
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
50 |
+
lr_ramp = lr_ramp * min(1, t / rampup)
|
51 |
+
|
52 |
+
return initial_lr * lr_ramp
|
53 |
+
|
54 |
+
|
55 |
+
def latent_noise(latent, strength):
|
56 |
+
noise = torch.randn_like(latent) * strength
|
57 |
+
|
58 |
+
return latent + noise
|
59 |
+
|
60 |
+
|
61 |
+
def make_image(tensor):
|
62 |
+
return (
|
63 |
+
tensor.detach()
|
64 |
+
.clamp_(min=-1, max=1)
|
65 |
+
.add(1)
|
66 |
+
.div_(2)
|
67 |
+
.mul(255)
|
68 |
+
.type(torch.uint8)
|
69 |
+
.permute(0, 2, 3, 1)
|
70 |
+
.to("cpu")
|
71 |
+
.numpy()
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
device = "cuda"
|
77 |
+
|
78 |
+
parser = argparse.ArgumentParser(
|
79 |
+
description="Image projector to the generator latent spaces"
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--ckpt", type=str, required=True, help="path to the model checkpoint"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--size", type=int, default=256, help="output image sizes of the generator"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--lr_rampup",
|
89 |
+
type=float,
|
90 |
+
default=0.05,
|
91 |
+
help="duration of the learning rate warmup",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--lr_rampdown",
|
95 |
+
type=float,
|
96 |
+
default=0.25,
|
97 |
+
help="duration of the learning rate decay",
|
98 |
+
)
|
99 |
+
parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
|
100 |
+
parser.add_argument(
|
101 |
+
"--noise", type=float, default=0.05, help="strength of the noise level"
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--noise_ramp",
|
105 |
+
type=float,
|
106 |
+
default=0.75,
|
107 |
+
help="duration of the noise level decay",
|
108 |
+
)
|
109 |
+
parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
|
110 |
+
parser.add_argument(
|
111 |
+
"--noise_regularize",
|
112 |
+
type=float,
|
113 |
+
default=1e5,
|
114 |
+
help="weight of the noise regularization",
|
115 |
+
)
|
116 |
+
parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
|
117 |
+
parser.add_argument(
|
118 |
+
"--w_plus",
|
119 |
+
action="store_true",
|
120 |
+
help="allow to use distinct latent codes to each layers",
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"files", metavar="FILES", nargs="+", help="path to image files to be projected"
|
124 |
+
)
|
125 |
+
|
126 |
+
args = parser.parse_args()
|
127 |
+
|
128 |
+
n_mean_latent = 10000
|
129 |
+
|
130 |
+
resize = min(args.size, 256)
|
131 |
+
|
132 |
+
transform = transforms.Compose(
|
133 |
+
[
|
134 |
+
transforms.Resize(resize),
|
135 |
+
transforms.CenterCrop(resize),
|
136 |
+
transforms.ToTensor(),
|
137 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
138 |
+
]
|
139 |
+
)
|
140 |
+
|
141 |
+
imgs = []
|
142 |
+
|
143 |
+
for imgfile in args.files:
|
144 |
+
img = transform(Image.open(imgfile).convert("RGB"))
|
145 |
+
imgs.append(img)
|
146 |
+
|
147 |
+
imgs = torch.stack(imgs, 0).to(device)
|
148 |
+
|
149 |
+
g_ema = Generator(args.size, 512, 8)
|
150 |
+
g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
|
151 |
+
g_ema.eval()
|
152 |
+
g_ema = g_ema.to(device)
|
153 |
+
|
154 |
+
with torch.no_grad():
|
155 |
+
noise_sample = torch.randn(n_mean_latent, 512, device=device)
|
156 |
+
latent_out = g_ema.style(noise_sample)
|
157 |
+
|
158 |
+
latent_mean = latent_out.mean(0)
|
159 |
+
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
|
160 |
+
|
161 |
+
percept = lpips.PerceptualLoss(
|
162 |
+
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
163 |
+
)
|
164 |
+
|
165 |
+
noises_single = g_ema.make_noise()
|
166 |
+
noises = []
|
167 |
+
for noise in noises_single:
|
168 |
+
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
|
169 |
+
|
170 |
+
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
|
171 |
+
|
172 |
+
if args.w_plus:
|
173 |
+
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
|
174 |
+
|
175 |
+
latent_in.requires_grad = True
|
176 |
+
|
177 |
+
for noise in noises:
|
178 |
+
noise.requires_grad = True
|
179 |
+
|
180 |
+
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
|
181 |
+
|
182 |
+
pbar = tqdm(range(args.step))
|
183 |
+
latent_path = []
|
184 |
+
|
185 |
+
for i in pbar:
|
186 |
+
t = i / args.step
|
187 |
+
lr = get_lr(t, args.lr)
|
188 |
+
optimizer.param_groups[0]["lr"] = lr
|
189 |
+
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
|
190 |
+
latent_n = latent_noise(latent_in, noise_strength.item())
|
191 |
+
|
192 |
+
img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
|
193 |
+
|
194 |
+
batch, channel, height, width = img_gen.shape
|
195 |
+
|
196 |
+
if height > 256:
|
197 |
+
factor = height // 256
|
198 |
+
|
199 |
+
img_gen = img_gen.reshape(
|
200 |
+
batch, channel, height // factor, factor, width // factor, factor
|
201 |
+
)
|
202 |
+
img_gen = img_gen.mean([3, 5])
|
203 |
+
|
204 |
+
p_loss = percept(img_gen, imgs).sum()
|
205 |
+
n_loss = noise_regularize(noises)
|
206 |
+
mse_loss = F.mse_loss(img_gen, imgs)
|
207 |
+
|
208 |
+
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
|
209 |
+
|
210 |
+
optimizer.zero_grad()
|
211 |
+
loss.backward()
|
212 |
+
optimizer.step()
|
213 |
+
|
214 |
+
noise_normalize_(noises)
|
215 |
+
|
216 |
+
if (i + 1) % 100 == 0:
|
217 |
+
latent_path.append(latent_in.detach().clone())
|
218 |
+
|
219 |
+
pbar.set_description(
|
220 |
+
(
|
221 |
+
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
|
222 |
+
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
|
223 |
+
)
|
224 |
+
)
|
225 |
+
|
226 |
+
img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
|
227 |
+
|
228 |
+
filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
|
229 |
+
|
230 |
+
img_ar = make_image(img_gen)
|
231 |
+
|
232 |
+
result_file = {}
|
233 |
+
for i, input_name in enumerate(args.files):
|
234 |
+
noise_single = []
|
235 |
+
for noise in noises:
|
236 |
+
noise_single.append(noise[i : i + 1])
|
237 |
+
|
238 |
+
result_file[input_name] = {
|
239 |
+
"img": img_gen[i],
|
240 |
+
"latent": latent_in[i],
|
241 |
+
"noise": noise_single,
|
242 |
+
}
|
243 |
+
|
244 |
+
img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
|
245 |
+
pil_img = Image.fromarray(img_ar[i])
|
246 |
+
pil_img.save(img_name)
|
247 |
+
|
248 |
+
torch.save(result_file, filename)
|
sample/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.png
|
swagan.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import functools
|
4 |
+
import operator
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.autograd import Function
|
10 |
+
|
11 |
+
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
12 |
+
from model import (
|
13 |
+
ModulatedConv2d,
|
14 |
+
StyledConv,
|
15 |
+
ConstantInput,
|
16 |
+
PixelNorm,
|
17 |
+
Upsample,
|
18 |
+
Downsample,
|
19 |
+
Blur,
|
20 |
+
EqualLinear,
|
21 |
+
ConvLayer,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def get_haar_wavelet(in_channels):
|
26 |
+
haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
|
27 |
+
haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2)
|
28 |
+
haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0]
|
29 |
+
|
30 |
+
haar_wav_ll = haar_wav_l.T * haar_wav_l
|
31 |
+
haar_wav_lh = haar_wav_h.T * haar_wav_l
|
32 |
+
haar_wav_hl = haar_wav_l.T * haar_wav_h
|
33 |
+
haar_wav_hh = haar_wav_h.T * haar_wav_h
|
34 |
+
|
35 |
+
return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh
|
36 |
+
|
37 |
+
|
38 |
+
def dwt_init(x):
|
39 |
+
x01 = x[:, :, 0::2, :] / 2
|
40 |
+
x02 = x[:, :, 1::2, :] / 2
|
41 |
+
x1 = x01[:, :, :, 0::2]
|
42 |
+
x2 = x02[:, :, :, 0::2]
|
43 |
+
x3 = x01[:, :, :, 1::2]
|
44 |
+
x4 = x02[:, :, :, 1::2]
|
45 |
+
x_LL = x1 + x2 + x3 + x4
|
46 |
+
x_HL = -x1 - x2 + x3 + x4
|
47 |
+
x_LH = -x1 + x2 - x3 + x4
|
48 |
+
x_HH = x1 - x2 - x3 + x4
|
49 |
+
|
50 |
+
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
|
51 |
+
|
52 |
+
|
53 |
+
def iwt_init(x):
|
54 |
+
r = 2
|
55 |
+
in_batch, in_channel, in_height, in_width = x.size()
|
56 |
+
# print([in_batch, in_channel, in_height, in_width])
|
57 |
+
out_batch, out_channel, out_height, out_width = (
|
58 |
+
in_batch,
|
59 |
+
int(in_channel / (r ** 2)),
|
60 |
+
r * in_height,
|
61 |
+
r * in_width,
|
62 |
+
)
|
63 |
+
x1 = x[:, 0:out_channel, :, :] / 2
|
64 |
+
x2 = x[:, out_channel : out_channel * 2, :, :] / 2
|
65 |
+
x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
|
66 |
+
x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
|
67 |
+
|
68 |
+
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
|
69 |
+
|
70 |
+
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
|
71 |
+
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
|
72 |
+
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
|
73 |
+
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
|
74 |
+
|
75 |
+
return h
|
76 |
+
|
77 |
+
|
78 |
+
class HaarTransform(nn.Module):
|
79 |
+
def __init__(self, in_channels):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
ll, lh, hl, hh = get_haar_wavelet(in_channels)
|
83 |
+
|
84 |
+
self.register_buffer("ll", ll)
|
85 |
+
self.register_buffer("lh", lh)
|
86 |
+
self.register_buffer("hl", hl)
|
87 |
+
self.register_buffer("hh", hh)
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
ll = upfirdn2d(input, self.ll, down=2)
|
91 |
+
lh = upfirdn2d(input, self.lh, down=2)
|
92 |
+
hl = upfirdn2d(input, self.hl, down=2)
|
93 |
+
hh = upfirdn2d(input, self.hh, down=2)
|
94 |
+
|
95 |
+
return torch.cat((ll, lh, hl, hh), 1)
|
96 |
+
|
97 |
+
|
98 |
+
class InverseHaarTransform(nn.Module):
|
99 |
+
def __init__(self, in_channels):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
ll, lh, hl, hh = get_haar_wavelet(in_channels)
|
103 |
+
|
104 |
+
self.register_buffer("ll", ll)
|
105 |
+
self.register_buffer("lh", -lh)
|
106 |
+
self.register_buffer("hl", -hl)
|
107 |
+
self.register_buffer("hh", hh)
|
108 |
+
|
109 |
+
def forward(self, input):
|
110 |
+
ll, lh, hl, hh = input.chunk(4, 1)
|
111 |
+
ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
|
112 |
+
lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
|
113 |
+
hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
|
114 |
+
hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))
|
115 |
+
|
116 |
+
return ll + lh + hl + hh
|
117 |
+
|
118 |
+
|
119 |
+
class ToRGB(nn.Module):
|
120 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
if upsample:
|
124 |
+
self.iwt = InverseHaarTransform(3)
|
125 |
+
self.upsample = Upsample(blur_kernel)
|
126 |
+
self.dwt = HaarTransform(3)
|
127 |
+
|
128 |
+
self.conv = ModulatedConv2d(in_channel, 3 * 4, 1, style_dim, demodulate=False)
|
129 |
+
self.bias = nn.Parameter(torch.zeros(1, 3 * 4, 1, 1))
|
130 |
+
|
131 |
+
def forward(self, input, style, skip=None):
|
132 |
+
out = self.conv(input, style)
|
133 |
+
out = out + self.bias
|
134 |
+
|
135 |
+
if skip is not None:
|
136 |
+
skip = self.iwt(skip)
|
137 |
+
skip = self.upsample(skip)
|
138 |
+
skip = self.dwt(skip)
|
139 |
+
|
140 |
+
out = out + skip
|
141 |
+
|
142 |
+
return out
|
143 |
+
|
144 |
+
|
145 |
+
class Generator(nn.Module):
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
size,
|
149 |
+
style_dim,
|
150 |
+
n_mlp,
|
151 |
+
channel_multiplier=2,
|
152 |
+
blur_kernel=[1, 3, 3, 1],
|
153 |
+
lr_mlp=0.01,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.size = size
|
158 |
+
|
159 |
+
self.style_dim = style_dim
|
160 |
+
|
161 |
+
layers = [PixelNorm()]
|
162 |
+
|
163 |
+
for i in range(n_mlp):
|
164 |
+
layers.append(
|
165 |
+
EqualLinear(
|
166 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
167 |
+
)
|
168 |
+
)
|
169 |
+
|
170 |
+
self.style = nn.Sequential(*layers)
|
171 |
+
|
172 |
+
self.channels = {
|
173 |
+
4: 512,
|
174 |
+
8: 512,
|
175 |
+
16: 512,
|
176 |
+
32: 512,
|
177 |
+
64: 256 * channel_multiplier,
|
178 |
+
128: 128 * channel_multiplier,
|
179 |
+
256: 64 * channel_multiplier,
|
180 |
+
512: 32 * channel_multiplier,
|
181 |
+
1024: 16 * channel_multiplier,
|
182 |
+
}
|
183 |
+
|
184 |
+
self.input = ConstantInput(self.channels[4])
|
185 |
+
self.conv1 = StyledConv(
|
186 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
187 |
+
)
|
188 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
189 |
+
|
190 |
+
self.log_size = int(math.log(size, 2)) - 1
|
191 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
192 |
+
|
193 |
+
self.convs = nn.ModuleList()
|
194 |
+
self.upsamples = nn.ModuleList()
|
195 |
+
self.to_rgbs = nn.ModuleList()
|
196 |
+
self.noises = nn.Module()
|
197 |
+
|
198 |
+
in_channel = self.channels[4]
|
199 |
+
|
200 |
+
for layer_idx in range(self.num_layers):
|
201 |
+
res = (layer_idx + 5) // 2
|
202 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
203 |
+
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
204 |
+
|
205 |
+
for i in range(3, self.log_size + 1):
|
206 |
+
out_channel = self.channels[2 ** i]
|
207 |
+
|
208 |
+
self.convs.append(
|
209 |
+
StyledConv(
|
210 |
+
in_channel,
|
211 |
+
out_channel,
|
212 |
+
3,
|
213 |
+
style_dim,
|
214 |
+
upsample=True,
|
215 |
+
blur_kernel=blur_kernel,
|
216 |
+
)
|
217 |
+
)
|
218 |
+
|
219 |
+
self.convs.append(
|
220 |
+
StyledConv(
|
221 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
222 |
+
)
|
223 |
+
)
|
224 |
+
|
225 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
226 |
+
|
227 |
+
in_channel = out_channel
|
228 |
+
|
229 |
+
self.iwt = InverseHaarTransform(3)
|
230 |
+
|
231 |
+
self.n_latent = self.log_size * 2 - 2
|
232 |
+
|
233 |
+
def make_noise(self):
|
234 |
+
device = self.input.input.device
|
235 |
+
|
236 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
237 |
+
|
238 |
+
for i in range(3, self.log_size + 1):
|
239 |
+
for _ in range(2):
|
240 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
241 |
+
|
242 |
+
return noises
|
243 |
+
|
244 |
+
def mean_latent(self, n_latent):
|
245 |
+
latent_in = torch.randn(
|
246 |
+
n_latent, self.style_dim, device=self.input.input.device
|
247 |
+
)
|
248 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
249 |
+
|
250 |
+
return latent
|
251 |
+
|
252 |
+
def get_latent(self, input):
|
253 |
+
return self.style(input)
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
styles,
|
258 |
+
return_latents=False,
|
259 |
+
inject_index=None,
|
260 |
+
truncation=1,
|
261 |
+
truncation_latent=None,
|
262 |
+
input_is_latent=False,
|
263 |
+
noise=None,
|
264 |
+
randomize_noise=True,
|
265 |
+
):
|
266 |
+
if not input_is_latent:
|
267 |
+
styles = [self.style(s) for s in styles]
|
268 |
+
|
269 |
+
if noise is None:
|
270 |
+
if randomize_noise:
|
271 |
+
noise = [None] * self.num_layers
|
272 |
+
else:
|
273 |
+
noise = [
|
274 |
+
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
275 |
+
]
|
276 |
+
|
277 |
+
if truncation < 1:
|
278 |
+
style_t = []
|
279 |
+
|
280 |
+
for style in styles:
|
281 |
+
style_t.append(
|
282 |
+
truncation_latent + truncation * (style - truncation_latent)
|
283 |
+
)
|
284 |
+
|
285 |
+
styles = style_t
|
286 |
+
|
287 |
+
if len(styles) < 2:
|
288 |
+
inject_index = self.n_latent
|
289 |
+
|
290 |
+
if styles[0].ndim < 3:
|
291 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
292 |
+
|
293 |
+
else:
|
294 |
+
latent = styles[0]
|
295 |
+
|
296 |
+
else:
|
297 |
+
if inject_index is None:
|
298 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
299 |
+
|
300 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
301 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
302 |
+
|
303 |
+
latent = torch.cat([latent, latent2], 1)
|
304 |
+
|
305 |
+
out = self.input(latent)
|
306 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
307 |
+
|
308 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
309 |
+
|
310 |
+
i = 1
|
311 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
312 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
313 |
+
):
|
314 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
315 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
316 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
317 |
+
|
318 |
+
i += 2
|
319 |
+
|
320 |
+
image = self.iwt(skip)
|
321 |
+
|
322 |
+
if return_latents:
|
323 |
+
return image, latent
|
324 |
+
|
325 |
+
else:
|
326 |
+
return image, None
|
327 |
+
|
328 |
+
|
329 |
+
class ConvBlock(nn.Module):
|
330 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
331 |
+
super().__init__()
|
332 |
+
|
333 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
334 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
335 |
+
|
336 |
+
def forward(self, input):
|
337 |
+
out = self.conv1(input)
|
338 |
+
out = self.conv2(out)
|
339 |
+
|
340 |
+
return out
|
341 |
+
|
342 |
+
|
343 |
+
class FromRGB(nn.Module):
|
344 |
+
def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]):
|
345 |
+
super().__init__()
|
346 |
+
|
347 |
+
self.downsample = downsample
|
348 |
+
|
349 |
+
if downsample:
|
350 |
+
self.iwt = InverseHaarTransform(3)
|
351 |
+
self.downsample = Downsample(blur_kernel)
|
352 |
+
self.dwt = HaarTransform(3)
|
353 |
+
|
354 |
+
self.conv = ConvLayer(3 * 4, out_channel, 3)
|
355 |
+
|
356 |
+
def forward(self, input, skip=None):
|
357 |
+
if self.downsample:
|
358 |
+
input = self.iwt(input)
|
359 |
+
input = self.downsample(input)
|
360 |
+
input = self.dwt(input)
|
361 |
+
|
362 |
+
out = self.conv(input)
|
363 |
+
|
364 |
+
if skip is not None:
|
365 |
+
out = out + skip
|
366 |
+
|
367 |
+
return input, out
|
368 |
+
|
369 |
+
|
370 |
+
class Discriminator(nn.Module):
|
371 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
372 |
+
super().__init__()
|
373 |
+
|
374 |
+
channels = {
|
375 |
+
4: 512,
|
376 |
+
8: 512,
|
377 |
+
16: 512,
|
378 |
+
32: 512,
|
379 |
+
64: 256 * channel_multiplier,
|
380 |
+
128: 128 * channel_multiplier,
|
381 |
+
256: 64 * channel_multiplier,
|
382 |
+
512: 32 * channel_multiplier,
|
383 |
+
1024: 16 * channel_multiplier,
|
384 |
+
}
|
385 |
+
|
386 |
+
self.dwt = HaarTransform(3)
|
387 |
+
|
388 |
+
self.from_rgbs = nn.ModuleList()
|
389 |
+
self.convs = nn.ModuleList()
|
390 |
+
|
391 |
+
log_size = int(math.log(size, 2)) - 1
|
392 |
+
|
393 |
+
in_channel = channels[size]
|
394 |
+
|
395 |
+
for i in range(log_size, 2, -1):
|
396 |
+
out_channel = channels[2 ** (i - 1)]
|
397 |
+
|
398 |
+
self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size))
|
399 |
+
self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel))
|
400 |
+
|
401 |
+
in_channel = out_channel
|
402 |
+
|
403 |
+
self.from_rgbs.append(FromRGB(channels[4]))
|
404 |
+
|
405 |
+
self.stddev_group = 4
|
406 |
+
self.stddev_feat = 1
|
407 |
+
|
408 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
409 |
+
self.final_linear = nn.Sequential(
|
410 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
411 |
+
EqualLinear(channels[4], 1),
|
412 |
+
)
|
413 |
+
|
414 |
+
def forward(self, input):
|
415 |
+
input = self.dwt(input)
|
416 |
+
out = None
|
417 |
+
|
418 |
+
for from_rgb, conv in zip(self.from_rgbs, self.convs):
|
419 |
+
input, out = from_rgb(input, out)
|
420 |
+
out = conv(out)
|
421 |
+
|
422 |
+
_, out = self.from_rgbs[-1](input, out)
|
423 |
+
|
424 |
+
batch, channel, height, width = out.shape
|
425 |
+
group = min(batch, self.stddev_group)
|
426 |
+
stddev = out.view(
|
427 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
428 |
+
)
|
429 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
430 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
431 |
+
stddev = stddev.repeat(group, 1, height, width)
|
432 |
+
out = torch.cat([out, stddev], 1)
|
433 |
+
|
434 |
+
out = self.final_conv(out)
|
435 |
+
|
436 |
+
out = out.view(batch, -1)
|
437 |
+
out = self.final_linear(out)
|
438 |
+
|
439 |
+
return out
|
440 |
+
|
train.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn, autograd, optim
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.utils import data
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torchvision import transforms, utils
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
try:
|
16 |
+
import wandb
|
17 |
+
|
18 |
+
except ImportError:
|
19 |
+
wandb = None
|
20 |
+
|
21 |
+
|
22 |
+
from dataset import MultiResolutionDataset
|
23 |
+
from distributed import (
|
24 |
+
get_rank,
|
25 |
+
synchronize,
|
26 |
+
reduce_loss_dict,
|
27 |
+
reduce_sum,
|
28 |
+
get_world_size,
|
29 |
+
)
|
30 |
+
from op import conv2d_gradfix
|
31 |
+
from non_leaking import augment, AdaptiveAugment
|
32 |
+
|
33 |
+
|
34 |
+
def data_sampler(dataset, shuffle, distributed):
|
35 |
+
if distributed:
|
36 |
+
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
37 |
+
|
38 |
+
if shuffle:
|
39 |
+
return data.RandomSampler(dataset)
|
40 |
+
|
41 |
+
else:
|
42 |
+
return data.SequentialSampler(dataset)
|
43 |
+
|
44 |
+
|
45 |
+
def requires_grad(model, flag=True):
|
46 |
+
for p in model.parameters():
|
47 |
+
p.requires_grad = flag
|
48 |
+
|
49 |
+
|
50 |
+
def accumulate(model1, model2, decay=0.999):
|
51 |
+
par1 = dict(model1.named_parameters())
|
52 |
+
par2 = dict(model2.named_parameters())
|
53 |
+
|
54 |
+
for k in par1.keys():
|
55 |
+
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
|
56 |
+
|
57 |
+
|
58 |
+
def sample_data(loader):
|
59 |
+
while True:
|
60 |
+
for batch in loader:
|
61 |
+
yield batch
|
62 |
+
|
63 |
+
|
64 |
+
def d_logistic_loss(real_pred, fake_pred):
|
65 |
+
real_loss = F.softplus(-real_pred)
|
66 |
+
fake_loss = F.softplus(fake_pred)
|
67 |
+
|
68 |
+
return real_loss.mean() + fake_loss.mean()
|
69 |
+
|
70 |
+
|
71 |
+
def d_r1_loss(real_pred, real_img):
|
72 |
+
with conv2d_gradfix.no_weight_gradients():
|
73 |
+
grad_real, = autograd.grad(
|
74 |
+
outputs=real_pred.sum(), inputs=real_img, create_graph=True
|
75 |
+
)
|
76 |
+
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
|
77 |
+
|
78 |
+
return grad_penalty
|
79 |
+
|
80 |
+
|
81 |
+
def g_nonsaturating_loss(fake_pred):
|
82 |
+
loss = F.softplus(-fake_pred).mean()
|
83 |
+
|
84 |
+
return loss
|
85 |
+
|
86 |
+
|
87 |
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
88 |
+
noise = torch.randn_like(fake_img) / math.sqrt(
|
89 |
+
fake_img.shape[2] * fake_img.shape[3]
|
90 |
+
)
|
91 |
+
grad, = autograd.grad(
|
92 |
+
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
|
93 |
+
)
|
94 |
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
95 |
+
|
96 |
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
97 |
+
|
98 |
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
99 |
+
|
100 |
+
return path_penalty, path_mean.detach(), path_lengths
|
101 |
+
|
102 |
+
|
103 |
+
def make_noise(batch, latent_dim, n_noise, device):
|
104 |
+
if n_noise == 1:
|
105 |
+
return torch.randn(batch, latent_dim, device=device)
|
106 |
+
|
107 |
+
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
|
108 |
+
|
109 |
+
return noises
|
110 |
+
|
111 |
+
|
112 |
+
def mixing_noise(batch, latent_dim, prob, device):
|
113 |
+
if prob > 0 and random.random() < prob:
|
114 |
+
return make_noise(batch, latent_dim, 2, device)
|
115 |
+
|
116 |
+
else:
|
117 |
+
return [make_noise(batch, latent_dim, 1, device)]
|
118 |
+
|
119 |
+
|
120 |
+
def set_grad_none(model, targets):
|
121 |
+
for n, p in model.named_parameters():
|
122 |
+
if n in targets:
|
123 |
+
p.grad = None
|
124 |
+
|
125 |
+
|
126 |
+
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
|
127 |
+
loader = sample_data(loader)
|
128 |
+
|
129 |
+
pbar = range(args.iter)
|
130 |
+
|
131 |
+
if get_rank() == 0:
|
132 |
+
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
|
133 |
+
|
134 |
+
mean_path_length = 0
|
135 |
+
|
136 |
+
d_loss_val = 0
|
137 |
+
r1_loss = torch.tensor(0.0, device=device)
|
138 |
+
g_loss_val = 0
|
139 |
+
path_loss = torch.tensor(0.0, device=device)
|
140 |
+
path_lengths = torch.tensor(0.0, device=device)
|
141 |
+
mean_path_length_avg = 0
|
142 |
+
loss_dict = {}
|
143 |
+
|
144 |
+
if args.distributed:
|
145 |
+
g_module = generator.module
|
146 |
+
d_module = discriminator.module
|
147 |
+
|
148 |
+
else:
|
149 |
+
g_module = generator
|
150 |
+
d_module = discriminator
|
151 |
+
|
152 |
+
accum = 0.5 ** (32 / (10 * 1000))
|
153 |
+
ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
|
154 |
+
r_t_stat = 0
|
155 |
+
|
156 |
+
if args.augment and args.augment_p == 0:
|
157 |
+
ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)
|
158 |
+
|
159 |
+
sample_z = torch.randn(args.n_sample, args.latent, device=device)
|
160 |
+
|
161 |
+
for idx in pbar:
|
162 |
+
i = idx + args.start_iter
|
163 |
+
|
164 |
+
if i > args.iter:
|
165 |
+
print("Done!")
|
166 |
+
|
167 |
+
break
|
168 |
+
|
169 |
+
real_img = next(loader)
|
170 |
+
real_img = real_img.to(device)
|
171 |
+
|
172 |
+
requires_grad(generator, False)
|
173 |
+
requires_grad(discriminator, True)
|
174 |
+
|
175 |
+
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
176 |
+
fake_img, _ = generator(noise)
|
177 |
+
|
178 |
+
if args.augment:
|
179 |
+
real_img_aug, _ = augment(real_img, ada_aug_p)
|
180 |
+
fake_img, _ = augment(fake_img, ada_aug_p)
|
181 |
+
|
182 |
+
else:
|
183 |
+
real_img_aug = real_img
|
184 |
+
|
185 |
+
fake_pred = discriminator(fake_img)
|
186 |
+
real_pred = discriminator(real_img_aug)
|
187 |
+
d_loss = d_logistic_loss(real_pred, fake_pred)
|
188 |
+
|
189 |
+
loss_dict["d"] = d_loss
|
190 |
+
loss_dict["real_score"] = real_pred.mean()
|
191 |
+
loss_dict["fake_score"] = fake_pred.mean()
|
192 |
+
|
193 |
+
discriminator.zero_grad()
|
194 |
+
d_loss.backward()
|
195 |
+
d_optim.step()
|
196 |
+
|
197 |
+
if args.augment and args.augment_p == 0:
|
198 |
+
ada_aug_p = ada_augment.tune(real_pred)
|
199 |
+
r_t_stat = ada_augment.r_t_stat
|
200 |
+
|
201 |
+
d_regularize = i % args.d_reg_every == 0
|
202 |
+
|
203 |
+
if d_regularize:
|
204 |
+
real_img.requires_grad = True
|
205 |
+
|
206 |
+
if args.augment:
|
207 |
+
real_img_aug, _ = augment(real_img, ada_aug_p)
|
208 |
+
|
209 |
+
else:
|
210 |
+
real_img_aug = real_img
|
211 |
+
|
212 |
+
real_pred = discriminator(real_img_aug)
|
213 |
+
r1_loss = d_r1_loss(real_pred, real_img)
|
214 |
+
|
215 |
+
discriminator.zero_grad()
|
216 |
+
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
|
217 |
+
|
218 |
+
d_optim.step()
|
219 |
+
|
220 |
+
loss_dict["r1"] = r1_loss
|
221 |
+
|
222 |
+
requires_grad(generator, True)
|
223 |
+
requires_grad(discriminator, False)
|
224 |
+
|
225 |
+
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
226 |
+
fake_img, _ = generator(noise)
|
227 |
+
|
228 |
+
if args.augment:
|
229 |
+
fake_img, _ = augment(fake_img, ada_aug_p)
|
230 |
+
|
231 |
+
fake_pred = discriminator(fake_img)
|
232 |
+
g_loss = g_nonsaturating_loss(fake_pred)
|
233 |
+
|
234 |
+
loss_dict["g"] = g_loss
|
235 |
+
|
236 |
+
generator.zero_grad()
|
237 |
+
g_loss.backward()
|
238 |
+
g_optim.step()
|
239 |
+
|
240 |
+
g_regularize = i % args.g_reg_every == 0
|
241 |
+
|
242 |
+
if g_regularize:
|
243 |
+
path_batch_size = max(1, args.batch // args.path_batch_shrink)
|
244 |
+
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
|
245 |
+
fake_img, latents = generator(noise, return_latents=True)
|
246 |
+
|
247 |
+
path_loss, mean_path_length, path_lengths = g_path_regularize(
|
248 |
+
fake_img, latents, mean_path_length
|
249 |
+
)
|
250 |
+
|
251 |
+
generator.zero_grad()
|
252 |
+
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
|
253 |
+
|
254 |
+
if args.path_batch_shrink:
|
255 |
+
weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
|
256 |
+
|
257 |
+
weighted_path_loss.backward()
|
258 |
+
|
259 |
+
g_optim.step()
|
260 |
+
|
261 |
+
mean_path_length_avg = (
|
262 |
+
reduce_sum(mean_path_length).item() / get_world_size()
|
263 |
+
)
|
264 |
+
|
265 |
+
loss_dict["path"] = path_loss
|
266 |
+
loss_dict["path_length"] = path_lengths.mean()
|
267 |
+
|
268 |
+
accumulate(g_ema, g_module, accum)
|
269 |
+
|
270 |
+
loss_reduced = reduce_loss_dict(loss_dict)
|
271 |
+
|
272 |
+
d_loss_val = loss_reduced["d"].mean().item()
|
273 |
+
g_loss_val = loss_reduced["g"].mean().item()
|
274 |
+
r1_val = loss_reduced["r1"].mean().item()
|
275 |
+
path_loss_val = loss_reduced["path"].mean().item()
|
276 |
+
real_score_val = loss_reduced["real_score"].mean().item()
|
277 |
+
fake_score_val = loss_reduced["fake_score"].mean().item()
|
278 |
+
path_length_val = loss_reduced["path_length"].mean().item()
|
279 |
+
|
280 |
+
if get_rank() == 0:
|
281 |
+
pbar.set_description(
|
282 |
+
(
|
283 |
+
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
|
284 |
+
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
|
285 |
+
f"augment: {ada_aug_p:.4f}"
|
286 |
+
)
|
287 |
+
)
|
288 |
+
|
289 |
+
if wandb and args.wandb:
|
290 |
+
wandb.log(
|
291 |
+
{
|
292 |
+
"Generator": g_loss_val,
|
293 |
+
"Discriminator": d_loss_val,
|
294 |
+
"Augment": ada_aug_p,
|
295 |
+
"Rt": r_t_stat,
|
296 |
+
"R1": r1_val,
|
297 |
+
"Path Length Regularization": path_loss_val,
|
298 |
+
"Mean Path Length": mean_path_length,
|
299 |
+
"Real Score": real_score_val,
|
300 |
+
"Fake Score": fake_score_val,
|
301 |
+
"Path Length": path_length_val,
|
302 |
+
}
|
303 |
+
)
|
304 |
+
|
305 |
+
if i % 100 == 0:
|
306 |
+
with torch.no_grad():
|
307 |
+
g_ema.eval()
|
308 |
+
sample, _ = g_ema([sample_z])
|
309 |
+
utils.save_image(
|
310 |
+
sample,
|
311 |
+
f"sample/{str(i).zfill(6)}.png",
|
312 |
+
nrow=int(args.n_sample ** 0.5),
|
313 |
+
normalize=True,
|
314 |
+
range=(-1, 1),
|
315 |
+
)
|
316 |
+
|
317 |
+
if i % 10000 == 0:
|
318 |
+
torch.save(
|
319 |
+
{
|
320 |
+
"g": g_module.state_dict(),
|
321 |
+
"d": d_module.state_dict(),
|
322 |
+
"g_ema": g_ema.state_dict(),
|
323 |
+
"g_optim": g_optim.state_dict(),
|
324 |
+
"d_optim": d_optim.state_dict(),
|
325 |
+
"args": args,
|
326 |
+
"ada_aug_p": ada_aug_p,
|
327 |
+
},
|
328 |
+
f"checkpoint/{str(i).zfill(6)}.pt",
|
329 |
+
)
|
330 |
+
|
331 |
+
|
332 |
+
if __name__ == "__main__":
|
333 |
+
device = "cuda"
|
334 |
+
|
335 |
+
parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
|
336 |
+
|
337 |
+
parser.add_argument("path", type=str, help="path to the lmdb dataset")
|
338 |
+
parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
|
339 |
+
parser.add_argument(
|
340 |
+
"--iter", type=int, default=800000, help="total training iterations"
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
"--batch", type=int, default=16, help="batch sizes for each gpus"
|
344 |
+
)
|
345 |
+
parser.add_argument(
|
346 |
+
"--n_sample",
|
347 |
+
type=int,
|
348 |
+
default=64,
|
349 |
+
help="number of the samples generated during training",
|
350 |
+
)
|
351 |
+
parser.add_argument(
|
352 |
+
"--size", type=int, default=256, help="image sizes for the model"
|
353 |
+
)
|
354 |
+
parser.add_argument(
|
355 |
+
"--r1", type=float, default=10, help="weight of the r1 regularization"
|
356 |
+
)
|
357 |
+
parser.add_argument(
|
358 |
+
"--path_regularize",
|
359 |
+
type=float,
|
360 |
+
default=2,
|
361 |
+
help="weight of the path length regularization",
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--path_batch_shrink",
|
365 |
+
type=int,
|
366 |
+
default=2,
|
367 |
+
help="batch size reducing factor for the path length regularization (reduce memory consumption)",
|
368 |
+
)
|
369 |
+
parser.add_argument(
|
370 |
+
"--d_reg_every",
|
371 |
+
type=int,
|
372 |
+
default=16,
|
373 |
+
help="interval of the applying r1 regularization",
|
374 |
+
)
|
375 |
+
parser.add_argument(
|
376 |
+
"--g_reg_every",
|
377 |
+
type=int,
|
378 |
+
default=4,
|
379 |
+
help="interval of the applying path length regularization",
|
380 |
+
)
|
381 |
+
parser.add_argument(
|
382 |
+
"--mixing", type=float, default=0.9, help="probability of latent code mixing"
|
383 |
+
)
|
384 |
+
parser.add_argument(
|
385 |
+
"--ckpt",
|
386 |
+
type=str,
|
387 |
+
default=None,
|
388 |
+
help="path to the checkpoints to resume training",
|
389 |
+
)
|
390 |
+
parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
|
391 |
+
parser.add_argument(
|
392 |
+
"--channel_multiplier",
|
393 |
+
type=int,
|
394 |
+
default=2,
|
395 |
+
help="channel multiplier factor for the model. config-f = 2, else = 1",
|
396 |
+
)
|
397 |
+
parser.add_argument(
|
398 |
+
"--wandb", action="store_true", help="use weights and biases logging"
|
399 |
+
)
|
400 |
+
parser.add_argument(
|
401 |
+
"--local_rank", type=int, default=0, help="local rank for distributed training"
|
402 |
+
)
|
403 |
+
parser.add_argument(
|
404 |
+
"--augment", action="store_true", help="apply non leaking augmentation"
|
405 |
+
)
|
406 |
+
parser.add_argument(
|
407 |
+
"--augment_p",
|
408 |
+
type=float,
|
409 |
+
default=0,
|
410 |
+
help="probability of applying augmentation. 0 = use adaptive augmentation",
|
411 |
+
)
|
412 |
+
parser.add_argument(
|
413 |
+
"--ada_target",
|
414 |
+
type=float,
|
415 |
+
default=0.6,
|
416 |
+
help="target augmentation probability for adaptive augmentation",
|
417 |
+
)
|
418 |
+
parser.add_argument(
|
419 |
+
"--ada_length",
|
420 |
+
type=int,
|
421 |
+
default=500 * 1000,
|
422 |
+
help="target duraing to reach augmentation probability for adaptive augmentation",
|
423 |
+
)
|
424 |
+
parser.add_argument(
|
425 |
+
"--ada_every",
|
426 |
+
type=int,
|
427 |
+
default=256,
|
428 |
+
help="probability update interval of the adaptive augmentation",
|
429 |
+
)
|
430 |
+
|
431 |
+
args = parser.parse_args()
|
432 |
+
|
433 |
+
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
434 |
+
args.distributed = n_gpu > 1
|
435 |
+
|
436 |
+
if args.distributed:
|
437 |
+
torch.cuda.set_device(args.local_rank)
|
438 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
439 |
+
synchronize()
|
440 |
+
|
441 |
+
args.latent = 512
|
442 |
+
args.n_mlp = 8
|
443 |
+
|
444 |
+
args.start_iter = 0
|
445 |
+
|
446 |
+
if args.arch == 'stylegan2':
|
447 |
+
from model import Generator, Discriminator
|
448 |
+
|
449 |
+
elif args.arch == 'swagan':
|
450 |
+
from swagan import Generator, Discriminator
|
451 |
+
|
452 |
+
generator = Generator(
|
453 |
+
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
454 |
+
).to(device)
|
455 |
+
discriminator = Discriminator(
|
456 |
+
args.size, channel_multiplier=args.channel_multiplier
|
457 |
+
).to(device)
|
458 |
+
g_ema = Generator(
|
459 |
+
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
460 |
+
).to(device)
|
461 |
+
g_ema.eval()
|
462 |
+
accumulate(g_ema, generator, 0)
|
463 |
+
|
464 |
+
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
|
465 |
+
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
|
466 |
+
|
467 |
+
g_optim = optim.Adam(
|
468 |
+
generator.parameters(),
|
469 |
+
lr=args.lr * g_reg_ratio,
|
470 |
+
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
|
471 |
+
)
|
472 |
+
d_optim = optim.Adam(
|
473 |
+
discriminator.parameters(),
|
474 |
+
lr=args.lr * d_reg_ratio,
|
475 |
+
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
|
476 |
+
)
|
477 |
+
|
478 |
+
if args.ckpt is not None:
|
479 |
+
print("load model:", args.ckpt)
|
480 |
+
|
481 |
+
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
|
482 |
+
|
483 |
+
try:
|
484 |
+
ckpt_name = os.path.basename(args.ckpt)
|
485 |
+
args.start_iter = int(os.path.splitext(ckpt_name)[0])
|
486 |
+
|
487 |
+
except ValueError:
|
488 |
+
pass
|
489 |
+
|
490 |
+
generator.load_state_dict(ckpt["g"])
|
491 |
+
discriminator.load_state_dict(ckpt["d"])
|
492 |
+
g_ema.load_state_dict(ckpt["g_ema"])
|
493 |
+
|
494 |
+
g_optim.load_state_dict(ckpt["g_optim"])
|
495 |
+
d_optim.load_state_dict(ckpt["d_optim"])
|
496 |
+
|
497 |
+
if args.distributed:
|
498 |
+
generator = nn.parallel.DistributedDataParallel(
|
499 |
+
generator,
|
500 |
+
device_ids=[args.local_rank],
|
501 |
+
output_device=args.local_rank,
|
502 |
+
broadcast_buffers=False,
|
503 |
+
)
|
504 |
+
|
505 |
+
discriminator = nn.parallel.DistributedDataParallel(
|
506 |
+
discriminator,
|
507 |
+
device_ids=[args.local_rank],
|
508 |
+
output_device=args.local_rank,
|
509 |
+
broadcast_buffers=False,
|
510 |
+
)
|
511 |
+
|
512 |
+
transform = transforms.Compose(
|
513 |
+
[
|
514 |
+
transforms.RandomHorizontalFlip(),
|
515 |
+
transforms.ToTensor(),
|
516 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
|
517 |
+
]
|
518 |
+
)
|
519 |
+
|
520 |
+
dataset = MultiResolutionDataset(args.path, transform, args.size)
|
521 |
+
loader = data.DataLoader(
|
522 |
+
dataset,
|
523 |
+
batch_size=args.batch,
|
524 |
+
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
|
525 |
+
drop_last=True,
|
526 |
+
)
|
527 |
+
|
528 |
+
if get_rank() == 0 and wandb is not None and args.wandb:
|
529 |
+
wandb.init(project="stylegan 2")
|
530 |
+
|
531 |
+
train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
|