Spaces:
Runtime error
Runtime error
commit
Browse files- LICENSE +0 -21
- LICENSE-FID +0 -201
- LICENSE-LPIPS +0 -24
- LICENSE-NVIDIA +0 -101
- apply_factor.py +0 -94
- calc_inception.py +0 -130
- checkpoint/.gitignore +0 -1
- closed_form_factorization.py +0 -33
- convert_weight.py +0 -301
- dataset.py +0 -40
- distributed.py +0 -126
- download.py +0 -47
- fid.py +0 -129
- generate.py +0 -84
- inception.py +0 -310
- lpips/__init__.py +0 -160
- lpips/base_model.py +0 -58
- lpips/dist_model.py +0 -284
- lpips/networks_basic.py +0 -187
- lpips/pretrained_networks.py +0 -181
- lpips/weights/v0.0/alex.pth +0 -3
- lpips/weights/v0.0/squeeze.pth +0 -3
- lpips/weights/v0.0/vgg.pth +0 -3
- lpips/weights/v0.1/alex.pth +0 -3
- lpips/weights/v0.1/squeeze.pth +0 -3
- lpips/weights/v0.1/vgg.pth +0 -3
- model.py +0 -698
- non_leaking.py +0 -465
- op/__init__.py +0 -2
- op/conv2d_gradfix.py +0 -227
- op/fused_act.py +0 -127
- op/fused_bias_act.cpp +0 -32
- op/fused_bias_act_kernel.cu +0 -105
- op/upfirdn2d.cpp +0 -31
- op/upfirdn2d.py +0 -209
- op/upfirdn2d_kernel.cu +0 -369
- ppl.py +0 -130
- prepare_data.py +0 -101
- projector.py +0 -248
- sample/.gitignore +0 -1
- swagan.py +0 -440
- train.py +0 -531
LICENSE
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2019 Kim Seonghyeon
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE-FID
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [yyyy] [name of copyright owner]
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE-LPIPS
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
-
All rights reserved.
|
3 |
-
|
4 |
-
Redistribution and use in source and binary forms, with or without
|
5 |
-
modification, are permitted provided that the following conditions are met:
|
6 |
-
|
7 |
-
* Redistributions of source code must retain the above copyright notice, this
|
8 |
-
list of conditions and the following disclaimer.
|
9 |
-
|
10 |
-
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
-
this list of conditions and the following disclaimer in the documentation
|
12 |
-
and/or other materials provided with the distribution.
|
13 |
-
|
14 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
-
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
-
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
-
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
-
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
-
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
-
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
-
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
-
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE-NVIDIA
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
-
|
3 |
-
|
4 |
-
Nvidia Source Code License-NC
|
5 |
-
|
6 |
-
=======================================================================
|
7 |
-
|
8 |
-
1. Definitions
|
9 |
-
|
10 |
-
"Licensor" means any person or entity that distributes its Work.
|
11 |
-
|
12 |
-
"Software" means the original work of authorship made available under
|
13 |
-
this License.
|
14 |
-
|
15 |
-
"Work" means the Software and any additions to or derivative works of
|
16 |
-
the Software that are made available under this License.
|
17 |
-
|
18 |
-
"Nvidia Processors" means any central processing unit (CPU), graphics
|
19 |
-
processing unit (GPU), field-programmable gate array (FPGA),
|
20 |
-
application-specific integrated circuit (ASIC) or any combination
|
21 |
-
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
22 |
-
|
23 |
-
The terms "reproduce," "reproduction," "derivative works," and
|
24 |
-
"distribution" have the meaning as provided under U.S. copyright law;
|
25 |
-
provided, however, that for the purposes of this License, derivative
|
26 |
-
works shall not include works that remain separable from, or merely
|
27 |
-
link (or bind by name) to the interfaces of, the Work.
|
28 |
-
|
29 |
-
Works, including the Software, are "made available" under this License
|
30 |
-
by including in or with the Work either (a) a copyright notice
|
31 |
-
referencing the applicability of this License to the Work, or (b) a
|
32 |
-
copy of this License.
|
33 |
-
|
34 |
-
2. License Grants
|
35 |
-
|
36 |
-
2.1 Copyright Grant. Subject to the terms and conditions of this
|
37 |
-
License, each Licensor grants to you a perpetual, worldwide,
|
38 |
-
non-exclusive, royalty-free, copyright license to reproduce,
|
39 |
-
prepare derivative works of, publicly display, publicly perform,
|
40 |
-
sublicense and distribute its Work and any resulting derivative
|
41 |
-
works in any form.
|
42 |
-
|
43 |
-
3. Limitations
|
44 |
-
|
45 |
-
3.1 Redistribution. You may reproduce or distribute the Work only
|
46 |
-
if (a) you do so under this License, (b) you include a complete
|
47 |
-
copy of this License with your distribution, and (c) you retain
|
48 |
-
without modification any copyright, patent, trademark, or
|
49 |
-
attribution notices that are present in the Work.
|
50 |
-
|
51 |
-
3.2 Derivative Works. You may specify that additional or different
|
52 |
-
terms apply to the use, reproduction, and distribution of your
|
53 |
-
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
54 |
-
provide that the use limitation in Section 3.3 applies to your
|
55 |
-
derivative works, and (b) you identify the specific derivative
|
56 |
-
works that are subject to Your Terms. Notwithstanding Your Terms,
|
57 |
-
this License (including the redistribution requirements in Section
|
58 |
-
3.1) will continue to apply to the Work itself.
|
59 |
-
|
60 |
-
3.3 Use Limitation. The Work and any derivative works thereof only
|
61 |
-
may be used or intended for use non-commercially. The Work or
|
62 |
-
derivative works thereof may be used or intended for use by Nvidia
|
63 |
-
or its affiliates commercially or non-commercially. As used herein,
|
64 |
-
"non-commercially" means for research or evaluation purposes only.
|
65 |
-
|
66 |
-
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
67 |
-
against any Licensor (including any claim, cross-claim or
|
68 |
-
counterclaim in a lawsuit) to enforce any patents that you allege
|
69 |
-
are infringed by any Work, then your rights under this License from
|
70 |
-
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
71 |
-
terminate immediately.
|
72 |
-
|
73 |
-
3.5 Trademarks. This License does not grant any rights to use any
|
74 |
-
Licensor's or its affiliates' names, logos, or trademarks, except
|
75 |
-
as necessary to reproduce the notices described in this License.
|
76 |
-
|
77 |
-
3.6 Termination. If you violate any term of this License, then your
|
78 |
-
rights under this License (including the grants in Sections 2.1 and
|
79 |
-
2.2) will terminate immediately.
|
80 |
-
|
81 |
-
4. Disclaimer of Warranty.
|
82 |
-
|
83 |
-
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
84 |
-
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
85 |
-
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
86 |
-
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
87 |
-
THIS LICENSE.
|
88 |
-
|
89 |
-
5. Limitation of Liability.
|
90 |
-
|
91 |
-
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
92 |
-
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
93 |
-
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
94 |
-
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
95 |
-
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
96 |
-
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
97 |
-
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
98 |
-
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
99 |
-
THE POSSIBILITY OF SUCH DAMAGES.
|
100 |
-
|
101 |
-
=======================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apply_factor.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torchvision import utils
|
5 |
-
|
6 |
-
from model import Generator
|
7 |
-
|
8 |
-
|
9 |
-
if __name__ == "__main__":
|
10 |
-
torch.set_grad_enabled(False)
|
11 |
-
|
12 |
-
parser = argparse.ArgumentParser(description="Apply closed form factorization")
|
13 |
-
|
14 |
-
parser.add_argument(
|
15 |
-
"-i", "--index", type=int, default=0, help="index of eigenvector"
|
16 |
-
)
|
17 |
-
parser.add_argument(
|
18 |
-
"-d",
|
19 |
-
"--degree",
|
20 |
-
type=float,
|
21 |
-
default=5,
|
22 |
-
help="scalar factors for moving latent vectors along eigenvector",
|
23 |
-
)
|
24 |
-
parser.add_argument(
|
25 |
-
"--channel_multiplier",
|
26 |
-
type=int,
|
27 |
-
default=2,
|
28 |
-
help='channel multiplier factor. config-f = 2, else = 1',
|
29 |
-
)
|
30 |
-
parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints")
|
31 |
-
parser.add_argument(
|
32 |
-
"--size", type=int, default=256, help="output image size of the generator"
|
33 |
-
)
|
34 |
-
parser.add_argument(
|
35 |
-
"-n", "--n_sample", type=int, default=7, help="number of samples created"
|
36 |
-
)
|
37 |
-
parser.add_argument(
|
38 |
-
"--truncation", type=float, default=0.7, help="truncation factor"
|
39 |
-
)
|
40 |
-
parser.add_argument(
|
41 |
-
"--device", type=str, default="cuda", help="device to run the model"
|
42 |
-
)
|
43 |
-
parser.add_argument(
|
44 |
-
"--out_prefix",
|
45 |
-
type=str,
|
46 |
-
default="factor",
|
47 |
-
help="filename prefix to result samples",
|
48 |
-
)
|
49 |
-
parser.add_argument(
|
50 |
-
"factor",
|
51 |
-
type=str,
|
52 |
-
help="name of the closed form factorization result factor file",
|
53 |
-
)
|
54 |
-
|
55 |
-
args = parser.parse_args()
|
56 |
-
|
57 |
-
eigvec = torch.load(args.factor)["eigvec"].to(args.device)
|
58 |
-
ckpt = torch.load(args.ckpt)
|
59 |
-
g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)
|
60 |
-
g.load_state_dict(ckpt["g_ema"], strict=False)
|
61 |
-
|
62 |
-
trunc = g.mean_latent(4096)
|
63 |
-
|
64 |
-
latent = torch.randn(args.n_sample, 512, device=args.device)
|
65 |
-
latent = g.get_latent(latent)
|
66 |
-
|
67 |
-
direction = args.degree * eigvec[:, args.index].unsqueeze(0)
|
68 |
-
|
69 |
-
img, _ = g(
|
70 |
-
[latent],
|
71 |
-
truncation=args.truncation,
|
72 |
-
truncation_latent=trunc,
|
73 |
-
input_is_latent=True,
|
74 |
-
)
|
75 |
-
img1, _ = g(
|
76 |
-
[latent + direction],
|
77 |
-
truncation=args.truncation,
|
78 |
-
truncation_latent=trunc,
|
79 |
-
input_is_latent=True,
|
80 |
-
)
|
81 |
-
img2, _ = g(
|
82 |
-
[latent - direction],
|
83 |
-
truncation=args.truncation,
|
84 |
-
truncation_latent=trunc,
|
85 |
-
input_is_latent=True,
|
86 |
-
)
|
87 |
-
|
88 |
-
grid = utils.save_image(
|
89 |
-
torch.cat([img1, img, img2], 0),
|
90 |
-
f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png",
|
91 |
-
normalize=True,
|
92 |
-
range=(-1, 1),
|
93 |
-
nrow=args.n_sample,
|
94 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
calc_inception.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import pickle
|
3 |
-
import os
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from torch import nn
|
7 |
-
from torch.nn import functional as F
|
8 |
-
from torch.utils.data import DataLoader
|
9 |
-
from torchvision import transforms
|
10 |
-
from torchvision.models import inception_v3, Inception3
|
11 |
-
import numpy as np
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
from inception import InceptionV3
|
15 |
-
from dataset import MultiResolutionDataset
|
16 |
-
|
17 |
-
|
18 |
-
class Inception3Feature(Inception3):
|
19 |
-
def forward(self, x):
|
20 |
-
if x.shape[2] != 299 or x.shape[3] != 299:
|
21 |
-
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
|
22 |
-
|
23 |
-
x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
|
24 |
-
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
|
25 |
-
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
|
26 |
-
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
|
27 |
-
|
28 |
-
x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
|
29 |
-
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
|
30 |
-
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
|
31 |
-
|
32 |
-
x = self.Mixed_5b(x) # 35 x 35 x 192
|
33 |
-
x = self.Mixed_5c(x) # 35 x 35 x 256
|
34 |
-
x = self.Mixed_5d(x) # 35 x 35 x 288
|
35 |
-
|
36 |
-
x = self.Mixed_6a(x) # 35 x 35 x 288
|
37 |
-
x = self.Mixed_6b(x) # 17 x 17 x 768
|
38 |
-
x = self.Mixed_6c(x) # 17 x 17 x 768
|
39 |
-
x = self.Mixed_6d(x) # 17 x 17 x 768
|
40 |
-
x = self.Mixed_6e(x) # 17 x 17 x 768
|
41 |
-
|
42 |
-
x = self.Mixed_7a(x) # 17 x 17 x 768
|
43 |
-
x = self.Mixed_7b(x) # 8 x 8 x 1280
|
44 |
-
x = self.Mixed_7c(x) # 8 x 8 x 2048
|
45 |
-
|
46 |
-
x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
|
47 |
-
|
48 |
-
return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
|
49 |
-
|
50 |
-
|
51 |
-
def load_patched_inception_v3():
|
52 |
-
# inception = inception_v3(pretrained=True)
|
53 |
-
# inception_feat = Inception3Feature()
|
54 |
-
# inception_feat.load_state_dict(inception.state_dict())
|
55 |
-
inception_feat = InceptionV3([3], normalize_input=False)
|
56 |
-
|
57 |
-
return inception_feat
|
58 |
-
|
59 |
-
|
60 |
-
@torch.no_grad()
|
61 |
-
def extract_features(loader, inception, device):
|
62 |
-
pbar = tqdm(loader)
|
63 |
-
|
64 |
-
feature_list = []
|
65 |
-
|
66 |
-
for img in pbar:
|
67 |
-
img = img.to(device)
|
68 |
-
feature = inception(img)[0].view(img.shape[0], -1)
|
69 |
-
feature_list.append(feature.to("cpu"))
|
70 |
-
|
71 |
-
features = torch.cat(feature_list, 0)
|
72 |
-
|
73 |
-
return features
|
74 |
-
|
75 |
-
|
76 |
-
if __name__ == "__main__":
|
77 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
78 |
-
|
79 |
-
parser = argparse.ArgumentParser(
|
80 |
-
description="Calculate Inception v3 features for datasets"
|
81 |
-
)
|
82 |
-
parser.add_argument(
|
83 |
-
"--size",
|
84 |
-
type=int,
|
85 |
-
default=256,
|
86 |
-
help="image sizes used for embedding calculation",
|
87 |
-
)
|
88 |
-
parser.add_argument(
|
89 |
-
"--batch", default=64, type=int, help="batch size for inception networks"
|
90 |
-
)
|
91 |
-
parser.add_argument(
|
92 |
-
"--n_sample",
|
93 |
-
type=int,
|
94 |
-
default=50000,
|
95 |
-
help="number of samples used for embedding calculation",
|
96 |
-
)
|
97 |
-
parser.add_argument(
|
98 |
-
"--flip", action="store_true", help="apply random flipping to real images"
|
99 |
-
)
|
100 |
-
parser.add_argument("path", metavar="PATH", help="path to datset lmdb file")
|
101 |
-
|
102 |
-
args = parser.parse_args()
|
103 |
-
|
104 |
-
inception = load_patched_inception_v3()
|
105 |
-
inception = nn.DataParallel(inception).eval().to(device)
|
106 |
-
|
107 |
-
transform = transforms.Compose(
|
108 |
-
[
|
109 |
-
transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
|
110 |
-
transforms.ToTensor(),
|
111 |
-
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
112 |
-
]
|
113 |
-
)
|
114 |
-
|
115 |
-
dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
|
116 |
-
loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
|
117 |
-
|
118 |
-
features = extract_features(loader, inception, device).numpy()
|
119 |
-
|
120 |
-
features = features[: args.n_sample]
|
121 |
-
|
122 |
-
print(f"extracted {features.shape[0]} features")
|
123 |
-
|
124 |
-
mean = np.mean(features, 0)
|
125 |
-
cov = np.cov(features, rowvar=False)
|
126 |
-
|
127 |
-
name = os.path.splitext(os.path.basename(args.path))[0]
|
128 |
-
|
129 |
-
with open(f"inception_{name}.pkl", "wb") as f:
|
130 |
-
pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint/.gitignore
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
*.pt
|
|
|
|
closed_form_factorization.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
if __name__ == "__main__":
|
7 |
-
parser = argparse.ArgumentParser(
|
8 |
-
description="Extract factor/eigenvectors of latent spaces using closed form factorization"
|
9 |
-
)
|
10 |
-
|
11 |
-
parser.add_argument(
|
12 |
-
"--out", type=str, default="factor.pt", help="name of the result factor file"
|
13 |
-
)
|
14 |
-
parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
|
15 |
-
|
16 |
-
args = parser.parse_args()
|
17 |
-
|
18 |
-
ckpt = torch.load(args.ckpt)
|
19 |
-
modulate = {
|
20 |
-
k: v
|
21 |
-
for k, v in ckpt["g_ema"].items()
|
22 |
-
if "modulation" in k and "to_rgbs" not in k and "weight" in k
|
23 |
-
}
|
24 |
-
|
25 |
-
weight_mat = []
|
26 |
-
for k, v in modulate.items():
|
27 |
-
weight_mat.append(v)
|
28 |
-
|
29 |
-
W = torch.cat(weight_mat, 0)
|
30 |
-
eigvec = torch.svd(W).V.to("cpu")
|
31 |
-
|
32 |
-
torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out)
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
convert_weight.py
DELETED
@@ -1,301 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
import pickle
|
5 |
-
import math
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import numpy as np
|
9 |
-
from torchvision import utils
|
10 |
-
|
11 |
-
from model import Generator, Discriminator
|
12 |
-
|
13 |
-
|
14 |
-
def convert_modconv(vars, source_name, target_name, flip=False):
|
15 |
-
weight = vars[source_name + "/weight"].value().eval()
|
16 |
-
mod_weight = vars[source_name + "/mod_weight"].value().eval()
|
17 |
-
mod_bias = vars[source_name + "/mod_bias"].value().eval()
|
18 |
-
noise = vars[source_name + "/noise_strength"].value().eval()
|
19 |
-
bias = vars[source_name + "/bias"].value().eval()
|
20 |
-
|
21 |
-
dic = {
|
22 |
-
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
23 |
-
"conv.modulation.weight": mod_weight.transpose((1, 0)),
|
24 |
-
"conv.modulation.bias": mod_bias + 1,
|
25 |
-
"noise.weight": np.array([noise]),
|
26 |
-
"activate.bias": bias,
|
27 |
-
}
|
28 |
-
|
29 |
-
dic_torch = {}
|
30 |
-
|
31 |
-
for k, v in dic.items():
|
32 |
-
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
33 |
-
|
34 |
-
if flip:
|
35 |
-
dic_torch[target_name + ".conv.weight"] = torch.flip(
|
36 |
-
dic_torch[target_name + ".conv.weight"], [3, 4]
|
37 |
-
)
|
38 |
-
|
39 |
-
return dic_torch
|
40 |
-
|
41 |
-
|
42 |
-
def convert_conv(vars, source_name, target_name, bias=True, start=0):
|
43 |
-
weight = vars[source_name + "/weight"].value().eval()
|
44 |
-
|
45 |
-
dic = {"weight": weight.transpose((3, 2, 0, 1))}
|
46 |
-
|
47 |
-
if bias:
|
48 |
-
dic["bias"] = vars[source_name + "/bias"].value().eval()
|
49 |
-
|
50 |
-
dic_torch = {}
|
51 |
-
|
52 |
-
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
|
53 |
-
|
54 |
-
if bias:
|
55 |
-
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
|
56 |
-
|
57 |
-
return dic_torch
|
58 |
-
|
59 |
-
|
60 |
-
def convert_torgb(vars, source_name, target_name):
|
61 |
-
weight = vars[source_name + "/weight"].value().eval()
|
62 |
-
mod_weight = vars[source_name + "/mod_weight"].value().eval()
|
63 |
-
mod_bias = vars[source_name + "/mod_bias"].value().eval()
|
64 |
-
bias = vars[source_name + "/bias"].value().eval()
|
65 |
-
|
66 |
-
dic = {
|
67 |
-
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
68 |
-
"conv.modulation.weight": mod_weight.transpose((1, 0)),
|
69 |
-
"conv.modulation.bias": mod_bias + 1,
|
70 |
-
"bias": bias.reshape((1, 3, 1, 1)),
|
71 |
-
}
|
72 |
-
|
73 |
-
dic_torch = {}
|
74 |
-
|
75 |
-
for k, v in dic.items():
|
76 |
-
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
77 |
-
|
78 |
-
return dic_torch
|
79 |
-
|
80 |
-
|
81 |
-
def convert_dense(vars, source_name, target_name):
|
82 |
-
weight = vars[source_name + "/weight"].value().eval()
|
83 |
-
bias = vars[source_name + "/bias"].value().eval()
|
84 |
-
|
85 |
-
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
|
86 |
-
|
87 |
-
dic_torch = {}
|
88 |
-
|
89 |
-
for k, v in dic.items():
|
90 |
-
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
91 |
-
|
92 |
-
return dic_torch
|
93 |
-
|
94 |
-
|
95 |
-
def update(state_dict, new):
|
96 |
-
for k, v in new.items():
|
97 |
-
if k not in state_dict:
|
98 |
-
raise KeyError(k + " is not found")
|
99 |
-
|
100 |
-
if v.shape != state_dict[k].shape:
|
101 |
-
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
|
102 |
-
|
103 |
-
state_dict[k] = v
|
104 |
-
|
105 |
-
|
106 |
-
def discriminator_fill_statedict(statedict, vars, size):
|
107 |
-
log_size = int(math.log(size, 2))
|
108 |
-
|
109 |
-
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
110 |
-
|
111 |
-
conv_i = 1
|
112 |
-
|
113 |
-
for i in range(log_size - 2, 0, -1):
|
114 |
-
reso = 4 * 2 ** i
|
115 |
-
update(
|
116 |
-
statedict,
|
117 |
-
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
118 |
-
)
|
119 |
-
update(
|
120 |
-
statedict,
|
121 |
-
convert_conv(
|
122 |
-
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
123 |
-
),
|
124 |
-
)
|
125 |
-
update(
|
126 |
-
statedict,
|
127 |
-
convert_conv(
|
128 |
-
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
129 |
-
),
|
130 |
-
)
|
131 |
-
conv_i += 1
|
132 |
-
|
133 |
-
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
134 |
-
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
135 |
-
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
136 |
-
|
137 |
-
return statedict
|
138 |
-
|
139 |
-
|
140 |
-
def fill_statedict(state_dict, vars, size, n_mlp):
|
141 |
-
log_size = int(math.log(size, 2))
|
142 |
-
|
143 |
-
for i in range(n_mlp):
|
144 |
-
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
|
145 |
-
|
146 |
-
update(
|
147 |
-
state_dict,
|
148 |
-
{
|
149 |
-
"input.input": torch.from_numpy(
|
150 |
-
vars["G_synthesis/4x4/Const/const"].value().eval()
|
151 |
-
)
|
152 |
-
},
|
153 |
-
)
|
154 |
-
|
155 |
-
update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
|
156 |
-
|
157 |
-
for i in range(log_size - 2):
|
158 |
-
reso = 4 * 2 ** (i + 1)
|
159 |
-
update(
|
160 |
-
state_dict,
|
161 |
-
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
|
162 |
-
)
|
163 |
-
|
164 |
-
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
|
165 |
-
|
166 |
-
conv_i = 0
|
167 |
-
|
168 |
-
for i in range(log_size - 2):
|
169 |
-
reso = 4 * 2 ** (i + 1)
|
170 |
-
update(
|
171 |
-
state_dict,
|
172 |
-
convert_modconv(
|
173 |
-
vars,
|
174 |
-
f"G_synthesis/{reso}x{reso}/Conv0_up",
|
175 |
-
f"convs.{conv_i}",
|
176 |
-
flip=True,
|
177 |
-
),
|
178 |
-
)
|
179 |
-
update(
|
180 |
-
state_dict,
|
181 |
-
convert_modconv(
|
182 |
-
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
|
183 |
-
),
|
184 |
-
)
|
185 |
-
conv_i += 2
|
186 |
-
|
187 |
-
for i in range(0, (log_size - 2) * 2 + 1):
|
188 |
-
update(
|
189 |
-
state_dict,
|
190 |
-
{
|
191 |
-
f"noises.noise_{i}": torch.from_numpy(
|
192 |
-
vars[f"G_synthesis/noise{i}"].value().eval()
|
193 |
-
)
|
194 |
-
},
|
195 |
-
)
|
196 |
-
|
197 |
-
return state_dict
|
198 |
-
|
199 |
-
|
200 |
-
if __name__ == "__main__":
|
201 |
-
device = "cuda"
|
202 |
-
|
203 |
-
parser = argparse.ArgumentParser(
|
204 |
-
description="Tensorflow to pytorch model checkpoint converter"
|
205 |
-
)
|
206 |
-
parser.add_argument(
|
207 |
-
"--repo",
|
208 |
-
type=str,
|
209 |
-
required=True,
|
210 |
-
help="path to the offical StyleGAN2 repository with dnnlib/ folder",
|
211 |
-
)
|
212 |
-
parser.add_argument(
|
213 |
-
"--gen", action="store_true", help="convert the generator weights"
|
214 |
-
)
|
215 |
-
parser.add_argument(
|
216 |
-
"--disc", action="store_true", help="convert the discriminator weights"
|
217 |
-
)
|
218 |
-
parser.add_argument(
|
219 |
-
"--channel_multiplier",
|
220 |
-
type=int,
|
221 |
-
default=2,
|
222 |
-
help="channel multiplier factor. config-f = 2, else = 1",
|
223 |
-
)
|
224 |
-
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
|
225 |
-
|
226 |
-
args = parser.parse_args()
|
227 |
-
|
228 |
-
sys.path.append(args.repo)
|
229 |
-
|
230 |
-
import dnnlib
|
231 |
-
from dnnlib import tflib
|
232 |
-
|
233 |
-
tflib.init_tf()
|
234 |
-
|
235 |
-
with open(args.path, "rb") as f:
|
236 |
-
generator, discriminator, g_ema = pickle.load(f)
|
237 |
-
|
238 |
-
size = g_ema.output_shape[2]
|
239 |
-
|
240 |
-
n_mlp = 0
|
241 |
-
mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
|
242 |
-
for layer in mapping_layers_names:
|
243 |
-
if layer[0].startswith('Dense'):
|
244 |
-
n_mlp += 1
|
245 |
-
|
246 |
-
g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
|
247 |
-
state_dict = g.state_dict()
|
248 |
-
state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
|
249 |
-
|
250 |
-
g.load_state_dict(state_dict)
|
251 |
-
|
252 |
-
latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
|
253 |
-
|
254 |
-
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
255 |
-
|
256 |
-
if args.gen:
|
257 |
-
g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
|
258 |
-
g_train_state = g_train.state_dict()
|
259 |
-
g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
|
260 |
-
ckpt["g"] = g_train_state
|
261 |
-
|
262 |
-
if args.disc:
|
263 |
-
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
|
264 |
-
d_state = disc.state_dict()
|
265 |
-
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
|
266 |
-
ckpt["d"] = d_state
|
267 |
-
|
268 |
-
name = os.path.splitext(os.path.basename(args.path))[0]
|
269 |
-
torch.save(ckpt, name + ".pt")
|
270 |
-
|
271 |
-
batch_size = {256: 16, 512: 9, 1024: 4}
|
272 |
-
n_sample = batch_size.get(size, 25)
|
273 |
-
|
274 |
-
g = g.to(device)
|
275 |
-
|
276 |
-
z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
|
277 |
-
|
278 |
-
with torch.no_grad():
|
279 |
-
img_pt, _ = g(
|
280 |
-
[torch.from_numpy(z).to(device)],
|
281 |
-
truncation=0.5,
|
282 |
-
truncation_latent=latent_avg.to(device),
|
283 |
-
randomize_noise=False,
|
284 |
-
)
|
285 |
-
|
286 |
-
Gs_kwargs = dnnlib.EasyDict()
|
287 |
-
Gs_kwargs.randomize_noise = False
|
288 |
-
img_tf = g_ema.run(z, None, **Gs_kwargs)
|
289 |
-
img_tf = torch.from_numpy(img_tf).to(device)
|
290 |
-
|
291 |
-
img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
|
292 |
-
0.0, 1.0
|
293 |
-
)
|
294 |
-
|
295 |
-
img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
|
296 |
-
|
297 |
-
print(img_diff.abs().max())
|
298 |
-
|
299 |
-
utils.save_image(
|
300 |
-
img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
301 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
from io import BytesIO
|
2 |
-
|
3 |
-
import lmdb
|
4 |
-
from PIL import Image
|
5 |
-
from torch.utils.data import Dataset
|
6 |
-
|
7 |
-
|
8 |
-
class MultiResolutionDataset(Dataset):
|
9 |
-
def __init__(self, path, transform, resolution=256):
|
10 |
-
self.env = lmdb.open(
|
11 |
-
path,
|
12 |
-
max_readers=32,
|
13 |
-
readonly=True,
|
14 |
-
lock=False,
|
15 |
-
readahead=False,
|
16 |
-
meminit=False,
|
17 |
-
)
|
18 |
-
|
19 |
-
if not self.env:
|
20 |
-
raise IOError('Cannot open lmdb dataset', path)
|
21 |
-
|
22 |
-
with self.env.begin(write=False) as txn:
|
23 |
-
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
|
24 |
-
|
25 |
-
self.resolution = resolution
|
26 |
-
self.transform = transform
|
27 |
-
|
28 |
-
def __len__(self):
|
29 |
-
return self.length
|
30 |
-
|
31 |
-
def __getitem__(self, index):
|
32 |
-
with self.env.begin(write=False) as txn:
|
33 |
-
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
|
34 |
-
img_bytes = txn.get(key)
|
35 |
-
|
36 |
-
buffer = BytesIO(img_bytes)
|
37 |
-
img = Image.open(buffer)
|
38 |
-
img = self.transform(img)
|
39 |
-
|
40 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
distributed.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import pickle
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import distributed as dist
|
6 |
-
from torch.utils.data.sampler import Sampler
|
7 |
-
|
8 |
-
|
9 |
-
def get_rank():
|
10 |
-
if not dist.is_available():
|
11 |
-
return 0
|
12 |
-
|
13 |
-
if not dist.is_initialized():
|
14 |
-
return 0
|
15 |
-
|
16 |
-
return dist.get_rank()
|
17 |
-
|
18 |
-
|
19 |
-
def synchronize():
|
20 |
-
if not dist.is_available():
|
21 |
-
return
|
22 |
-
|
23 |
-
if not dist.is_initialized():
|
24 |
-
return
|
25 |
-
|
26 |
-
world_size = dist.get_world_size()
|
27 |
-
|
28 |
-
if world_size == 1:
|
29 |
-
return
|
30 |
-
|
31 |
-
dist.barrier()
|
32 |
-
|
33 |
-
|
34 |
-
def get_world_size():
|
35 |
-
if not dist.is_available():
|
36 |
-
return 1
|
37 |
-
|
38 |
-
if not dist.is_initialized():
|
39 |
-
return 1
|
40 |
-
|
41 |
-
return dist.get_world_size()
|
42 |
-
|
43 |
-
|
44 |
-
def reduce_sum(tensor):
|
45 |
-
if not dist.is_available():
|
46 |
-
return tensor
|
47 |
-
|
48 |
-
if not dist.is_initialized():
|
49 |
-
return tensor
|
50 |
-
|
51 |
-
tensor = tensor.clone()
|
52 |
-
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
53 |
-
|
54 |
-
return tensor
|
55 |
-
|
56 |
-
|
57 |
-
def gather_grad(params):
|
58 |
-
world_size = get_world_size()
|
59 |
-
|
60 |
-
if world_size == 1:
|
61 |
-
return
|
62 |
-
|
63 |
-
for param in params:
|
64 |
-
if param.grad is not None:
|
65 |
-
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
66 |
-
param.grad.data.div_(world_size)
|
67 |
-
|
68 |
-
|
69 |
-
def all_gather(data):
|
70 |
-
world_size = get_world_size()
|
71 |
-
|
72 |
-
if world_size == 1:
|
73 |
-
return [data]
|
74 |
-
|
75 |
-
buffer = pickle.dumps(data)
|
76 |
-
storage = torch.ByteStorage.from_buffer(buffer)
|
77 |
-
tensor = torch.ByteTensor(storage).to('cuda')
|
78 |
-
|
79 |
-
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
|
80 |
-
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
|
81 |
-
dist.all_gather(size_list, local_size)
|
82 |
-
size_list = [int(size.item()) for size in size_list]
|
83 |
-
max_size = max(size_list)
|
84 |
-
|
85 |
-
tensor_list = []
|
86 |
-
for _ in size_list:
|
87 |
-
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
|
88 |
-
|
89 |
-
if local_size != max_size:
|
90 |
-
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
|
91 |
-
tensor = torch.cat((tensor, padding), 0)
|
92 |
-
|
93 |
-
dist.all_gather(tensor_list, tensor)
|
94 |
-
|
95 |
-
data_list = []
|
96 |
-
|
97 |
-
for size, tensor in zip(size_list, tensor_list):
|
98 |
-
buffer = tensor.cpu().numpy().tobytes()[:size]
|
99 |
-
data_list.append(pickle.loads(buffer))
|
100 |
-
|
101 |
-
return data_list
|
102 |
-
|
103 |
-
|
104 |
-
def reduce_loss_dict(loss_dict):
|
105 |
-
world_size = get_world_size()
|
106 |
-
|
107 |
-
if world_size < 2:
|
108 |
-
return loss_dict
|
109 |
-
|
110 |
-
with torch.no_grad():
|
111 |
-
keys = []
|
112 |
-
losses = []
|
113 |
-
|
114 |
-
for k in sorted(loss_dict.keys()):
|
115 |
-
keys.append(k)
|
116 |
-
losses.append(loss_dict[k])
|
117 |
-
|
118 |
-
losses = torch.stack(losses, 0)
|
119 |
-
dist.reduce(losses, dst=0)
|
120 |
-
|
121 |
-
if dist.get_rank() == 0:
|
122 |
-
losses /= world_size
|
123 |
-
|
124 |
-
reduced_losses = {k: v for k, v in zip(keys, losses)}
|
125 |
-
|
126 |
-
return reduced_losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
download.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
from transformers import BertConfig
|
2 |
-
|
3 |
-
model_name = 'bert-base-chinese' # bert版本名称
|
4 |
-
model_path = 'D:/Transformers-Bert/bert-base-chinese/' # 用户下载的预训练bert文件存放地址
|
5 |
-
config_path = 'D:/Transformers-Bert/bert-base-chinese/config.json' # 用户下载的预训练bert文件config.json存放地址
|
6 |
-
|
7 |
-
# 载入config 文件可以采取三种方式:bert名称、bert文件夹地址、config文件地址
|
8 |
-
# config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
9 |
-
# config = BertConfig.from_pretrained(model_path)
|
10 |
-
config = BertConfig.from_pretrained(config_path)
|
11 |
-
|
12 |
-
from transformers import BertTokenizer,BertModel
|
13 |
-
model_name = 'KenjieDec/Time-Travel-Rephotograph_e4e_ffhq_encode'
|
14 |
-
config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
15 |
-
tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
|
16 |
-
model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
|
17 |
-
|
18 |
-
model_name = 'KenjieDec/Time-Travel-Rephotography_stylegan2-ffhq-config-f'
|
19 |
-
config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
20 |
-
tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
|
21 |
-
model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
|
22 |
-
|
23 |
-
model_name = 'KenjieDec/Time-Travel-Rephotography_vgg_face_dag'
|
24 |
-
config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
25 |
-
tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
|
26 |
-
model = BertModel.from_pretrained(model_name)
|
27 |
-
|
28 |
-
model_name = 'D:\premodel\checkpoint_b'
|
29 |
-
config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
30 |
-
tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
|
31 |
-
model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
|
32 |
-
|
33 |
-
model_name = 'D:\premodel\checkpoint_g'
|
34 |
-
config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
35 |
-
tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
|
36 |
-
model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
|
37 |
-
|
38 |
-
model_name = 'D:\premodel\checkpoint_gb'
|
39 |
-
config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
40 |
-
tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
|
41 |
-
model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
|
42 |
-
|
43 |
-
model_name = 'clearspandex/face-parsing'
|
44 |
-
config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
|
45 |
-
tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
|
46 |
-
model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fid.py
DELETED
@@ -1,129 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import pickle
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import nn
|
6 |
-
import numpy as np
|
7 |
-
from scipy import linalg
|
8 |
-
from tqdm import tqdm
|
9 |
-
|
10 |
-
from model import Generator
|
11 |
-
from calc_inception import load_patched_inception_v3
|
12 |
-
|
13 |
-
|
14 |
-
@torch.no_grad()
|
15 |
-
def extract_feature_from_samples(
|
16 |
-
generator, inception, truncation, truncation_latent, batch_size, n_sample, device
|
17 |
-
):
|
18 |
-
n_batch = n_sample // batch_size
|
19 |
-
resid = n_sample - (n_batch * batch_size)
|
20 |
-
batch_sizes = [batch_size] * n_batch + [resid]
|
21 |
-
features = []
|
22 |
-
|
23 |
-
for batch in tqdm(batch_sizes):
|
24 |
-
latent = torch.randn(batch, 512, device=device)
|
25 |
-
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
|
26 |
-
feat = inception(img)[0].view(img.shape[0], -1)
|
27 |
-
features.append(feat.to("cpu"))
|
28 |
-
|
29 |
-
features = torch.cat(features, 0)
|
30 |
-
|
31 |
-
return features
|
32 |
-
|
33 |
-
|
34 |
-
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
|
35 |
-
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
|
36 |
-
|
37 |
-
if not np.isfinite(cov_sqrt).all():
|
38 |
-
print("product of cov matrices is singular")
|
39 |
-
offset = np.eye(sample_cov.shape[0]) * eps
|
40 |
-
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
|
41 |
-
|
42 |
-
if np.iscomplexobj(cov_sqrt):
|
43 |
-
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
|
44 |
-
m = np.max(np.abs(cov_sqrt.imag))
|
45 |
-
|
46 |
-
raise ValueError(f"Imaginary component {m}")
|
47 |
-
|
48 |
-
cov_sqrt = cov_sqrt.real
|
49 |
-
|
50 |
-
mean_diff = sample_mean - real_mean
|
51 |
-
mean_norm = mean_diff @ mean_diff
|
52 |
-
|
53 |
-
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
|
54 |
-
|
55 |
-
fid = mean_norm + trace
|
56 |
-
|
57 |
-
return fid
|
58 |
-
|
59 |
-
|
60 |
-
if __name__ == "__main__":
|
61 |
-
device = "cuda"
|
62 |
-
|
63 |
-
parser = argparse.ArgumentParser(description="Calculate FID scores")
|
64 |
-
|
65 |
-
parser.add_argument("--truncation", type=float, default=1, help="truncation factor")
|
66 |
-
parser.add_argument(
|
67 |
-
"--truncation_mean",
|
68 |
-
type=int,
|
69 |
-
default=4096,
|
70 |
-
help="number of samples to calculate mean for truncation",
|
71 |
-
)
|
72 |
-
parser.add_argument(
|
73 |
-
"--batch", type=int, default=64, help="batch size for the generator"
|
74 |
-
)
|
75 |
-
parser.add_argument(
|
76 |
-
"--n_sample",
|
77 |
-
type=int,
|
78 |
-
default=50000,
|
79 |
-
help="number of the samples for calculating FID",
|
80 |
-
)
|
81 |
-
parser.add_argument(
|
82 |
-
"--size", type=int, default=256, help="image sizes for generator"
|
83 |
-
)
|
84 |
-
parser.add_argument(
|
85 |
-
"--inception",
|
86 |
-
type=str,
|
87 |
-
default=None,
|
88 |
-
required=True,
|
89 |
-
help="path to precomputed inception embedding",
|
90 |
-
)
|
91 |
-
parser.add_argument(
|
92 |
-
"ckpt", metavar="CHECKPOINT", help="path to generator checkpoint"
|
93 |
-
)
|
94 |
-
|
95 |
-
args = parser.parse_args()
|
96 |
-
|
97 |
-
ckpt = torch.load(args.ckpt)
|
98 |
-
|
99 |
-
g = Generator(args.size, 512, 8).to(device)
|
100 |
-
g.load_state_dict(ckpt["g_ema"])
|
101 |
-
g = nn.DataParallel(g)
|
102 |
-
g.eval()
|
103 |
-
|
104 |
-
if args.truncation < 1:
|
105 |
-
with torch.no_grad():
|
106 |
-
mean_latent = g.mean_latent(args.truncation_mean)
|
107 |
-
|
108 |
-
else:
|
109 |
-
mean_latent = None
|
110 |
-
|
111 |
-
inception = nn.DataParallel(load_patched_inception_v3()).to(device)
|
112 |
-
inception.eval()
|
113 |
-
|
114 |
-
features = extract_feature_from_samples(
|
115 |
-
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
|
116 |
-
).numpy()
|
117 |
-
print(f"extracted {features.shape[0]} features")
|
118 |
-
|
119 |
-
sample_mean = np.mean(features, 0)
|
120 |
-
sample_cov = np.cov(features, rowvar=False)
|
121 |
-
|
122 |
-
with open(args.inception, "rb") as f:
|
123 |
-
embeds = pickle.load(f)
|
124 |
-
real_mean = embeds["mean"]
|
125 |
-
real_cov = embeds["cov"]
|
126 |
-
|
127 |
-
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
|
128 |
-
|
129 |
-
print("fid:", fid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torchvision import utils
|
5 |
-
from model import Generator
|
6 |
-
from tqdm import tqdm
|
7 |
-
|
8 |
-
|
9 |
-
def generate(args, g_ema, device, mean_latent):
|
10 |
-
|
11 |
-
with torch.no_grad():
|
12 |
-
g_ema.eval()
|
13 |
-
for i in tqdm(range(args.pics)):
|
14 |
-
sample_z = torch.randn(args.sample, args.latent, device=device)
|
15 |
-
|
16 |
-
sample, _ = g_ema(
|
17 |
-
[sample_z], truncation=args.truncation, truncation_latent=mean_latent
|
18 |
-
)
|
19 |
-
|
20 |
-
utils.save_image(
|
21 |
-
sample,
|
22 |
-
f"sample/{str(i).zfill(6)}.png",
|
23 |
-
nrow=1,
|
24 |
-
normalize=True,
|
25 |
-
range=(-1, 1),
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
if __name__ == "__main__":
|
30 |
-
device = "cuda"
|
31 |
-
|
32 |
-
parser = argparse.ArgumentParser(description="Generate samples from the generator")
|
33 |
-
|
34 |
-
parser.add_argument(
|
35 |
-
"--size", type=int, default=1024, help="output image size of the generator"
|
36 |
-
)
|
37 |
-
parser.add_argument(
|
38 |
-
"--sample",
|
39 |
-
type=int,
|
40 |
-
default=1,
|
41 |
-
help="number of samples to be generated for each image",
|
42 |
-
)
|
43 |
-
parser.add_argument(
|
44 |
-
"--pics", type=int, default=20, help="number of images to be generated"
|
45 |
-
)
|
46 |
-
parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
|
47 |
-
parser.add_argument(
|
48 |
-
"--truncation_mean",
|
49 |
-
type=int,
|
50 |
-
default=4096,
|
51 |
-
help="number of vectors to calculate mean for the truncation",
|
52 |
-
)
|
53 |
-
parser.add_argument(
|
54 |
-
"--ckpt",
|
55 |
-
type=str,
|
56 |
-
default="stylegan2-ffhq-config-f.pt",
|
57 |
-
help="path to the model checkpoint",
|
58 |
-
)
|
59 |
-
parser.add_argument(
|
60 |
-
"--channel_multiplier",
|
61 |
-
type=int,
|
62 |
-
default=2,
|
63 |
-
help="channel multiplier of the generator. config-f = 2, else = 1",
|
64 |
-
)
|
65 |
-
|
66 |
-
args = parser.parse_args()
|
67 |
-
|
68 |
-
args.latent = 512
|
69 |
-
args.n_mlp = 8
|
70 |
-
|
71 |
-
g_ema = Generator(
|
72 |
-
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
73 |
-
).to(device)
|
74 |
-
checkpoint = torch.load(args.ckpt)
|
75 |
-
|
76 |
-
g_ema.load_state_dict(checkpoint["g_ema"])
|
77 |
-
|
78 |
-
if args.truncation < 1:
|
79 |
-
with torch.no_grad():
|
80 |
-
mean_latent = g_ema.mean_latent(args.truncation_mean)
|
81 |
-
else:
|
82 |
-
mean_latent = None
|
83 |
-
|
84 |
-
generate(args, g_ema, device, mean_latent)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inception.py
DELETED
@@ -1,310 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from torchvision import models
|
5 |
-
|
6 |
-
try:
|
7 |
-
from torchvision.models.utils import load_state_dict_from_url
|
8 |
-
except ImportError:
|
9 |
-
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
-
|
11 |
-
# Inception weights ported to Pytorch from
|
12 |
-
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
-
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
14 |
-
|
15 |
-
|
16 |
-
class InceptionV3(nn.Module):
|
17 |
-
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
-
|
19 |
-
# Index of default block of inception to return,
|
20 |
-
# corresponds to output of final average pooling
|
21 |
-
DEFAULT_BLOCK_INDEX = 3
|
22 |
-
|
23 |
-
# Maps feature dimensionality to their output blocks indices
|
24 |
-
BLOCK_INDEX_BY_DIM = {
|
25 |
-
64: 0, # First max pooling features
|
26 |
-
192: 1, # Second max pooling featurs
|
27 |
-
768: 2, # Pre-aux classifier features
|
28 |
-
2048: 3 # Final average pooling features
|
29 |
-
}
|
30 |
-
|
31 |
-
def __init__(self,
|
32 |
-
output_blocks=[DEFAULT_BLOCK_INDEX],
|
33 |
-
resize_input=True,
|
34 |
-
normalize_input=True,
|
35 |
-
requires_grad=False,
|
36 |
-
use_fid_inception=True):
|
37 |
-
"""Build pretrained InceptionV3
|
38 |
-
|
39 |
-
Parameters
|
40 |
-
----------
|
41 |
-
output_blocks : list of int
|
42 |
-
Indices of blocks to return features of. Possible values are:
|
43 |
-
- 0: corresponds to output of first max pooling
|
44 |
-
- 1: corresponds to output of second max pooling
|
45 |
-
- 2: corresponds to output which is fed to aux classifier
|
46 |
-
- 3: corresponds to output of final average pooling
|
47 |
-
resize_input : bool
|
48 |
-
If true, bilinearly resizes input to width and height 299 before
|
49 |
-
feeding input to model. As the network without fully connected
|
50 |
-
layers is fully convolutional, it should be able to handle inputs
|
51 |
-
of arbitrary size, so resizing might not be strictly needed
|
52 |
-
normalize_input : bool
|
53 |
-
If true, scales the input from range (0, 1) to the range the
|
54 |
-
pretrained Inception network expects, namely (-1, 1)
|
55 |
-
requires_grad : bool
|
56 |
-
If true, parameters of the model require gradients. Possibly useful
|
57 |
-
for finetuning the network
|
58 |
-
use_fid_inception : bool
|
59 |
-
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
-
FID implementation. If false, uses the pretrained Inception model
|
61 |
-
available in torchvision. The FID Inception model has different
|
62 |
-
weights and a slightly different structure from torchvision's
|
63 |
-
Inception model. If you want to compute FID scores, you are
|
64 |
-
strongly advised to set this parameter to true to get comparable
|
65 |
-
results.
|
66 |
-
"""
|
67 |
-
super(InceptionV3, self).__init__()
|
68 |
-
|
69 |
-
self.resize_input = resize_input
|
70 |
-
self.normalize_input = normalize_input
|
71 |
-
self.output_blocks = sorted(output_blocks)
|
72 |
-
self.last_needed_block = max(output_blocks)
|
73 |
-
|
74 |
-
assert self.last_needed_block <= 3, \
|
75 |
-
'Last possible output block index is 3'
|
76 |
-
|
77 |
-
self.blocks = nn.ModuleList()
|
78 |
-
|
79 |
-
if use_fid_inception:
|
80 |
-
inception = fid_inception_v3()
|
81 |
-
else:
|
82 |
-
inception = models.inception_v3(pretrained=True)
|
83 |
-
|
84 |
-
# Block 0: input to maxpool1
|
85 |
-
block0 = [
|
86 |
-
inception.Conv2d_1a_3x3,
|
87 |
-
inception.Conv2d_2a_3x3,
|
88 |
-
inception.Conv2d_2b_3x3,
|
89 |
-
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
-
]
|
91 |
-
self.blocks.append(nn.Sequential(*block0))
|
92 |
-
|
93 |
-
# Block 1: maxpool1 to maxpool2
|
94 |
-
if self.last_needed_block >= 1:
|
95 |
-
block1 = [
|
96 |
-
inception.Conv2d_3b_1x1,
|
97 |
-
inception.Conv2d_4a_3x3,
|
98 |
-
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
-
]
|
100 |
-
self.blocks.append(nn.Sequential(*block1))
|
101 |
-
|
102 |
-
# Block 2: maxpool2 to aux classifier
|
103 |
-
if self.last_needed_block >= 2:
|
104 |
-
block2 = [
|
105 |
-
inception.Mixed_5b,
|
106 |
-
inception.Mixed_5c,
|
107 |
-
inception.Mixed_5d,
|
108 |
-
inception.Mixed_6a,
|
109 |
-
inception.Mixed_6b,
|
110 |
-
inception.Mixed_6c,
|
111 |
-
inception.Mixed_6d,
|
112 |
-
inception.Mixed_6e,
|
113 |
-
]
|
114 |
-
self.blocks.append(nn.Sequential(*block2))
|
115 |
-
|
116 |
-
# Block 3: aux classifier to final avgpool
|
117 |
-
if self.last_needed_block >= 3:
|
118 |
-
block3 = [
|
119 |
-
inception.Mixed_7a,
|
120 |
-
inception.Mixed_7b,
|
121 |
-
inception.Mixed_7c,
|
122 |
-
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
-
]
|
124 |
-
self.blocks.append(nn.Sequential(*block3))
|
125 |
-
|
126 |
-
for param in self.parameters():
|
127 |
-
param.requires_grad = requires_grad
|
128 |
-
|
129 |
-
def forward(self, inp):
|
130 |
-
"""Get Inception feature maps
|
131 |
-
|
132 |
-
Parameters
|
133 |
-
----------
|
134 |
-
inp : torch.autograd.Variable
|
135 |
-
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
-
range (0, 1)
|
137 |
-
|
138 |
-
Returns
|
139 |
-
-------
|
140 |
-
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
-
block, sorted ascending by index
|
142 |
-
"""
|
143 |
-
outp = []
|
144 |
-
x = inp
|
145 |
-
|
146 |
-
if self.resize_input:
|
147 |
-
x = F.interpolate(x,
|
148 |
-
size=(299, 299),
|
149 |
-
mode='bilinear',
|
150 |
-
align_corners=False)
|
151 |
-
|
152 |
-
if self.normalize_input:
|
153 |
-
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
-
|
155 |
-
for idx, block in enumerate(self.blocks):
|
156 |
-
x = block(x)
|
157 |
-
if idx in self.output_blocks:
|
158 |
-
outp.append(x)
|
159 |
-
|
160 |
-
if idx == self.last_needed_block:
|
161 |
-
break
|
162 |
-
|
163 |
-
return outp
|
164 |
-
|
165 |
-
|
166 |
-
def fid_inception_v3():
|
167 |
-
"""Build pretrained Inception model for FID computation
|
168 |
-
|
169 |
-
The Inception model for FID computation uses a different set of weights
|
170 |
-
and has a slightly different structure than torchvision's Inception.
|
171 |
-
|
172 |
-
This method first constructs torchvision's Inception and then patches the
|
173 |
-
necessary parts that are different in the FID Inception model.
|
174 |
-
"""
|
175 |
-
inception = models.inception_v3(num_classes=1008,
|
176 |
-
aux_logits=False,
|
177 |
-
pretrained=False)
|
178 |
-
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
179 |
-
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
180 |
-
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
181 |
-
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
182 |
-
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
183 |
-
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
184 |
-
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
185 |
-
inception.Mixed_7b = FIDInceptionE_1(1280)
|
186 |
-
inception.Mixed_7c = FIDInceptionE_2(2048)
|
187 |
-
|
188 |
-
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
189 |
-
inception.load_state_dict(state_dict)
|
190 |
-
return inception
|
191 |
-
|
192 |
-
|
193 |
-
class FIDInceptionA(models.inception.InceptionA):
|
194 |
-
"""InceptionA block patched for FID computation"""
|
195 |
-
def __init__(self, in_channels, pool_features):
|
196 |
-
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
197 |
-
|
198 |
-
def forward(self, x):
|
199 |
-
branch1x1 = self.branch1x1(x)
|
200 |
-
|
201 |
-
branch5x5 = self.branch5x5_1(x)
|
202 |
-
branch5x5 = self.branch5x5_2(branch5x5)
|
203 |
-
|
204 |
-
branch3x3dbl = self.branch3x3dbl_1(x)
|
205 |
-
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
206 |
-
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
207 |
-
|
208 |
-
# Patch: Tensorflow's average pool does not use the padded zero's in
|
209 |
-
# its average calculation
|
210 |
-
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
211 |
-
count_include_pad=False)
|
212 |
-
branch_pool = self.branch_pool(branch_pool)
|
213 |
-
|
214 |
-
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
215 |
-
return torch.cat(outputs, 1)
|
216 |
-
|
217 |
-
|
218 |
-
class FIDInceptionC(models.inception.InceptionC):
|
219 |
-
"""InceptionC block patched for FID computation"""
|
220 |
-
def __init__(self, in_channels, channels_7x7):
|
221 |
-
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
222 |
-
|
223 |
-
def forward(self, x):
|
224 |
-
branch1x1 = self.branch1x1(x)
|
225 |
-
|
226 |
-
branch7x7 = self.branch7x7_1(x)
|
227 |
-
branch7x7 = self.branch7x7_2(branch7x7)
|
228 |
-
branch7x7 = self.branch7x7_3(branch7x7)
|
229 |
-
|
230 |
-
branch7x7dbl = self.branch7x7dbl_1(x)
|
231 |
-
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
232 |
-
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
233 |
-
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
234 |
-
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
235 |
-
|
236 |
-
# Patch: Tensorflow's average pool does not use the padded zero's in
|
237 |
-
# its average calculation
|
238 |
-
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
239 |
-
count_include_pad=False)
|
240 |
-
branch_pool = self.branch_pool(branch_pool)
|
241 |
-
|
242 |
-
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
243 |
-
return torch.cat(outputs, 1)
|
244 |
-
|
245 |
-
|
246 |
-
class FIDInceptionE_1(models.inception.InceptionE):
|
247 |
-
"""First InceptionE block patched for FID computation"""
|
248 |
-
def __init__(self, in_channels):
|
249 |
-
super(FIDInceptionE_1, self).__init__(in_channels)
|
250 |
-
|
251 |
-
def forward(self, x):
|
252 |
-
branch1x1 = self.branch1x1(x)
|
253 |
-
|
254 |
-
branch3x3 = self.branch3x3_1(x)
|
255 |
-
branch3x3 = [
|
256 |
-
self.branch3x3_2a(branch3x3),
|
257 |
-
self.branch3x3_2b(branch3x3),
|
258 |
-
]
|
259 |
-
branch3x3 = torch.cat(branch3x3, 1)
|
260 |
-
|
261 |
-
branch3x3dbl = self.branch3x3dbl_1(x)
|
262 |
-
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
263 |
-
branch3x3dbl = [
|
264 |
-
self.branch3x3dbl_3a(branch3x3dbl),
|
265 |
-
self.branch3x3dbl_3b(branch3x3dbl),
|
266 |
-
]
|
267 |
-
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
268 |
-
|
269 |
-
# Patch: Tensorflow's average pool does not use the padded zero's in
|
270 |
-
# its average calculation
|
271 |
-
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
272 |
-
count_include_pad=False)
|
273 |
-
branch_pool = self.branch_pool(branch_pool)
|
274 |
-
|
275 |
-
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
276 |
-
return torch.cat(outputs, 1)
|
277 |
-
|
278 |
-
|
279 |
-
class FIDInceptionE_2(models.inception.InceptionE):
|
280 |
-
"""Second InceptionE block patched for FID computation"""
|
281 |
-
def __init__(self, in_channels):
|
282 |
-
super(FIDInceptionE_2, self).__init__(in_channels)
|
283 |
-
|
284 |
-
def forward(self, x):
|
285 |
-
branch1x1 = self.branch1x1(x)
|
286 |
-
|
287 |
-
branch3x3 = self.branch3x3_1(x)
|
288 |
-
branch3x3 = [
|
289 |
-
self.branch3x3_2a(branch3x3),
|
290 |
-
self.branch3x3_2b(branch3x3),
|
291 |
-
]
|
292 |
-
branch3x3 = torch.cat(branch3x3, 1)
|
293 |
-
|
294 |
-
branch3x3dbl = self.branch3x3dbl_1(x)
|
295 |
-
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
296 |
-
branch3x3dbl = [
|
297 |
-
self.branch3x3dbl_3a(branch3x3dbl),
|
298 |
-
self.branch3x3dbl_3b(branch3x3dbl),
|
299 |
-
]
|
300 |
-
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
301 |
-
|
302 |
-
# Patch: The FID Inception model uses max pooling instead of average
|
303 |
-
# pooling. This is likely an error in this specific Inception
|
304 |
-
# implementation, as other Inception models use average pooling here
|
305 |
-
# (which matches the description in the paper).
|
306 |
-
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
307 |
-
branch_pool = self.branch_pool(branch_pool)
|
308 |
-
|
309 |
-
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
310 |
-
return torch.cat(outputs, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lpips/__init__.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
|
2 |
-
from __future__ import absolute_import
|
3 |
-
from __future__ import division
|
4 |
-
from __future__ import print_function
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
from skimage.measure import compare_ssim
|
8 |
-
import torch
|
9 |
-
from torch.autograd import Variable
|
10 |
-
|
11 |
-
from lpips import dist_model
|
12 |
-
|
13 |
-
class PerceptualLoss(torch.nn.Module):
|
14 |
-
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
|
15 |
-
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
16 |
-
super(PerceptualLoss, self).__init__()
|
17 |
-
print('Setting up Perceptual loss...')
|
18 |
-
self.use_gpu = use_gpu
|
19 |
-
self.spatial = spatial
|
20 |
-
self.gpu_ids = gpu_ids
|
21 |
-
self.model = dist_model.DistModel()
|
22 |
-
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
|
23 |
-
print('...[%s] initialized'%self.model.name())
|
24 |
-
print('...Done')
|
25 |
-
|
26 |
-
def forward(self, pred, target, normalize=False):
|
27 |
-
"""
|
28 |
-
Pred and target are Variables.
|
29 |
-
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
30 |
-
If normalize is False, assumes the images are already between [-1,+1]
|
31 |
-
|
32 |
-
Inputs pred and target are Nx3xHxW
|
33 |
-
Output pytorch Variable N long
|
34 |
-
"""
|
35 |
-
|
36 |
-
if normalize:
|
37 |
-
target = 2 * target - 1
|
38 |
-
pred = 2 * pred - 1
|
39 |
-
|
40 |
-
return self.model.forward(target, pred)
|
41 |
-
|
42 |
-
def normalize_tensor(in_feat,eps=1e-10):
|
43 |
-
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
44 |
-
return in_feat/(norm_factor+eps)
|
45 |
-
|
46 |
-
def l2(p0, p1, range=255.):
|
47 |
-
return .5*np.mean((p0 / range - p1 / range)**2)
|
48 |
-
|
49 |
-
def psnr(p0, p1, peak=255.):
|
50 |
-
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
|
51 |
-
|
52 |
-
def dssim(p0, p1, range=255.):
|
53 |
-
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
54 |
-
|
55 |
-
def rgb2lab(in_img,mean_cent=False):
|
56 |
-
from skimage import color
|
57 |
-
img_lab = color.rgb2lab(in_img)
|
58 |
-
if(mean_cent):
|
59 |
-
img_lab[:,:,0] = img_lab[:,:,0]-50
|
60 |
-
return img_lab
|
61 |
-
|
62 |
-
def tensor2np(tensor_obj):
|
63 |
-
# change dimension of a tensor object into a numpy array
|
64 |
-
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
65 |
-
|
66 |
-
def np2tensor(np_obj):
|
67 |
-
# change dimenion of np array into tensor array
|
68 |
-
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
69 |
-
|
70 |
-
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
71 |
-
# image tensor to lab tensor
|
72 |
-
from skimage import color
|
73 |
-
|
74 |
-
img = tensor2im(image_tensor)
|
75 |
-
img_lab = color.rgb2lab(img)
|
76 |
-
if(mc_only):
|
77 |
-
img_lab[:,:,0] = img_lab[:,:,0]-50
|
78 |
-
if(to_norm and not mc_only):
|
79 |
-
img_lab[:,:,0] = img_lab[:,:,0]-50
|
80 |
-
img_lab = img_lab/100.
|
81 |
-
|
82 |
-
return np2tensor(img_lab)
|
83 |
-
|
84 |
-
def tensorlab2tensor(lab_tensor,return_inbnd=False):
|
85 |
-
from skimage import color
|
86 |
-
import warnings
|
87 |
-
warnings.filterwarnings("ignore")
|
88 |
-
|
89 |
-
lab = tensor2np(lab_tensor)*100.
|
90 |
-
lab[:,:,0] = lab[:,:,0]+50
|
91 |
-
|
92 |
-
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
|
93 |
-
if(return_inbnd):
|
94 |
-
# convert back to lab, see if we match
|
95 |
-
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
96 |
-
mask = 1.*np.isclose(lab_back,lab,atol=2.)
|
97 |
-
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
|
98 |
-
return (im2tensor(rgb_back),mask)
|
99 |
-
else:
|
100 |
-
return im2tensor(rgb_back)
|
101 |
-
|
102 |
-
def rgb2lab(input):
|
103 |
-
from skimage import color
|
104 |
-
return color.rgb2lab(input / 255.)
|
105 |
-
|
106 |
-
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
107 |
-
image_numpy = image_tensor[0].cpu().float().numpy()
|
108 |
-
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
109 |
-
return image_numpy.astype(imtype)
|
110 |
-
|
111 |
-
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
112 |
-
return torch.Tensor((image / factor - cent)
|
113 |
-
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
114 |
-
|
115 |
-
def tensor2vec(vector_tensor):
|
116 |
-
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
117 |
-
|
118 |
-
def voc_ap(rec, prec, use_07_metric=False):
|
119 |
-
""" ap = voc_ap(rec, prec, [use_07_metric])
|
120 |
-
Compute VOC AP given precision and recall.
|
121 |
-
If use_07_metric is true, uses the
|
122 |
-
VOC 07 11 point method (default:False).
|
123 |
-
"""
|
124 |
-
if use_07_metric:
|
125 |
-
# 11 point metric
|
126 |
-
ap = 0.
|
127 |
-
for t in np.arange(0., 1.1, 0.1):
|
128 |
-
if np.sum(rec >= t) == 0:
|
129 |
-
p = 0
|
130 |
-
else:
|
131 |
-
p = np.max(prec[rec >= t])
|
132 |
-
ap = ap + p / 11.
|
133 |
-
else:
|
134 |
-
# correct AP calculation
|
135 |
-
# first append sentinel values at the end
|
136 |
-
mrec = np.concatenate(([0.], rec, [1.]))
|
137 |
-
mpre = np.concatenate(([0.], prec, [0.]))
|
138 |
-
|
139 |
-
# compute the precision envelope
|
140 |
-
for i in range(mpre.size - 1, 0, -1):
|
141 |
-
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
142 |
-
|
143 |
-
# to calculate area under PR curve, look for points
|
144 |
-
# where X axis (recall) changes value
|
145 |
-
i = np.where(mrec[1:] != mrec[:-1])[0]
|
146 |
-
|
147 |
-
# and sum (\Delta recall) * prec
|
148 |
-
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
149 |
-
return ap
|
150 |
-
|
151 |
-
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
152 |
-
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
153 |
-
image_numpy = image_tensor[0].cpu().float().numpy()
|
154 |
-
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
155 |
-
return image_numpy.astype(imtype)
|
156 |
-
|
157 |
-
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
158 |
-
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
159 |
-
return torch.Tensor((image / factor - cent)
|
160 |
-
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lpips/base_model.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
from torch.autograd import Variable
|
5 |
-
from pdb import set_trace as st
|
6 |
-
from IPython import embed
|
7 |
-
|
8 |
-
class BaseModel():
|
9 |
-
def __init__(self):
|
10 |
-
pass;
|
11 |
-
|
12 |
-
def name(self):
|
13 |
-
return 'BaseModel'
|
14 |
-
|
15 |
-
def initialize(self, use_gpu=True, gpu_ids=[0]):
|
16 |
-
self.use_gpu = use_gpu
|
17 |
-
self.gpu_ids = gpu_ids
|
18 |
-
|
19 |
-
def forward(self):
|
20 |
-
pass
|
21 |
-
|
22 |
-
def get_image_paths(self):
|
23 |
-
pass
|
24 |
-
|
25 |
-
def optimize_parameters(self):
|
26 |
-
pass
|
27 |
-
|
28 |
-
def get_current_visuals(self):
|
29 |
-
return self.input
|
30 |
-
|
31 |
-
def get_current_errors(self):
|
32 |
-
return {}
|
33 |
-
|
34 |
-
def save(self, label):
|
35 |
-
pass
|
36 |
-
|
37 |
-
# helper saving function that can be used by subclasses
|
38 |
-
def save_network(self, network, path, network_label, epoch_label):
|
39 |
-
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
40 |
-
save_path = os.path.join(path, save_filename)
|
41 |
-
torch.save(network.state_dict(), save_path)
|
42 |
-
|
43 |
-
# helper loading function that can be used by subclasses
|
44 |
-
def load_network(self, network, network_label, epoch_label):
|
45 |
-
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
46 |
-
save_path = os.path.join(self.save_dir, save_filename)
|
47 |
-
print('Loading network from %s'%save_path)
|
48 |
-
network.load_state_dict(torch.load(save_path))
|
49 |
-
|
50 |
-
def update_learning_rate():
|
51 |
-
pass
|
52 |
-
|
53 |
-
def get_image_paths(self):
|
54 |
-
return self.image_paths
|
55 |
-
|
56 |
-
def save_done(self, flag=False):
|
57 |
-
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
|
58 |
-
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lpips/dist_model.py
DELETED
@@ -1,284 +0,0 @@
|
|
1 |
-
|
2 |
-
from __future__ import absolute_import
|
3 |
-
|
4 |
-
import sys
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
from torch import nn
|
8 |
-
import os
|
9 |
-
from collections import OrderedDict
|
10 |
-
from torch.autograd import Variable
|
11 |
-
import itertools
|
12 |
-
from .base_model import BaseModel
|
13 |
-
from scipy.ndimage import zoom
|
14 |
-
import fractions
|
15 |
-
import functools
|
16 |
-
import skimage.transform
|
17 |
-
from tqdm import tqdm
|
18 |
-
|
19 |
-
from IPython import embed
|
20 |
-
|
21 |
-
from . import networks_basic as networks
|
22 |
-
import lpips as util
|
23 |
-
|
24 |
-
class DistModel(BaseModel):
|
25 |
-
def name(self):
|
26 |
-
return self.model_name
|
27 |
-
|
28 |
-
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
29 |
-
use_gpu=True, printNet=False, spatial=False,
|
30 |
-
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
31 |
-
'''
|
32 |
-
INPUTS
|
33 |
-
model - ['net-lin'] for linearly calibrated network
|
34 |
-
['net'] for off-the-shelf network
|
35 |
-
['L2'] for L2 distance in Lab colorspace
|
36 |
-
['SSIM'] for ssim in RGB colorspace
|
37 |
-
net - ['squeeze','alex','vgg']
|
38 |
-
model_path - if None, will look in weights/[NET_NAME].pth
|
39 |
-
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
40 |
-
use_gpu - bool - whether or not to use a GPU
|
41 |
-
printNet - bool - whether or not to print network architecture out
|
42 |
-
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
43 |
-
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
44 |
-
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
45 |
-
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
46 |
-
is_train - bool - [True] for training mode
|
47 |
-
lr - float - initial learning rate
|
48 |
-
beta1 - float - initial momentum term for adam
|
49 |
-
version - 0.1 for latest, 0.0 was original (with a bug)
|
50 |
-
gpu_ids - int array - [0] by default, gpus to use
|
51 |
-
'''
|
52 |
-
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
53 |
-
|
54 |
-
self.model = model
|
55 |
-
self.net = net
|
56 |
-
self.is_train = is_train
|
57 |
-
self.spatial = spatial
|
58 |
-
self.gpu_ids = gpu_ids
|
59 |
-
self.model_name = '%s [%s]'%(model,net)
|
60 |
-
|
61 |
-
if(self.model == 'net-lin'): # pretrained net + linear layer
|
62 |
-
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
63 |
-
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
64 |
-
kw = {}
|
65 |
-
if not use_gpu:
|
66 |
-
kw['map_location'] = 'cpu'
|
67 |
-
if(model_path is None):
|
68 |
-
import inspect
|
69 |
-
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
|
70 |
-
|
71 |
-
if(not is_train):
|
72 |
-
print('Loading model from: %s'%model_path)
|
73 |
-
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
74 |
-
|
75 |
-
elif(self.model=='net'): # pretrained network
|
76 |
-
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
77 |
-
elif(self.model in ['L2','l2']):
|
78 |
-
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
|
79 |
-
self.model_name = 'L2'
|
80 |
-
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
|
81 |
-
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
|
82 |
-
self.model_name = 'SSIM'
|
83 |
-
else:
|
84 |
-
raise ValueError("Model [%s] not recognized." % self.model)
|
85 |
-
|
86 |
-
self.parameters = list(self.net.parameters())
|
87 |
-
|
88 |
-
if self.is_train: # training mode
|
89 |
-
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
90 |
-
self.rankLoss = networks.BCERankingLoss()
|
91 |
-
self.parameters += list(self.rankLoss.net.parameters())
|
92 |
-
self.lr = lr
|
93 |
-
self.old_lr = lr
|
94 |
-
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
95 |
-
else: # test mode
|
96 |
-
self.net.eval()
|
97 |
-
|
98 |
-
if(use_gpu):
|
99 |
-
self.net.to(gpu_ids[0])
|
100 |
-
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
101 |
-
if(self.is_train):
|
102 |
-
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
103 |
-
|
104 |
-
if(printNet):
|
105 |
-
print('---------- Networks initialized -------------')
|
106 |
-
networks.print_network(self.net)
|
107 |
-
print('-----------------------------------------------')
|
108 |
-
|
109 |
-
def forward(self, in0, in1, retPerLayer=False):
|
110 |
-
''' Function computes the distance between image patches in0 and in1
|
111 |
-
INPUTS
|
112 |
-
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
113 |
-
OUTPUT
|
114 |
-
computed distances between in0 and in1
|
115 |
-
'''
|
116 |
-
|
117 |
-
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
118 |
-
|
119 |
-
# ***** TRAINING FUNCTIONS *****
|
120 |
-
def optimize_parameters(self):
|
121 |
-
self.forward_train()
|
122 |
-
self.optimizer_net.zero_grad()
|
123 |
-
self.backward_train()
|
124 |
-
self.optimizer_net.step()
|
125 |
-
self.clamp_weights()
|
126 |
-
|
127 |
-
def clamp_weights(self):
|
128 |
-
for module in self.net.modules():
|
129 |
-
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
|
130 |
-
module.weight.data = torch.clamp(module.weight.data,min=0)
|
131 |
-
|
132 |
-
def set_input(self, data):
|
133 |
-
self.input_ref = data['ref']
|
134 |
-
self.input_p0 = data['p0']
|
135 |
-
self.input_p1 = data['p1']
|
136 |
-
self.input_judge = data['judge']
|
137 |
-
|
138 |
-
if(self.use_gpu):
|
139 |
-
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
140 |
-
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
141 |
-
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
142 |
-
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
143 |
-
|
144 |
-
self.var_ref = Variable(self.input_ref,requires_grad=True)
|
145 |
-
self.var_p0 = Variable(self.input_p0,requires_grad=True)
|
146 |
-
self.var_p1 = Variable(self.input_p1,requires_grad=True)
|
147 |
-
|
148 |
-
def forward_train(self): # run forward pass
|
149 |
-
# print(self.net.module.scaling_layer.shift)
|
150 |
-
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
151 |
-
|
152 |
-
self.d0 = self.forward(self.var_ref, self.var_p0)
|
153 |
-
self.d1 = self.forward(self.var_ref, self.var_p1)
|
154 |
-
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
|
155 |
-
|
156 |
-
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
|
157 |
-
|
158 |
-
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
|
159 |
-
|
160 |
-
return self.loss_total
|
161 |
-
|
162 |
-
def backward_train(self):
|
163 |
-
torch.mean(self.loss_total).backward()
|
164 |
-
|
165 |
-
def compute_accuracy(self,d0,d1,judge):
|
166 |
-
''' d0, d1 are Variables, judge is a Tensor '''
|
167 |
-
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
|
168 |
-
judge_per = judge.cpu().numpy().flatten()
|
169 |
-
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
|
170 |
-
|
171 |
-
def get_current_errors(self):
|
172 |
-
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
173 |
-
('acc_r', self.acc_r)])
|
174 |
-
|
175 |
-
for key in retDict.keys():
|
176 |
-
retDict[key] = np.mean(retDict[key])
|
177 |
-
|
178 |
-
return retDict
|
179 |
-
|
180 |
-
def get_current_visuals(self):
|
181 |
-
zoom_factor = 256/self.var_ref.data.size()[2]
|
182 |
-
|
183 |
-
ref_img = util.tensor2im(self.var_ref.data)
|
184 |
-
p0_img = util.tensor2im(self.var_p0.data)
|
185 |
-
p1_img = util.tensor2im(self.var_p1.data)
|
186 |
-
|
187 |
-
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
|
188 |
-
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
|
189 |
-
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
|
190 |
-
|
191 |
-
return OrderedDict([('ref', ref_img_vis),
|
192 |
-
('p0', p0_img_vis),
|
193 |
-
('p1', p1_img_vis)])
|
194 |
-
|
195 |
-
def save(self, path, label):
|
196 |
-
if(self.use_gpu):
|
197 |
-
self.save_network(self.net.module, path, '', label)
|
198 |
-
else:
|
199 |
-
self.save_network(self.net, path, '', label)
|
200 |
-
self.save_network(self.rankLoss.net, path, 'rank', label)
|
201 |
-
|
202 |
-
def update_learning_rate(self,nepoch_decay):
|
203 |
-
lrd = self.lr / nepoch_decay
|
204 |
-
lr = self.old_lr - lrd
|
205 |
-
|
206 |
-
for param_group in self.optimizer_net.param_groups:
|
207 |
-
param_group['lr'] = lr
|
208 |
-
|
209 |
-
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
|
210 |
-
self.old_lr = lr
|
211 |
-
|
212 |
-
def score_2afc_dataset(data_loader, func, name=''):
|
213 |
-
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
214 |
-
distance function 'func' in dataset 'data_loader'
|
215 |
-
INPUTS
|
216 |
-
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
217 |
-
func - callable distance function - calling d=func(in0,in1) should take 2
|
218 |
-
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
219 |
-
OUTPUTS
|
220 |
-
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
221 |
-
[1] - dictionary with following elements
|
222 |
-
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
223 |
-
gts - N array in [0,1], preferred patch selected by human evaluators
|
224 |
-
(closer to "0" for left patch p0, "1" for right patch p1,
|
225 |
-
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
226 |
-
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
227 |
-
CONSTS
|
228 |
-
N - number of test triplets in data_loader
|
229 |
-
'''
|
230 |
-
|
231 |
-
d0s = []
|
232 |
-
d1s = []
|
233 |
-
gts = []
|
234 |
-
|
235 |
-
for data in tqdm(data_loader.load_data(), desc=name):
|
236 |
-
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
|
237 |
-
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
|
238 |
-
gts+=data['judge'].cpu().numpy().flatten().tolist()
|
239 |
-
|
240 |
-
d0s = np.array(d0s)
|
241 |
-
d1s = np.array(d1s)
|
242 |
-
gts = np.array(gts)
|
243 |
-
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
|
244 |
-
|
245 |
-
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
|
246 |
-
|
247 |
-
def score_jnd_dataset(data_loader, func, name=''):
|
248 |
-
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
249 |
-
INPUTS
|
250 |
-
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
251 |
-
func - callable distance function - calling d=func(in0,in1) should take 2
|
252 |
-
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
253 |
-
OUTPUTS
|
254 |
-
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
255 |
-
[1] - dictionary with following elements
|
256 |
-
ds - N array containing distances between two patches shown to human evaluator
|
257 |
-
sames - N array containing fraction of people who thought the two patches were identical
|
258 |
-
CONSTS
|
259 |
-
N - number of test triplets in data_loader
|
260 |
-
'''
|
261 |
-
|
262 |
-
ds = []
|
263 |
-
gts = []
|
264 |
-
|
265 |
-
for data in tqdm(data_loader.load_data(), desc=name):
|
266 |
-
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
|
267 |
-
gts+=data['same'].cpu().numpy().flatten().tolist()
|
268 |
-
|
269 |
-
sames = np.array(gts)
|
270 |
-
ds = np.array(ds)
|
271 |
-
|
272 |
-
sorted_inds = np.argsort(ds)
|
273 |
-
ds_sorted = ds[sorted_inds]
|
274 |
-
sames_sorted = sames[sorted_inds]
|
275 |
-
|
276 |
-
TPs = np.cumsum(sames_sorted)
|
277 |
-
FPs = np.cumsum(1-sames_sorted)
|
278 |
-
FNs = np.sum(sames_sorted)-TPs
|
279 |
-
|
280 |
-
precs = TPs/(TPs+FPs)
|
281 |
-
recs = TPs/(TPs+FNs)
|
282 |
-
score = util.voc_ap(recs,precs)
|
283 |
-
|
284 |
-
return(score, dict(ds=ds,sames=sames))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lpips/networks_basic.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
|
2 |
-
from __future__ import absolute_import
|
3 |
-
|
4 |
-
import sys
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.init as init
|
8 |
-
from torch.autograd import Variable
|
9 |
-
import numpy as np
|
10 |
-
from pdb import set_trace as st
|
11 |
-
from skimage import color
|
12 |
-
from IPython import embed
|
13 |
-
from . import pretrained_networks as pn
|
14 |
-
|
15 |
-
import lpips as util
|
16 |
-
|
17 |
-
def spatial_average(in_tens, keepdim=True):
|
18 |
-
return in_tens.mean([2,3],keepdim=keepdim)
|
19 |
-
|
20 |
-
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
21 |
-
in_H = in_tens.shape[2]
|
22 |
-
scale_factor = 1.*out_H/in_H
|
23 |
-
|
24 |
-
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
25 |
-
|
26 |
-
# Learned perceptual metric
|
27 |
-
class PNetLin(nn.Module):
|
28 |
-
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
|
29 |
-
super(PNetLin, self).__init__()
|
30 |
-
|
31 |
-
self.pnet_type = pnet_type
|
32 |
-
self.pnet_tune = pnet_tune
|
33 |
-
self.pnet_rand = pnet_rand
|
34 |
-
self.spatial = spatial
|
35 |
-
self.lpips = lpips
|
36 |
-
self.version = version
|
37 |
-
self.scaling_layer = ScalingLayer()
|
38 |
-
|
39 |
-
if(self.pnet_type in ['vgg','vgg16']):
|
40 |
-
net_type = pn.vgg16
|
41 |
-
self.chns = [64,128,256,512,512]
|
42 |
-
elif(self.pnet_type=='alex'):
|
43 |
-
net_type = pn.alexnet
|
44 |
-
self.chns = [64,192,384,256,256]
|
45 |
-
elif(self.pnet_type=='squeeze'):
|
46 |
-
net_type = pn.squeezenet
|
47 |
-
self.chns = [64,128,256,384,384,512,512]
|
48 |
-
self.L = len(self.chns)
|
49 |
-
|
50 |
-
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
51 |
-
|
52 |
-
if(lpips):
|
53 |
-
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
54 |
-
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
55 |
-
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
56 |
-
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
57 |
-
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
58 |
-
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
59 |
-
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
60 |
-
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
61 |
-
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
62 |
-
self.lins+=[self.lin5,self.lin6]
|
63 |
-
|
64 |
-
def forward(self, in0, in1, retPerLayer=False):
|
65 |
-
# v0.0 - original release had a bug, where input was not scaled
|
66 |
-
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
67 |
-
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
68 |
-
feats0, feats1, diffs = {}, {}, {}
|
69 |
-
|
70 |
-
for kk in range(self.L):
|
71 |
-
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
|
72 |
-
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
73 |
-
|
74 |
-
if(self.lpips):
|
75 |
-
if(self.spatial):
|
76 |
-
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
77 |
-
else:
|
78 |
-
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
79 |
-
else:
|
80 |
-
if(self.spatial):
|
81 |
-
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
82 |
-
else:
|
83 |
-
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
84 |
-
|
85 |
-
val = res[0]
|
86 |
-
for l in range(1,self.L):
|
87 |
-
val += res[l]
|
88 |
-
|
89 |
-
if(retPerLayer):
|
90 |
-
return (val, res)
|
91 |
-
else:
|
92 |
-
return val
|
93 |
-
|
94 |
-
class ScalingLayer(nn.Module):
|
95 |
-
def __init__(self):
|
96 |
-
super(ScalingLayer, self).__init__()
|
97 |
-
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
98 |
-
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
99 |
-
|
100 |
-
def forward(self, inp):
|
101 |
-
return (inp - self.shift) / self.scale
|
102 |
-
|
103 |
-
|
104 |
-
class NetLinLayer(nn.Module):
|
105 |
-
''' A single linear layer which does a 1x1 conv '''
|
106 |
-
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
107 |
-
super(NetLinLayer, self).__init__()
|
108 |
-
|
109 |
-
layers = [nn.Dropout(),] if(use_dropout) else []
|
110 |
-
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
111 |
-
self.model = nn.Sequential(*layers)
|
112 |
-
|
113 |
-
|
114 |
-
class Dist2LogitLayer(nn.Module):
|
115 |
-
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
116 |
-
def __init__(self, chn_mid=32, use_sigmoid=True):
|
117 |
-
super(Dist2LogitLayer, self).__init__()
|
118 |
-
|
119 |
-
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
120 |
-
layers += [nn.LeakyReLU(0.2,True),]
|
121 |
-
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
122 |
-
layers += [nn.LeakyReLU(0.2,True),]
|
123 |
-
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
124 |
-
if(use_sigmoid):
|
125 |
-
layers += [nn.Sigmoid(),]
|
126 |
-
self.model = nn.Sequential(*layers)
|
127 |
-
|
128 |
-
def forward(self,d0,d1,eps=0.1):
|
129 |
-
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
130 |
-
|
131 |
-
class BCERankingLoss(nn.Module):
|
132 |
-
def __init__(self, chn_mid=32):
|
133 |
-
super(BCERankingLoss, self).__init__()
|
134 |
-
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
135 |
-
# self.parameters = list(self.net.parameters())
|
136 |
-
self.loss = torch.nn.BCELoss()
|
137 |
-
|
138 |
-
def forward(self, d0, d1, judge):
|
139 |
-
per = (judge+1.)/2.
|
140 |
-
self.logit = self.net.forward(d0,d1)
|
141 |
-
return self.loss(self.logit, per)
|
142 |
-
|
143 |
-
# L2, DSSIM metrics
|
144 |
-
class FakeNet(nn.Module):
|
145 |
-
def __init__(self, use_gpu=True, colorspace='Lab'):
|
146 |
-
super(FakeNet, self).__init__()
|
147 |
-
self.use_gpu = use_gpu
|
148 |
-
self.colorspace=colorspace
|
149 |
-
|
150 |
-
class L2(FakeNet):
|
151 |
-
|
152 |
-
def forward(self, in0, in1, retPerLayer=None):
|
153 |
-
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
154 |
-
|
155 |
-
if(self.colorspace=='RGB'):
|
156 |
-
(N,C,X,Y) = in0.size()
|
157 |
-
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
158 |
-
return value
|
159 |
-
elif(self.colorspace=='Lab'):
|
160 |
-
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
161 |
-
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
162 |
-
ret_var = Variable( torch.Tensor((value,) ) )
|
163 |
-
if(self.use_gpu):
|
164 |
-
ret_var = ret_var.cuda()
|
165 |
-
return ret_var
|
166 |
-
|
167 |
-
class DSSIM(FakeNet):
|
168 |
-
|
169 |
-
def forward(self, in0, in1, retPerLayer=None):
|
170 |
-
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
171 |
-
|
172 |
-
if(self.colorspace=='RGB'):
|
173 |
-
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
|
174 |
-
elif(self.colorspace=='Lab'):
|
175 |
-
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
176 |
-
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
177 |
-
ret_var = Variable( torch.Tensor((value,) ) )
|
178 |
-
if(self.use_gpu):
|
179 |
-
ret_var = ret_var.cuda()
|
180 |
-
return ret_var
|
181 |
-
|
182 |
-
def print_network(net):
|
183 |
-
num_params = 0
|
184 |
-
for param in net.parameters():
|
185 |
-
num_params += param.numel()
|
186 |
-
print('Network',net)
|
187 |
-
print('Total number of parameters: %d' % num_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lpips/pretrained_networks.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
from collections import namedtuple
|
2 |
-
import torch
|
3 |
-
from torchvision import models as tv
|
4 |
-
from IPython import embed
|
5 |
-
|
6 |
-
class squeezenet(torch.nn.Module):
|
7 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
-
super(squeezenet, self).__init__()
|
9 |
-
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
-
self.slice1 = torch.nn.Sequential()
|
11 |
-
self.slice2 = torch.nn.Sequential()
|
12 |
-
self.slice3 = torch.nn.Sequential()
|
13 |
-
self.slice4 = torch.nn.Sequential()
|
14 |
-
self.slice5 = torch.nn.Sequential()
|
15 |
-
self.slice6 = torch.nn.Sequential()
|
16 |
-
self.slice7 = torch.nn.Sequential()
|
17 |
-
self.N_slices = 7
|
18 |
-
for x in range(2):
|
19 |
-
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
-
for x in range(2,5):
|
21 |
-
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
-
for x in range(5, 8):
|
23 |
-
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
-
for x in range(8, 10):
|
25 |
-
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
-
for x in range(10, 11):
|
27 |
-
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
-
for x in range(11, 12):
|
29 |
-
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
-
for x in range(12, 13):
|
31 |
-
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
-
if not requires_grad:
|
33 |
-
for param in self.parameters():
|
34 |
-
param.requires_grad = False
|
35 |
-
|
36 |
-
def forward(self, X):
|
37 |
-
h = self.slice1(X)
|
38 |
-
h_relu1 = h
|
39 |
-
h = self.slice2(h)
|
40 |
-
h_relu2 = h
|
41 |
-
h = self.slice3(h)
|
42 |
-
h_relu3 = h
|
43 |
-
h = self.slice4(h)
|
44 |
-
h_relu4 = h
|
45 |
-
h = self.slice5(h)
|
46 |
-
h_relu5 = h
|
47 |
-
h = self.slice6(h)
|
48 |
-
h_relu6 = h
|
49 |
-
h = self.slice7(h)
|
50 |
-
h_relu7 = h
|
51 |
-
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
52 |
-
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
53 |
-
|
54 |
-
return out
|
55 |
-
|
56 |
-
|
57 |
-
class alexnet(torch.nn.Module):
|
58 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
59 |
-
super(alexnet, self).__init__()
|
60 |
-
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
61 |
-
self.slice1 = torch.nn.Sequential()
|
62 |
-
self.slice2 = torch.nn.Sequential()
|
63 |
-
self.slice3 = torch.nn.Sequential()
|
64 |
-
self.slice4 = torch.nn.Sequential()
|
65 |
-
self.slice5 = torch.nn.Sequential()
|
66 |
-
self.N_slices = 5
|
67 |
-
for x in range(2):
|
68 |
-
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
69 |
-
for x in range(2, 5):
|
70 |
-
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
-
for x in range(5, 8):
|
72 |
-
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
-
for x in range(8, 10):
|
74 |
-
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
-
for x in range(10, 12):
|
76 |
-
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
-
if not requires_grad:
|
78 |
-
for param in self.parameters():
|
79 |
-
param.requires_grad = False
|
80 |
-
|
81 |
-
def forward(self, X):
|
82 |
-
h = self.slice1(X)
|
83 |
-
h_relu1 = h
|
84 |
-
h = self.slice2(h)
|
85 |
-
h_relu2 = h
|
86 |
-
h = self.slice3(h)
|
87 |
-
h_relu3 = h
|
88 |
-
h = self.slice4(h)
|
89 |
-
h_relu4 = h
|
90 |
-
h = self.slice5(h)
|
91 |
-
h_relu5 = h
|
92 |
-
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
93 |
-
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
94 |
-
|
95 |
-
return out
|
96 |
-
|
97 |
-
class vgg16(torch.nn.Module):
|
98 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
99 |
-
super(vgg16, self).__init__()
|
100 |
-
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
101 |
-
self.slice1 = torch.nn.Sequential()
|
102 |
-
self.slice2 = torch.nn.Sequential()
|
103 |
-
self.slice3 = torch.nn.Sequential()
|
104 |
-
self.slice4 = torch.nn.Sequential()
|
105 |
-
self.slice5 = torch.nn.Sequential()
|
106 |
-
self.N_slices = 5
|
107 |
-
for x in range(4):
|
108 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
109 |
-
for x in range(4, 9):
|
110 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
111 |
-
for x in range(9, 16):
|
112 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
113 |
-
for x in range(16, 23):
|
114 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
115 |
-
for x in range(23, 30):
|
116 |
-
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
117 |
-
if not requires_grad:
|
118 |
-
for param in self.parameters():
|
119 |
-
param.requires_grad = False
|
120 |
-
|
121 |
-
def forward(self, X):
|
122 |
-
h = self.slice1(X)
|
123 |
-
h_relu1_2 = h
|
124 |
-
h = self.slice2(h)
|
125 |
-
h_relu2_2 = h
|
126 |
-
h = self.slice3(h)
|
127 |
-
h_relu3_3 = h
|
128 |
-
h = self.slice4(h)
|
129 |
-
h_relu4_3 = h
|
130 |
-
h = self.slice5(h)
|
131 |
-
h_relu5_3 = h
|
132 |
-
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
133 |
-
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
134 |
-
|
135 |
-
return out
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
class resnet(torch.nn.Module):
|
140 |
-
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
141 |
-
super(resnet, self).__init__()
|
142 |
-
if(num==18):
|
143 |
-
self.net = tv.resnet18(pretrained=pretrained)
|
144 |
-
elif(num==34):
|
145 |
-
self.net = tv.resnet34(pretrained=pretrained)
|
146 |
-
elif(num==50):
|
147 |
-
self.net = tv.resnet50(pretrained=pretrained)
|
148 |
-
elif(num==101):
|
149 |
-
self.net = tv.resnet101(pretrained=pretrained)
|
150 |
-
elif(num==152):
|
151 |
-
self.net = tv.resnet152(pretrained=pretrained)
|
152 |
-
self.N_slices = 5
|
153 |
-
|
154 |
-
self.conv1 = self.net.conv1
|
155 |
-
self.bn1 = self.net.bn1
|
156 |
-
self.relu = self.net.relu
|
157 |
-
self.maxpool = self.net.maxpool
|
158 |
-
self.layer1 = self.net.layer1
|
159 |
-
self.layer2 = self.net.layer2
|
160 |
-
self.layer3 = self.net.layer3
|
161 |
-
self.layer4 = self.net.layer4
|
162 |
-
|
163 |
-
def forward(self, X):
|
164 |
-
h = self.conv1(X)
|
165 |
-
h = self.bn1(h)
|
166 |
-
h = self.relu(h)
|
167 |
-
h_relu1 = h
|
168 |
-
h = self.maxpool(h)
|
169 |
-
h = self.layer1(h)
|
170 |
-
h_conv2 = h
|
171 |
-
h = self.layer2(h)
|
172 |
-
h_conv3 = h
|
173 |
-
h = self.layer3(h)
|
174 |
-
h_conv4 = h
|
175 |
-
h = self.layer4(h)
|
176 |
-
h_conv5 = h
|
177 |
-
|
178 |
-
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
179 |
-
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
180 |
-
|
181 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lpips/weights/v0.0/alex.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
|
3 |
-
size 5455
|
|
|
|
|
|
|
|
lpips/weights/v0.0/squeeze.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
|
3 |
-
size 10057
|
|
|
|
|
|
|
|
lpips/weights/v0.0/vgg.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
|
3 |
-
size 6735
|
|
|
|
|
|
|
|
lpips/weights/v0.1/alex.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
3 |
-
size 6009
|
|
|
|
|
|
|
|
lpips/weights/v0.1/squeeze.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
|
3 |
-
size 10811
|
|
|
|
|
|
|
|
lpips/weights/v0.1/vgg.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
-
size 7289
|
|
|
|
|
|
|
|
model.py
DELETED
@@ -1,698 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import random
|
3 |
-
import functools
|
4 |
-
import operator
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import functional as F
|
9 |
-
from torch.autograd import Function
|
10 |
-
|
11 |
-
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
12 |
-
|
13 |
-
|
14 |
-
class PixelNorm(nn.Module):
|
15 |
-
def __init__(self):
|
16 |
-
super().__init__()
|
17 |
-
|
18 |
-
def forward(self, input):
|
19 |
-
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
20 |
-
|
21 |
-
|
22 |
-
def make_kernel(k):
|
23 |
-
k = torch.tensor(k, dtype=torch.float32)
|
24 |
-
|
25 |
-
if k.ndim == 1:
|
26 |
-
k = k[None, :] * k[:, None]
|
27 |
-
|
28 |
-
k /= k.sum()
|
29 |
-
|
30 |
-
return k
|
31 |
-
|
32 |
-
|
33 |
-
class Upsample(nn.Module):
|
34 |
-
def __init__(self, kernel, factor=2):
|
35 |
-
super().__init__()
|
36 |
-
|
37 |
-
self.factor = factor
|
38 |
-
kernel = make_kernel(kernel) * (factor ** 2)
|
39 |
-
self.register_buffer("kernel", kernel)
|
40 |
-
|
41 |
-
p = kernel.shape[0] - factor
|
42 |
-
|
43 |
-
pad0 = (p + 1) // 2 + factor - 1
|
44 |
-
pad1 = p // 2
|
45 |
-
|
46 |
-
self.pad = (pad0, pad1)
|
47 |
-
|
48 |
-
def forward(self, input):
|
49 |
-
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
50 |
-
|
51 |
-
return out
|
52 |
-
|
53 |
-
|
54 |
-
class Downsample(nn.Module):
|
55 |
-
def __init__(self, kernel, factor=2):
|
56 |
-
super().__init__()
|
57 |
-
|
58 |
-
self.factor = factor
|
59 |
-
kernel = make_kernel(kernel)
|
60 |
-
self.register_buffer("kernel", kernel)
|
61 |
-
|
62 |
-
p = kernel.shape[0] - factor
|
63 |
-
|
64 |
-
pad0 = (p + 1) // 2
|
65 |
-
pad1 = p // 2
|
66 |
-
|
67 |
-
self.pad = (pad0, pad1)
|
68 |
-
|
69 |
-
def forward(self, input):
|
70 |
-
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
71 |
-
|
72 |
-
return out
|
73 |
-
|
74 |
-
|
75 |
-
class Blur(nn.Module):
|
76 |
-
def __init__(self, kernel, pad, upsample_factor=1):
|
77 |
-
super().__init__()
|
78 |
-
|
79 |
-
kernel = make_kernel(kernel)
|
80 |
-
|
81 |
-
if upsample_factor > 1:
|
82 |
-
kernel = kernel * (upsample_factor ** 2)
|
83 |
-
|
84 |
-
self.register_buffer("kernel", kernel)
|
85 |
-
|
86 |
-
self.pad = pad
|
87 |
-
|
88 |
-
def forward(self, input):
|
89 |
-
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
90 |
-
|
91 |
-
return out
|
92 |
-
|
93 |
-
|
94 |
-
class EqualConv2d(nn.Module):
|
95 |
-
def __init__(
|
96 |
-
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
97 |
-
):
|
98 |
-
super().__init__()
|
99 |
-
|
100 |
-
self.weight = nn.Parameter(
|
101 |
-
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
102 |
-
)
|
103 |
-
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
104 |
-
|
105 |
-
self.stride = stride
|
106 |
-
self.padding = padding
|
107 |
-
|
108 |
-
if bias:
|
109 |
-
self.bias = nn.Parameter(torch.zeros(out_channel))
|
110 |
-
|
111 |
-
else:
|
112 |
-
self.bias = None
|
113 |
-
|
114 |
-
def forward(self, input):
|
115 |
-
out = conv2d_gradfix.conv2d(
|
116 |
-
input,
|
117 |
-
self.weight * self.scale,
|
118 |
-
bias=self.bias,
|
119 |
-
stride=self.stride,
|
120 |
-
padding=self.padding,
|
121 |
-
)
|
122 |
-
|
123 |
-
return out
|
124 |
-
|
125 |
-
def __repr__(self):
|
126 |
-
return (
|
127 |
-
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
128 |
-
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
129 |
-
)
|
130 |
-
|
131 |
-
|
132 |
-
class EqualLinear(nn.Module):
|
133 |
-
def __init__(
|
134 |
-
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
135 |
-
):
|
136 |
-
super().__init__()
|
137 |
-
|
138 |
-
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
139 |
-
|
140 |
-
if bias:
|
141 |
-
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
142 |
-
|
143 |
-
else:
|
144 |
-
self.bias = None
|
145 |
-
|
146 |
-
self.activation = activation
|
147 |
-
|
148 |
-
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
149 |
-
self.lr_mul = lr_mul
|
150 |
-
|
151 |
-
def forward(self, input):
|
152 |
-
if self.activation:
|
153 |
-
out = F.linear(input, self.weight * self.scale)
|
154 |
-
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
155 |
-
|
156 |
-
else:
|
157 |
-
out = F.linear(
|
158 |
-
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
159 |
-
)
|
160 |
-
|
161 |
-
return out
|
162 |
-
|
163 |
-
def __repr__(self):
|
164 |
-
return (
|
165 |
-
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
166 |
-
)
|
167 |
-
|
168 |
-
|
169 |
-
class ModulatedConv2d(nn.Module):
|
170 |
-
def __init__(
|
171 |
-
self,
|
172 |
-
in_channel,
|
173 |
-
out_channel,
|
174 |
-
kernel_size,
|
175 |
-
style_dim,
|
176 |
-
demodulate=True,
|
177 |
-
upsample=False,
|
178 |
-
downsample=False,
|
179 |
-
blur_kernel=[1, 3, 3, 1],
|
180 |
-
fused=True,
|
181 |
-
):
|
182 |
-
super().__init__()
|
183 |
-
|
184 |
-
self.eps = 1e-8
|
185 |
-
self.kernel_size = kernel_size
|
186 |
-
self.in_channel = in_channel
|
187 |
-
self.out_channel = out_channel
|
188 |
-
self.upsample = upsample
|
189 |
-
self.downsample = downsample
|
190 |
-
|
191 |
-
if upsample:
|
192 |
-
factor = 2
|
193 |
-
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
194 |
-
pad0 = (p + 1) // 2 + factor - 1
|
195 |
-
pad1 = p // 2 + 1
|
196 |
-
|
197 |
-
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
198 |
-
|
199 |
-
if downsample:
|
200 |
-
factor = 2
|
201 |
-
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
202 |
-
pad0 = (p + 1) // 2
|
203 |
-
pad1 = p // 2
|
204 |
-
|
205 |
-
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
206 |
-
|
207 |
-
fan_in = in_channel * kernel_size ** 2
|
208 |
-
self.scale = 1 / math.sqrt(fan_in)
|
209 |
-
self.padding = kernel_size // 2
|
210 |
-
|
211 |
-
self.weight = nn.Parameter(
|
212 |
-
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
213 |
-
)
|
214 |
-
|
215 |
-
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
216 |
-
|
217 |
-
self.demodulate = demodulate
|
218 |
-
self.fused = fused
|
219 |
-
|
220 |
-
def __repr__(self):
|
221 |
-
return (
|
222 |
-
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
223 |
-
f"upsample={self.upsample}, downsample={self.downsample})"
|
224 |
-
)
|
225 |
-
|
226 |
-
def forward(self, input, style):
|
227 |
-
batch, in_channel, height, width = input.shape
|
228 |
-
|
229 |
-
if not self.fused:
|
230 |
-
weight = self.scale * self.weight.squeeze(0)
|
231 |
-
style = self.modulation(style)
|
232 |
-
|
233 |
-
if self.demodulate:
|
234 |
-
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
|
235 |
-
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
|
236 |
-
|
237 |
-
input = input * style.reshape(batch, in_channel, 1, 1)
|
238 |
-
|
239 |
-
if self.upsample:
|
240 |
-
weight = weight.transpose(0, 1)
|
241 |
-
out = conv2d_gradfix.conv_transpose2d(
|
242 |
-
input, weight, padding=0, stride=2
|
243 |
-
)
|
244 |
-
out = self.blur(out)
|
245 |
-
|
246 |
-
elif self.downsample:
|
247 |
-
input = self.blur(input)
|
248 |
-
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
|
249 |
-
|
250 |
-
else:
|
251 |
-
out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
|
252 |
-
|
253 |
-
if self.demodulate:
|
254 |
-
out = out * dcoefs.view(batch, -1, 1, 1)
|
255 |
-
|
256 |
-
return out
|
257 |
-
|
258 |
-
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
259 |
-
weight = self.scale * self.weight * style
|
260 |
-
|
261 |
-
if self.demodulate:
|
262 |
-
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
263 |
-
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
264 |
-
|
265 |
-
weight = weight.view(
|
266 |
-
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
267 |
-
)
|
268 |
-
|
269 |
-
if self.upsample:
|
270 |
-
input = input.view(1, batch * in_channel, height, width)
|
271 |
-
weight = weight.view(
|
272 |
-
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
273 |
-
)
|
274 |
-
weight = weight.transpose(1, 2).reshape(
|
275 |
-
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
276 |
-
)
|
277 |
-
out = conv2d_gradfix.conv_transpose2d(
|
278 |
-
input, weight, padding=0, stride=2, groups=batch
|
279 |
-
)
|
280 |
-
_, _, height, width = out.shape
|
281 |
-
out = out.view(batch, self.out_channel, height, width)
|
282 |
-
out = self.blur(out)
|
283 |
-
|
284 |
-
elif self.downsample:
|
285 |
-
input = self.blur(input)
|
286 |
-
_, _, height, width = input.shape
|
287 |
-
input = input.view(1, batch * in_channel, height, width)
|
288 |
-
out = conv2d_gradfix.conv2d(
|
289 |
-
input, weight, padding=0, stride=2, groups=batch
|
290 |
-
)
|
291 |
-
_, _, height, width = out.shape
|
292 |
-
out = out.view(batch, self.out_channel, height, width)
|
293 |
-
|
294 |
-
else:
|
295 |
-
input = input.view(1, batch * in_channel, height, width)
|
296 |
-
out = conv2d_gradfix.conv2d(
|
297 |
-
input, weight, padding=self.padding, groups=batch
|
298 |
-
)
|
299 |
-
_, _, height, width = out.shape
|
300 |
-
out = out.view(batch, self.out_channel, height, width)
|
301 |
-
|
302 |
-
return out
|
303 |
-
|
304 |
-
|
305 |
-
class NoiseInjection(nn.Module):
|
306 |
-
def __init__(self):
|
307 |
-
super().__init__()
|
308 |
-
|
309 |
-
self.weight = nn.Parameter(torch.zeros(1))
|
310 |
-
|
311 |
-
def forward(self, image, noise=None):
|
312 |
-
if noise is None:
|
313 |
-
batch, _, height, width = image.shape
|
314 |
-
noise = image.new_empty(batch, 1, height, width).normal_()
|
315 |
-
|
316 |
-
return image + self.weight * noise
|
317 |
-
|
318 |
-
|
319 |
-
class ConstantInput(nn.Module):
|
320 |
-
def __init__(self, channel, size=4):
|
321 |
-
super().__init__()
|
322 |
-
|
323 |
-
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
324 |
-
|
325 |
-
def forward(self, input):
|
326 |
-
batch = input.shape[0]
|
327 |
-
out = self.input.repeat(batch, 1, 1, 1)
|
328 |
-
|
329 |
-
return out
|
330 |
-
|
331 |
-
|
332 |
-
class StyledConv(nn.Module):
|
333 |
-
def __init__(
|
334 |
-
self,
|
335 |
-
in_channel,
|
336 |
-
out_channel,
|
337 |
-
kernel_size,
|
338 |
-
style_dim,
|
339 |
-
upsample=False,
|
340 |
-
blur_kernel=[1, 3, 3, 1],
|
341 |
-
demodulate=True,
|
342 |
-
):
|
343 |
-
super().__init__()
|
344 |
-
|
345 |
-
self.conv = ModulatedConv2d(
|
346 |
-
in_channel,
|
347 |
-
out_channel,
|
348 |
-
kernel_size,
|
349 |
-
style_dim,
|
350 |
-
upsample=upsample,
|
351 |
-
blur_kernel=blur_kernel,
|
352 |
-
demodulate=demodulate,
|
353 |
-
)
|
354 |
-
|
355 |
-
self.noise = NoiseInjection()
|
356 |
-
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
357 |
-
# self.activate = ScaledLeakyReLU(0.2)
|
358 |
-
self.activate = FusedLeakyReLU(out_channel)
|
359 |
-
|
360 |
-
def forward(self, input, style, noise=None):
|
361 |
-
out = self.conv(input, style)
|
362 |
-
out = self.noise(out, noise=noise)
|
363 |
-
# out = out + self.bias
|
364 |
-
out = self.activate(out)
|
365 |
-
|
366 |
-
return out
|
367 |
-
|
368 |
-
|
369 |
-
class ToRGB(nn.Module):
|
370 |
-
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
371 |
-
super().__init__()
|
372 |
-
|
373 |
-
if upsample:
|
374 |
-
self.upsample = Upsample(blur_kernel)
|
375 |
-
|
376 |
-
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
377 |
-
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
378 |
-
|
379 |
-
def forward(self, input, style, skip=None):
|
380 |
-
out = self.conv(input, style)
|
381 |
-
out = out + self.bias
|
382 |
-
|
383 |
-
if skip is not None:
|
384 |
-
skip = self.upsample(skip)
|
385 |
-
|
386 |
-
out = out + skip
|
387 |
-
|
388 |
-
return out
|
389 |
-
|
390 |
-
|
391 |
-
class Generator(nn.Module):
|
392 |
-
def __init__(
|
393 |
-
self,
|
394 |
-
size,
|
395 |
-
style_dim,
|
396 |
-
n_mlp,
|
397 |
-
channel_multiplier=2,
|
398 |
-
blur_kernel=[1, 3, 3, 1],
|
399 |
-
lr_mlp=0.01,
|
400 |
-
):
|
401 |
-
super().__init__()
|
402 |
-
|
403 |
-
self.size = size
|
404 |
-
|
405 |
-
self.style_dim = style_dim
|
406 |
-
|
407 |
-
layers = [PixelNorm()]
|
408 |
-
|
409 |
-
for i in range(n_mlp):
|
410 |
-
layers.append(
|
411 |
-
EqualLinear(
|
412 |
-
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
413 |
-
)
|
414 |
-
)
|
415 |
-
|
416 |
-
self.style = nn.Sequential(*layers)
|
417 |
-
|
418 |
-
self.channels = {
|
419 |
-
4: 512,
|
420 |
-
8: 512,
|
421 |
-
16: 512,
|
422 |
-
32: 512,
|
423 |
-
64: 256 * channel_multiplier,
|
424 |
-
128: 128 * channel_multiplier,
|
425 |
-
256: 64 * channel_multiplier,
|
426 |
-
512: 32 * channel_multiplier,
|
427 |
-
1024: 16 * channel_multiplier,
|
428 |
-
}
|
429 |
-
|
430 |
-
self.input = ConstantInput(self.channels[4])
|
431 |
-
self.conv1 = StyledConv(
|
432 |
-
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
433 |
-
)
|
434 |
-
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
435 |
-
|
436 |
-
self.log_size = int(math.log(size, 2))
|
437 |
-
self.num_layers = (self.log_size - 2) * 2 + 1
|
438 |
-
|
439 |
-
self.convs = nn.ModuleList()
|
440 |
-
self.upsamples = nn.ModuleList()
|
441 |
-
self.to_rgbs = nn.ModuleList()
|
442 |
-
self.noises = nn.Module()
|
443 |
-
|
444 |
-
in_channel = self.channels[4]
|
445 |
-
|
446 |
-
for layer_idx in range(self.num_layers):
|
447 |
-
res = (layer_idx + 5) // 2
|
448 |
-
shape = [1, 1, 2 ** res, 2 ** res]
|
449 |
-
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
450 |
-
|
451 |
-
for i in range(3, self.log_size + 1):
|
452 |
-
out_channel = self.channels[2 ** i]
|
453 |
-
|
454 |
-
self.convs.append(
|
455 |
-
StyledConv(
|
456 |
-
in_channel,
|
457 |
-
out_channel,
|
458 |
-
3,
|
459 |
-
style_dim,
|
460 |
-
upsample=True,
|
461 |
-
blur_kernel=blur_kernel,
|
462 |
-
)
|
463 |
-
)
|
464 |
-
|
465 |
-
self.convs.append(
|
466 |
-
StyledConv(
|
467 |
-
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
468 |
-
)
|
469 |
-
)
|
470 |
-
|
471 |
-
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
472 |
-
|
473 |
-
in_channel = out_channel
|
474 |
-
|
475 |
-
self.n_latent = self.log_size * 2 - 2
|
476 |
-
|
477 |
-
def make_noise(self):
|
478 |
-
device = self.input.input.device
|
479 |
-
|
480 |
-
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
481 |
-
|
482 |
-
for i in range(3, self.log_size + 1):
|
483 |
-
for _ in range(2):
|
484 |
-
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
485 |
-
|
486 |
-
return noises
|
487 |
-
|
488 |
-
def mean_latent(self, n_latent):
|
489 |
-
latent_in = torch.randn(
|
490 |
-
n_latent, self.style_dim, device=self.input.input.device
|
491 |
-
)
|
492 |
-
latent = self.style(latent_in).mean(0, keepdim=True)
|
493 |
-
|
494 |
-
return latent
|
495 |
-
|
496 |
-
def get_latent(self, input):
|
497 |
-
return self.style(input)
|
498 |
-
|
499 |
-
def forward(
|
500 |
-
self,
|
501 |
-
styles,
|
502 |
-
return_latents=False,
|
503 |
-
inject_index=None,
|
504 |
-
truncation=1,
|
505 |
-
truncation_latent=None,
|
506 |
-
input_is_latent=False,
|
507 |
-
noise=None,
|
508 |
-
randomize_noise=True,
|
509 |
-
):
|
510 |
-
if not input_is_latent:
|
511 |
-
styles = [self.style(s) for s in styles]
|
512 |
-
|
513 |
-
if noise is None:
|
514 |
-
if randomize_noise:
|
515 |
-
noise = [None] * self.num_layers
|
516 |
-
else:
|
517 |
-
noise = [
|
518 |
-
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
519 |
-
]
|
520 |
-
|
521 |
-
if truncation < 1:
|
522 |
-
style_t = []
|
523 |
-
|
524 |
-
for style in styles:
|
525 |
-
style_t.append(
|
526 |
-
truncation_latent + truncation * (style - truncation_latent)
|
527 |
-
)
|
528 |
-
|
529 |
-
styles = style_t
|
530 |
-
|
531 |
-
if len(styles) < 2:
|
532 |
-
inject_index = self.n_latent
|
533 |
-
|
534 |
-
if styles[0].ndim < 3:
|
535 |
-
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
536 |
-
|
537 |
-
else:
|
538 |
-
latent = styles[0]
|
539 |
-
|
540 |
-
else:
|
541 |
-
if inject_index is None:
|
542 |
-
inject_index = random.randint(1, self.n_latent - 1)
|
543 |
-
|
544 |
-
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
545 |
-
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
546 |
-
|
547 |
-
latent = torch.cat([latent, latent2], 1)
|
548 |
-
|
549 |
-
out = self.input(latent)
|
550 |
-
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
551 |
-
|
552 |
-
skip = self.to_rgb1(out, latent[:, 1])
|
553 |
-
|
554 |
-
i = 1
|
555 |
-
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
556 |
-
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
557 |
-
):
|
558 |
-
out = conv1(out, latent[:, i], noise=noise1)
|
559 |
-
out = conv2(out, latent[:, i + 1], noise=noise2)
|
560 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
561 |
-
|
562 |
-
i += 2
|
563 |
-
|
564 |
-
image = skip
|
565 |
-
|
566 |
-
if return_latents:
|
567 |
-
return image, latent
|
568 |
-
|
569 |
-
else:
|
570 |
-
return image, None
|
571 |
-
|
572 |
-
|
573 |
-
class ConvLayer(nn.Sequential):
|
574 |
-
def __init__(
|
575 |
-
self,
|
576 |
-
in_channel,
|
577 |
-
out_channel,
|
578 |
-
kernel_size,
|
579 |
-
downsample=False,
|
580 |
-
blur_kernel=[1, 3, 3, 1],
|
581 |
-
bias=True,
|
582 |
-
activate=True,
|
583 |
-
):
|
584 |
-
layers = []
|
585 |
-
|
586 |
-
if downsample:
|
587 |
-
factor = 2
|
588 |
-
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
589 |
-
pad0 = (p + 1) // 2
|
590 |
-
pad1 = p // 2
|
591 |
-
|
592 |
-
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
593 |
-
|
594 |
-
stride = 2
|
595 |
-
self.padding = 0
|
596 |
-
|
597 |
-
else:
|
598 |
-
stride = 1
|
599 |
-
self.padding = kernel_size // 2
|
600 |
-
|
601 |
-
layers.append(
|
602 |
-
EqualConv2d(
|
603 |
-
in_channel,
|
604 |
-
out_channel,
|
605 |
-
kernel_size,
|
606 |
-
padding=self.padding,
|
607 |
-
stride=stride,
|
608 |
-
bias=bias and not activate,
|
609 |
-
)
|
610 |
-
)
|
611 |
-
|
612 |
-
if activate:
|
613 |
-
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
614 |
-
|
615 |
-
super().__init__(*layers)
|
616 |
-
|
617 |
-
|
618 |
-
class ResBlock(nn.Module):
|
619 |
-
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
620 |
-
super().__init__()
|
621 |
-
|
622 |
-
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
623 |
-
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
624 |
-
|
625 |
-
self.skip = ConvLayer(
|
626 |
-
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
627 |
-
)
|
628 |
-
|
629 |
-
def forward(self, input):
|
630 |
-
out = self.conv1(input)
|
631 |
-
out = self.conv2(out)
|
632 |
-
|
633 |
-
skip = self.skip(input)
|
634 |
-
out = (out + skip) / math.sqrt(2)
|
635 |
-
|
636 |
-
return out
|
637 |
-
|
638 |
-
|
639 |
-
class Discriminator(nn.Module):
|
640 |
-
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
641 |
-
super().__init__()
|
642 |
-
|
643 |
-
channels = {
|
644 |
-
4: 512,
|
645 |
-
8: 512,
|
646 |
-
16: 512,
|
647 |
-
32: 512,
|
648 |
-
64: 256 * channel_multiplier,
|
649 |
-
128: 128 * channel_multiplier,
|
650 |
-
256: 64 * channel_multiplier,
|
651 |
-
512: 32 * channel_multiplier,
|
652 |
-
1024: 16 * channel_multiplier,
|
653 |
-
}
|
654 |
-
|
655 |
-
convs = [ConvLayer(3, channels[size], 1)]
|
656 |
-
|
657 |
-
log_size = int(math.log(size, 2))
|
658 |
-
|
659 |
-
in_channel = channels[size]
|
660 |
-
|
661 |
-
for i in range(log_size, 2, -1):
|
662 |
-
out_channel = channels[2 ** (i - 1)]
|
663 |
-
|
664 |
-
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
665 |
-
|
666 |
-
in_channel = out_channel
|
667 |
-
|
668 |
-
self.convs = nn.Sequential(*convs)
|
669 |
-
|
670 |
-
self.stddev_group = 4
|
671 |
-
self.stddev_feat = 1
|
672 |
-
|
673 |
-
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
674 |
-
self.final_linear = nn.Sequential(
|
675 |
-
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
676 |
-
EqualLinear(channels[4], 1),
|
677 |
-
)
|
678 |
-
|
679 |
-
def forward(self, input):
|
680 |
-
out = self.convs(input)
|
681 |
-
|
682 |
-
batch, channel, height, width = out.shape
|
683 |
-
group = min(batch, self.stddev_group)
|
684 |
-
stddev = out.view(
|
685 |
-
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
686 |
-
)
|
687 |
-
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
688 |
-
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
689 |
-
stddev = stddev.repeat(group, 1, height, width)
|
690 |
-
out = torch.cat([out, stddev], 1)
|
691 |
-
|
692 |
-
out = self.final_conv(out)
|
693 |
-
|
694 |
-
out = out.view(batch, -1)
|
695 |
-
out = self.final_linear(out)
|
696 |
-
|
697 |
-
return out
|
698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
non_leaking.py
DELETED
@@ -1,465 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import autograd
|
5 |
-
from torch.nn import functional as F
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
from distributed import reduce_sum
|
9 |
-
from op import upfirdn2d
|
10 |
-
|
11 |
-
|
12 |
-
class AdaptiveAugment:
|
13 |
-
def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
|
14 |
-
self.ada_aug_target = ada_aug_target
|
15 |
-
self.ada_aug_len = ada_aug_len
|
16 |
-
self.update_every = update_every
|
17 |
-
|
18 |
-
self.ada_update = 0
|
19 |
-
self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
|
20 |
-
self.r_t_stat = 0
|
21 |
-
self.ada_aug_p = 0
|
22 |
-
|
23 |
-
@torch.no_grad()
|
24 |
-
def tune(self, real_pred):
|
25 |
-
self.ada_aug_buf += torch.tensor(
|
26 |
-
(torch.sign(real_pred).sum().item(), real_pred.shape[0]),
|
27 |
-
device=real_pred.device,
|
28 |
-
)
|
29 |
-
self.ada_update += 1
|
30 |
-
|
31 |
-
if self.ada_update % self.update_every == 0:
|
32 |
-
self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
|
33 |
-
pred_signs, n_pred = self.ada_aug_buf.tolist()
|
34 |
-
|
35 |
-
self.r_t_stat = pred_signs / n_pred
|
36 |
-
|
37 |
-
if self.r_t_stat > self.ada_aug_target:
|
38 |
-
sign = 1
|
39 |
-
|
40 |
-
else:
|
41 |
-
sign = -1
|
42 |
-
|
43 |
-
self.ada_aug_p += sign * n_pred / self.ada_aug_len
|
44 |
-
self.ada_aug_p = min(1, max(0, self.ada_aug_p))
|
45 |
-
self.ada_aug_buf.mul_(0)
|
46 |
-
self.ada_update = 0
|
47 |
-
|
48 |
-
return self.ada_aug_p
|
49 |
-
|
50 |
-
|
51 |
-
SYM6 = (
|
52 |
-
0.015404109327027373,
|
53 |
-
0.0034907120842174702,
|
54 |
-
-0.11799011114819057,
|
55 |
-
-0.048311742585633,
|
56 |
-
0.4910559419267466,
|
57 |
-
0.787641141030194,
|
58 |
-
0.3379294217276218,
|
59 |
-
-0.07263752278646252,
|
60 |
-
-0.021060292512300564,
|
61 |
-
0.04472490177066578,
|
62 |
-
0.0017677118642428036,
|
63 |
-
-0.007800708325034148,
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
def translate_mat(t_x, t_y, device="cpu"):
|
68 |
-
batch = t_x.shape[0]
|
69 |
-
|
70 |
-
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
71 |
-
translate = torch.stack((t_x, t_y), 1)
|
72 |
-
mat[:, :2, 2] = translate
|
73 |
-
|
74 |
-
return mat
|
75 |
-
|
76 |
-
|
77 |
-
def rotate_mat(theta, device="cpu"):
|
78 |
-
batch = theta.shape[0]
|
79 |
-
|
80 |
-
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
81 |
-
sin_t = torch.sin(theta)
|
82 |
-
cos_t = torch.cos(theta)
|
83 |
-
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
|
84 |
-
mat[:, :2, :2] = rot
|
85 |
-
|
86 |
-
return mat
|
87 |
-
|
88 |
-
|
89 |
-
def scale_mat(s_x, s_y, device="cpu"):
|
90 |
-
batch = s_x.shape[0]
|
91 |
-
|
92 |
-
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
93 |
-
mat[:, 0, 0] = s_x
|
94 |
-
mat[:, 1, 1] = s_y
|
95 |
-
|
96 |
-
return mat
|
97 |
-
|
98 |
-
|
99 |
-
def translate3d_mat(t_x, t_y, t_z):
|
100 |
-
batch = t_x.shape[0]
|
101 |
-
|
102 |
-
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
103 |
-
translate = torch.stack((t_x, t_y, t_z), 1)
|
104 |
-
mat[:, :3, 3] = translate
|
105 |
-
|
106 |
-
return mat
|
107 |
-
|
108 |
-
|
109 |
-
def rotate3d_mat(axis, theta):
|
110 |
-
batch = theta.shape[0]
|
111 |
-
|
112 |
-
u_x, u_y, u_z = axis
|
113 |
-
|
114 |
-
eye = torch.eye(3).unsqueeze(0)
|
115 |
-
cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
|
116 |
-
outer = torch.tensor(axis)
|
117 |
-
outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
|
118 |
-
|
119 |
-
sin_t = torch.sin(theta).view(-1, 1, 1)
|
120 |
-
cos_t = torch.cos(theta).view(-1, 1, 1)
|
121 |
-
|
122 |
-
rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
|
123 |
-
|
124 |
-
eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
125 |
-
eye_4[:, :3, :3] = rot
|
126 |
-
|
127 |
-
return eye_4
|
128 |
-
|
129 |
-
|
130 |
-
def scale3d_mat(s_x, s_y, s_z):
|
131 |
-
batch = s_x.shape[0]
|
132 |
-
|
133 |
-
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
134 |
-
mat[:, 0, 0] = s_x
|
135 |
-
mat[:, 1, 1] = s_y
|
136 |
-
mat[:, 2, 2] = s_z
|
137 |
-
|
138 |
-
return mat
|
139 |
-
|
140 |
-
|
141 |
-
def luma_flip_mat(axis, i):
|
142 |
-
batch = i.shape[0]
|
143 |
-
|
144 |
-
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
145 |
-
axis = torch.tensor(axis + (0,))
|
146 |
-
flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
|
147 |
-
|
148 |
-
return eye - flip
|
149 |
-
|
150 |
-
|
151 |
-
def saturation_mat(axis, i):
|
152 |
-
batch = i.shape[0]
|
153 |
-
|
154 |
-
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
155 |
-
axis = torch.tensor(axis + (0,))
|
156 |
-
axis = torch.ger(axis, axis)
|
157 |
-
saturate = axis + (eye - axis) * i.view(-1, 1, 1)
|
158 |
-
|
159 |
-
return saturate
|
160 |
-
|
161 |
-
|
162 |
-
def lognormal_sample(size, mean=0, std=1, device="cpu"):
|
163 |
-
return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
|
164 |
-
|
165 |
-
|
166 |
-
def category_sample(size, categories, device="cpu"):
|
167 |
-
category = torch.tensor(categories, device=device)
|
168 |
-
sample = torch.randint(high=len(categories), size=(size,), device=device)
|
169 |
-
|
170 |
-
return category[sample]
|
171 |
-
|
172 |
-
|
173 |
-
def uniform_sample(size, low, high, device="cpu"):
|
174 |
-
return torch.empty(size, device=device).uniform_(low, high)
|
175 |
-
|
176 |
-
|
177 |
-
def normal_sample(size, mean=0, std=1, device="cpu"):
|
178 |
-
return torch.empty(size, device=device).normal_(mean, std)
|
179 |
-
|
180 |
-
|
181 |
-
def bernoulli_sample(size, p, device="cpu"):
|
182 |
-
return torch.empty(size, device=device).bernoulli_(p)
|
183 |
-
|
184 |
-
|
185 |
-
def random_mat_apply(p, transform, prev, eye, device="cpu"):
|
186 |
-
size = transform.shape[0]
|
187 |
-
select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
|
188 |
-
select_transform = select * transform + (1 - select) * eye
|
189 |
-
|
190 |
-
return select_transform @ prev
|
191 |
-
|
192 |
-
|
193 |
-
def sample_affine(p, size, height, width, device="cpu"):
|
194 |
-
G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
|
195 |
-
eye = G
|
196 |
-
|
197 |
-
# flip
|
198 |
-
param = category_sample(size, (0, 1))
|
199 |
-
Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
|
200 |
-
G = random_mat_apply(p, Gc, G, eye, device=device)
|
201 |
-
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
|
202 |
-
|
203 |
-
# 90 rotate
|
204 |
-
param = category_sample(size, (0, 3))
|
205 |
-
Gc = rotate_mat(-math.pi / 2 * param, device=device)
|
206 |
-
G = random_mat_apply(p, Gc, G, eye, device=device)
|
207 |
-
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
|
208 |
-
|
209 |
-
# integer translate
|
210 |
-
param = uniform_sample((2, size), -0.125, 0.125)
|
211 |
-
param_height = torch.round(param[0] * height)
|
212 |
-
param_width = torch.round(param[1] * width)
|
213 |
-
Gc = translate_mat(param_width, param_height, device=device)
|
214 |
-
G = random_mat_apply(p, Gc, G, eye, device=device)
|
215 |
-
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
|
216 |
-
|
217 |
-
# isotropic scale
|
218 |
-
param = lognormal_sample(size, std=0.2 * math.log(2))
|
219 |
-
Gc = scale_mat(param, param, device=device)
|
220 |
-
G = random_mat_apply(p, Gc, G, eye, device=device)
|
221 |
-
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
|
222 |
-
|
223 |
-
p_rot = 1 - math.sqrt(1 - p)
|
224 |
-
|
225 |
-
# pre-rotate
|
226 |
-
param = uniform_sample(size, -math.pi, math.pi)
|
227 |
-
Gc = rotate_mat(-param, device=device)
|
228 |
-
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
229 |
-
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
|
230 |
-
|
231 |
-
# anisotropic scale
|
232 |
-
param = lognormal_sample(size, std=0.2 * math.log(2))
|
233 |
-
Gc = scale_mat(param, 1 / param, device=device)
|
234 |
-
G = random_mat_apply(p, Gc, G, eye, device=device)
|
235 |
-
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
|
236 |
-
|
237 |
-
# post-rotate
|
238 |
-
param = uniform_sample(size, -math.pi, math.pi)
|
239 |
-
Gc = rotate_mat(-param, device=device)
|
240 |
-
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
241 |
-
# print('post-rotate', G, rotate_mat(-param), sep='\n')
|
242 |
-
|
243 |
-
# fractional translate
|
244 |
-
param = normal_sample((2, size), std=0.125)
|
245 |
-
Gc = translate_mat(param[1] * width, param[0] * height, device=device)
|
246 |
-
G = random_mat_apply(p, Gc, G, eye, device=device)
|
247 |
-
# print('fractional translate', G, translate_mat(param, param), sep='\n')
|
248 |
-
|
249 |
-
return G
|
250 |
-
|
251 |
-
|
252 |
-
def sample_color(p, size):
|
253 |
-
C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
|
254 |
-
eye = C
|
255 |
-
axis_val = 1 / math.sqrt(3)
|
256 |
-
axis = (axis_val, axis_val, axis_val)
|
257 |
-
|
258 |
-
# brightness
|
259 |
-
param = normal_sample(size, std=0.2)
|
260 |
-
Cc = translate3d_mat(param, param, param)
|
261 |
-
C = random_mat_apply(p, Cc, C, eye)
|
262 |
-
|
263 |
-
# contrast
|
264 |
-
param = lognormal_sample(size, std=0.5 * math.log(2))
|
265 |
-
Cc = scale3d_mat(param, param, param)
|
266 |
-
C = random_mat_apply(p, Cc, C, eye)
|
267 |
-
|
268 |
-
# luma flip
|
269 |
-
param = category_sample(size, (0, 1))
|
270 |
-
Cc = luma_flip_mat(axis, param)
|
271 |
-
C = random_mat_apply(p, Cc, C, eye)
|
272 |
-
|
273 |
-
# hue rotation
|
274 |
-
param = uniform_sample(size, -math.pi, math.pi)
|
275 |
-
Cc = rotate3d_mat(axis, param)
|
276 |
-
C = random_mat_apply(p, Cc, C, eye)
|
277 |
-
|
278 |
-
# saturation
|
279 |
-
param = lognormal_sample(size, std=1 * math.log(2))
|
280 |
-
Cc = saturation_mat(axis, param)
|
281 |
-
C = random_mat_apply(p, Cc, C, eye)
|
282 |
-
|
283 |
-
return C
|
284 |
-
|
285 |
-
|
286 |
-
def make_grid(shape, x0, x1, y0, y1, device):
|
287 |
-
n, c, h, w = shape
|
288 |
-
grid = torch.empty(n, h, w, 3, device=device)
|
289 |
-
grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
|
290 |
-
grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
|
291 |
-
grid[:, :, :, 2] = 1
|
292 |
-
|
293 |
-
return grid
|
294 |
-
|
295 |
-
|
296 |
-
def affine_grid(grid, mat):
|
297 |
-
n, h, w, _ = grid.shape
|
298 |
-
return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
|
299 |
-
|
300 |
-
|
301 |
-
def get_padding(G, height, width, kernel_size):
|
302 |
-
device = G.device
|
303 |
-
|
304 |
-
cx = (width - 1) / 2
|
305 |
-
cy = (height - 1) / 2
|
306 |
-
cp = torch.tensor(
|
307 |
-
[(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
|
308 |
-
)
|
309 |
-
cp = G @ cp.T
|
310 |
-
|
311 |
-
pad_k = kernel_size // 4
|
312 |
-
|
313 |
-
pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
|
314 |
-
pad = torch.cat((-pad, pad)).max(1).values
|
315 |
-
pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
|
316 |
-
pad = pad.max(torch.tensor([0, 0] * 2, device=device))
|
317 |
-
pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
|
318 |
-
|
319 |
-
pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
|
320 |
-
|
321 |
-
return pad_x1, pad_x2, pad_y1, pad_y2
|
322 |
-
|
323 |
-
|
324 |
-
def try_sample_affine_and_pad(img, p, kernel_size, G=None):
|
325 |
-
batch, _, height, width = img.shape
|
326 |
-
|
327 |
-
G_try = G
|
328 |
-
|
329 |
-
if G is None:
|
330 |
-
G_try = torch.inverse(sample_affine(p, batch, height, width))
|
331 |
-
|
332 |
-
pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
|
333 |
-
|
334 |
-
img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
|
335 |
-
|
336 |
-
return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
|
337 |
-
|
338 |
-
|
339 |
-
class GridSampleForward(autograd.Function):
|
340 |
-
@staticmethod
|
341 |
-
def forward(ctx, input, grid):
|
342 |
-
out = F.grid_sample(
|
343 |
-
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
|
344 |
-
)
|
345 |
-
ctx.save_for_backward(input, grid)
|
346 |
-
|
347 |
-
return out
|
348 |
-
|
349 |
-
@staticmethod
|
350 |
-
def backward(ctx, grad_output):
|
351 |
-
input, grid = ctx.saved_tensors
|
352 |
-
grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
|
353 |
-
|
354 |
-
return grad_input, grad_grid
|
355 |
-
|
356 |
-
|
357 |
-
class GridSampleBackward(autograd.Function):
|
358 |
-
@staticmethod
|
359 |
-
def forward(ctx, grad_output, input, grid):
|
360 |
-
op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
|
361 |
-
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
362 |
-
ctx.save_for_backward(grid)
|
363 |
-
|
364 |
-
return grad_input, grad_grid
|
365 |
-
|
366 |
-
@staticmethod
|
367 |
-
def backward(ctx, grad_grad_input, grad_grad_grid):
|
368 |
-
(grid,) = ctx.saved_tensors
|
369 |
-
grad_grad_output = None
|
370 |
-
|
371 |
-
if ctx.needs_input_grad[0]:
|
372 |
-
grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
|
373 |
-
|
374 |
-
return grad_grad_output, None, None
|
375 |
-
|
376 |
-
|
377 |
-
grid_sample = GridSampleForward.apply
|
378 |
-
|
379 |
-
|
380 |
-
def scale_mat_single(s_x, s_y):
|
381 |
-
return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
|
382 |
-
|
383 |
-
|
384 |
-
def translate_mat_single(t_x, t_y):
|
385 |
-
return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
|
386 |
-
|
387 |
-
|
388 |
-
def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
|
389 |
-
kernel = antialiasing_kernel
|
390 |
-
len_k = len(kernel)
|
391 |
-
|
392 |
-
kernel = torch.as_tensor(kernel).to(img)
|
393 |
-
# kernel = torch.ger(kernel, kernel).to(img)
|
394 |
-
kernel_flip = torch.flip(kernel, (0,))
|
395 |
-
|
396 |
-
img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
|
397 |
-
img, p, len_k, G
|
398 |
-
)
|
399 |
-
|
400 |
-
G_inv = (
|
401 |
-
translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
|
402 |
-
@ G
|
403 |
-
)
|
404 |
-
up_pad = (
|
405 |
-
(len_k + 2 - 1) // 2,
|
406 |
-
(len_k - 2) // 2,
|
407 |
-
(len_k + 2 - 1) // 2,
|
408 |
-
(len_k - 2) // 2,
|
409 |
-
)
|
410 |
-
img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
|
411 |
-
img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
|
412 |
-
G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
|
413 |
-
G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
|
414 |
-
batch_size, channel, height, width = img.shape
|
415 |
-
pad_k = len_k // 4
|
416 |
-
shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
|
417 |
-
G_inv = (
|
418 |
-
scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
|
419 |
-
@ G_inv
|
420 |
-
@ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
|
421 |
-
)
|
422 |
-
grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
|
423 |
-
img_affine = grid_sample(img_2x, grid)
|
424 |
-
d_p = -pad_k * 2
|
425 |
-
down_pad = (
|
426 |
-
d_p + (len_k - 2 + 1) // 2,
|
427 |
-
d_p + (len_k - 2) // 2,
|
428 |
-
d_p + (len_k - 2 + 1) // 2,
|
429 |
-
d_p + (len_k - 2) // 2,
|
430 |
-
)
|
431 |
-
img_down = upfirdn2d(
|
432 |
-
img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
|
433 |
-
)
|
434 |
-
img_down = upfirdn2d(
|
435 |
-
img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
|
436 |
-
)
|
437 |
-
|
438 |
-
return img_down, G
|
439 |
-
|
440 |
-
|
441 |
-
def apply_color(img, mat):
|
442 |
-
batch = img.shape[0]
|
443 |
-
img = img.permute(0, 2, 3, 1)
|
444 |
-
mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
|
445 |
-
mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
|
446 |
-
img = img @ mat_mul + mat_add
|
447 |
-
img = img.permute(0, 3, 1, 2)
|
448 |
-
|
449 |
-
return img
|
450 |
-
|
451 |
-
|
452 |
-
def random_apply_color(img, p, C=None):
|
453 |
-
if C is None:
|
454 |
-
C = sample_color(p, img.shape[0])
|
455 |
-
|
456 |
-
img = apply_color(img, C.to(img))
|
457 |
-
|
458 |
-
return img, C
|
459 |
-
|
460 |
-
|
461 |
-
def augment(img, p, transform_matrix=(None, None)):
|
462 |
-
img, G = random_apply_affine(img, p, transform_matrix[0])
|
463 |
-
img, C = random_apply_color(img, p, transform_matrix[1])
|
464 |
-
|
465 |
-
return img, (G, C)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
-
from .upfirdn2d import upfirdn2d
|
|
|
|
|
|
op/conv2d_gradfix.py
DELETED
@@ -1,227 +0,0 @@
|
|
1 |
-
import contextlib
|
2 |
-
import warnings
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import autograd
|
6 |
-
from torch.nn import functional as F
|
7 |
-
|
8 |
-
enabled = True
|
9 |
-
weight_gradients_disabled = False
|
10 |
-
|
11 |
-
|
12 |
-
@contextlib.contextmanager
|
13 |
-
def no_weight_gradients():
|
14 |
-
global weight_gradients_disabled
|
15 |
-
|
16 |
-
old = weight_gradients_disabled
|
17 |
-
weight_gradients_disabled = True
|
18 |
-
yield
|
19 |
-
weight_gradients_disabled = old
|
20 |
-
|
21 |
-
|
22 |
-
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
23 |
-
if could_use_op(input):
|
24 |
-
return conv2d_gradfix(
|
25 |
-
transpose=False,
|
26 |
-
weight_shape=weight.shape,
|
27 |
-
stride=stride,
|
28 |
-
padding=padding,
|
29 |
-
output_padding=0,
|
30 |
-
dilation=dilation,
|
31 |
-
groups=groups,
|
32 |
-
).apply(input, weight, bias)
|
33 |
-
|
34 |
-
return F.conv2d(
|
35 |
-
input=input,
|
36 |
-
weight=weight,
|
37 |
-
bias=bias,
|
38 |
-
stride=stride,
|
39 |
-
padding=padding,
|
40 |
-
dilation=dilation,
|
41 |
-
groups=groups,
|
42 |
-
)
|
43 |
-
|
44 |
-
|
45 |
-
def conv_transpose2d(
|
46 |
-
input,
|
47 |
-
weight,
|
48 |
-
bias=None,
|
49 |
-
stride=1,
|
50 |
-
padding=0,
|
51 |
-
output_padding=0,
|
52 |
-
groups=1,
|
53 |
-
dilation=1,
|
54 |
-
):
|
55 |
-
if could_use_op(input):
|
56 |
-
return conv2d_gradfix(
|
57 |
-
transpose=True,
|
58 |
-
weight_shape=weight.shape,
|
59 |
-
stride=stride,
|
60 |
-
padding=padding,
|
61 |
-
output_padding=output_padding,
|
62 |
-
groups=groups,
|
63 |
-
dilation=dilation,
|
64 |
-
).apply(input, weight, bias)
|
65 |
-
|
66 |
-
return F.conv_transpose2d(
|
67 |
-
input=input,
|
68 |
-
weight=weight,
|
69 |
-
bias=bias,
|
70 |
-
stride=stride,
|
71 |
-
padding=padding,
|
72 |
-
output_padding=output_padding,
|
73 |
-
dilation=dilation,
|
74 |
-
groups=groups,
|
75 |
-
)
|
76 |
-
|
77 |
-
|
78 |
-
def could_use_op(input):
|
79 |
-
if (not enabled) or (not torch.backends.cudnn.enabled):
|
80 |
-
return False
|
81 |
-
|
82 |
-
if input.device.type != "cuda":
|
83 |
-
return False
|
84 |
-
|
85 |
-
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
86 |
-
return True
|
87 |
-
|
88 |
-
warnings.warn(
|
89 |
-
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
90 |
-
)
|
91 |
-
|
92 |
-
return False
|
93 |
-
|
94 |
-
|
95 |
-
def ensure_tuple(xs, ndim):
|
96 |
-
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
97 |
-
|
98 |
-
return xs
|
99 |
-
|
100 |
-
|
101 |
-
conv2d_gradfix_cache = dict()
|
102 |
-
|
103 |
-
|
104 |
-
def conv2d_gradfix(
|
105 |
-
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
106 |
-
):
|
107 |
-
ndim = 2
|
108 |
-
weight_shape = tuple(weight_shape)
|
109 |
-
stride = ensure_tuple(stride, ndim)
|
110 |
-
padding = ensure_tuple(padding, ndim)
|
111 |
-
output_padding = ensure_tuple(output_padding, ndim)
|
112 |
-
dilation = ensure_tuple(dilation, ndim)
|
113 |
-
|
114 |
-
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
115 |
-
if key in conv2d_gradfix_cache:
|
116 |
-
return conv2d_gradfix_cache[key]
|
117 |
-
|
118 |
-
common_kwargs = dict(
|
119 |
-
stride=stride, padding=padding, dilation=dilation, groups=groups
|
120 |
-
)
|
121 |
-
|
122 |
-
def calc_output_padding(input_shape, output_shape):
|
123 |
-
if transpose:
|
124 |
-
return [0, 0]
|
125 |
-
|
126 |
-
return [
|
127 |
-
input_shape[i + 2]
|
128 |
-
- (output_shape[i + 2] - 1) * stride[i]
|
129 |
-
- (1 - 2 * padding[i])
|
130 |
-
- dilation[i] * (weight_shape[i + 2] - 1)
|
131 |
-
for i in range(ndim)
|
132 |
-
]
|
133 |
-
|
134 |
-
class Conv2d(autograd.Function):
|
135 |
-
@staticmethod
|
136 |
-
def forward(ctx, input, weight, bias):
|
137 |
-
if not transpose:
|
138 |
-
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
139 |
-
|
140 |
-
else:
|
141 |
-
out = F.conv_transpose2d(
|
142 |
-
input=input,
|
143 |
-
weight=weight,
|
144 |
-
bias=bias,
|
145 |
-
output_padding=output_padding,
|
146 |
-
**common_kwargs,
|
147 |
-
)
|
148 |
-
|
149 |
-
ctx.save_for_backward(input, weight)
|
150 |
-
|
151 |
-
return out
|
152 |
-
|
153 |
-
@staticmethod
|
154 |
-
def backward(ctx, grad_output):
|
155 |
-
input, weight = ctx.saved_tensors
|
156 |
-
grad_input, grad_weight, grad_bias = None, None, None
|
157 |
-
|
158 |
-
if ctx.needs_input_grad[0]:
|
159 |
-
p = calc_output_padding(
|
160 |
-
input_shape=input.shape, output_shape=grad_output.shape
|
161 |
-
)
|
162 |
-
grad_input = conv2d_gradfix(
|
163 |
-
transpose=(not transpose),
|
164 |
-
weight_shape=weight_shape,
|
165 |
-
output_padding=p,
|
166 |
-
**common_kwargs,
|
167 |
-
).apply(grad_output, weight, None)
|
168 |
-
|
169 |
-
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
170 |
-
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
171 |
-
|
172 |
-
if ctx.needs_input_grad[2]:
|
173 |
-
grad_bias = grad_output.sum((0, 2, 3))
|
174 |
-
|
175 |
-
return grad_input, grad_weight, grad_bias
|
176 |
-
|
177 |
-
class Conv2dGradWeight(autograd.Function):
|
178 |
-
@staticmethod
|
179 |
-
def forward(ctx, grad_output, input):
|
180 |
-
op = torch._C._jit_get_operation(
|
181 |
-
"aten::cudnn_convolution_backward_weight"
|
182 |
-
if not transpose
|
183 |
-
else "aten::cudnn_convolution_transpose_backward_weight"
|
184 |
-
)
|
185 |
-
flags = [
|
186 |
-
torch.backends.cudnn.benchmark,
|
187 |
-
torch.backends.cudnn.deterministic,
|
188 |
-
torch.backends.cudnn.allow_tf32,
|
189 |
-
]
|
190 |
-
grad_weight = op(
|
191 |
-
weight_shape,
|
192 |
-
grad_output,
|
193 |
-
input,
|
194 |
-
padding,
|
195 |
-
stride,
|
196 |
-
dilation,
|
197 |
-
groups,
|
198 |
-
*flags,
|
199 |
-
)
|
200 |
-
ctx.save_for_backward(grad_output, input)
|
201 |
-
|
202 |
-
return grad_weight
|
203 |
-
|
204 |
-
@staticmethod
|
205 |
-
def backward(ctx, grad_grad_weight):
|
206 |
-
grad_output, input = ctx.saved_tensors
|
207 |
-
grad_grad_output, grad_grad_input = None, None
|
208 |
-
|
209 |
-
if ctx.needs_input_grad[0]:
|
210 |
-
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
211 |
-
|
212 |
-
if ctx.needs_input_grad[1]:
|
213 |
-
p = calc_output_padding(
|
214 |
-
input_shape=input.shape, output_shape=grad_output.shape
|
215 |
-
)
|
216 |
-
grad_grad_input = conv2d_gradfix(
|
217 |
-
transpose=(not transpose),
|
218 |
-
weight_shape=weight_shape,
|
219 |
-
output_padding=p,
|
220 |
-
**common_kwargs,
|
221 |
-
).apply(grad_output, grad_grad_weight, None)
|
222 |
-
|
223 |
-
return grad_grad_output, grad_grad_input
|
224 |
-
|
225 |
-
conv2d_gradfix_cache[key] = Conv2d
|
226 |
-
|
227 |
-
return Conv2d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/fused_act.py
DELETED
@@ -1,127 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn import functional as F
|
6 |
-
from torch.autograd import Function
|
7 |
-
from torch.utils.cpp_extension import load
|
8 |
-
|
9 |
-
|
10 |
-
module_path = os.path.dirname(__file__)
|
11 |
-
fused = load(
|
12 |
-
"fused",
|
13 |
-
sources=[
|
14 |
-
os.path.join(module_path, "fused_bias_act.cpp"),
|
15 |
-
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
16 |
-
],
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
class FusedLeakyReLUFunctionBackward(Function):
|
21 |
-
@staticmethod
|
22 |
-
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
23 |
-
ctx.save_for_backward(out)
|
24 |
-
ctx.negative_slope = negative_slope
|
25 |
-
ctx.scale = scale
|
26 |
-
|
27 |
-
empty = grad_output.new_empty(0)
|
28 |
-
|
29 |
-
grad_input = fused.fused_bias_act(
|
30 |
-
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
|
31 |
-
)
|
32 |
-
|
33 |
-
dim = [0]
|
34 |
-
|
35 |
-
if grad_input.ndim > 2:
|
36 |
-
dim += list(range(2, grad_input.ndim))
|
37 |
-
|
38 |
-
if bias:
|
39 |
-
grad_bias = grad_input.sum(dim).detach()
|
40 |
-
|
41 |
-
else:
|
42 |
-
grad_bias = empty
|
43 |
-
|
44 |
-
return grad_input, grad_bias
|
45 |
-
|
46 |
-
@staticmethod
|
47 |
-
def backward(ctx, gradgrad_input, gradgrad_bias):
|
48 |
-
out, = ctx.saved_tensors
|
49 |
-
gradgrad_out = fused.fused_bias_act(
|
50 |
-
gradgrad_input.contiguous(),
|
51 |
-
gradgrad_bias,
|
52 |
-
out,
|
53 |
-
3,
|
54 |
-
1,
|
55 |
-
ctx.negative_slope,
|
56 |
-
ctx.scale,
|
57 |
-
)
|
58 |
-
|
59 |
-
return gradgrad_out, None, None, None, None
|
60 |
-
|
61 |
-
|
62 |
-
class FusedLeakyReLUFunction(Function):
|
63 |
-
@staticmethod
|
64 |
-
def forward(ctx, input, bias, negative_slope, scale):
|
65 |
-
empty = input.new_empty(0)
|
66 |
-
|
67 |
-
ctx.bias = bias is not None
|
68 |
-
|
69 |
-
if bias is None:
|
70 |
-
bias = empty
|
71 |
-
|
72 |
-
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
73 |
-
ctx.save_for_backward(out)
|
74 |
-
ctx.negative_slope = negative_slope
|
75 |
-
ctx.scale = scale
|
76 |
-
|
77 |
-
return out
|
78 |
-
|
79 |
-
@staticmethod
|
80 |
-
def backward(ctx, grad_output):
|
81 |
-
out, = ctx.saved_tensors
|
82 |
-
|
83 |
-
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
84 |
-
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
|
85 |
-
)
|
86 |
-
|
87 |
-
if not ctx.bias:
|
88 |
-
grad_bias = None
|
89 |
-
|
90 |
-
return grad_input, grad_bias, None, None
|
91 |
-
|
92 |
-
|
93 |
-
class FusedLeakyReLU(nn.Module):
|
94 |
-
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
95 |
-
super().__init__()
|
96 |
-
|
97 |
-
if bias:
|
98 |
-
self.bias = nn.Parameter(torch.zeros(channel))
|
99 |
-
|
100 |
-
else:
|
101 |
-
self.bias = None
|
102 |
-
|
103 |
-
self.negative_slope = negative_slope
|
104 |
-
self.scale = scale
|
105 |
-
|
106 |
-
def forward(self, input):
|
107 |
-
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
108 |
-
|
109 |
-
|
110 |
-
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
111 |
-
if input.device.type == "cpu":
|
112 |
-
if bias is not None:
|
113 |
-
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
114 |
-
return (
|
115 |
-
F.leaky_relu(
|
116 |
-
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
117 |
-
)
|
118 |
-
* scale
|
119 |
-
)
|
120 |
-
|
121 |
-
else:
|
122 |
-
return F.leaky_relu(input, negative_slope=0.2) * scale
|
123 |
-
|
124 |
-
else:
|
125 |
-
return FusedLeakyReLUFunction.apply(
|
126 |
-
input.contiguous(), bias, negative_slope, scale
|
127 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/fused_bias_act.cpp
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
|
2 |
-
#include <ATen/ATen.h>
|
3 |
-
#include <torch/extension.h>
|
4 |
-
|
5 |
-
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
6 |
-
const torch::Tensor &bias,
|
7 |
-
const torch::Tensor &refer, int act, int grad,
|
8 |
-
float alpha, float scale);
|
9 |
-
|
10 |
-
#define CHECK_CUDA(x) \
|
11 |
-
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
12 |
-
#define CHECK_CONTIGUOUS(x) \
|
13 |
-
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
14 |
-
#define CHECK_INPUT(x) \
|
15 |
-
CHECK_CUDA(x); \
|
16 |
-
CHECK_CONTIGUOUS(x)
|
17 |
-
|
18 |
-
torch::Tensor fused_bias_act(const torch::Tensor &input,
|
19 |
-
const torch::Tensor &bias,
|
20 |
-
const torch::Tensor &refer, int act, int grad,
|
21 |
-
float alpha, float scale) {
|
22 |
-
CHECK_INPUT(input);
|
23 |
-
CHECK_INPUT(bias);
|
24 |
-
|
25 |
-
at::DeviceGuard guard(input.device());
|
26 |
-
|
27 |
-
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
28 |
-
}
|
29 |
-
|
30 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
31 |
-
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
32 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/fused_bias_act_kernel.cu
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
-
//
|
3 |
-
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
-
// To view a copy of this license, visit
|
5 |
-
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
-
|
7 |
-
#include <torch/types.h>
|
8 |
-
|
9 |
-
#include <ATen/ATen.h>
|
10 |
-
#include <ATen/AccumulateType.h>
|
11 |
-
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
-
#include <ATen/cuda/CUDAContext.h>
|
13 |
-
|
14 |
-
|
15 |
-
#include <cuda.h>
|
16 |
-
#include <cuda_runtime.h>
|
17 |
-
|
18 |
-
template <typename scalar_t>
|
19 |
-
static __global__ void
|
20 |
-
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
|
21 |
-
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
|
22 |
-
scalar_t scale, int loop_x, int size_x, int step_b,
|
23 |
-
int size_b, int use_bias, int use_ref) {
|
24 |
-
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
25 |
-
|
26 |
-
scalar_t zero = 0.0;
|
27 |
-
|
28 |
-
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
|
29 |
-
loop_idx++, xi += blockDim.x) {
|
30 |
-
scalar_t x = p_x[xi];
|
31 |
-
|
32 |
-
if (use_bias) {
|
33 |
-
x += p_b[(xi / step_b) % size_b];
|
34 |
-
}
|
35 |
-
|
36 |
-
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
37 |
-
|
38 |
-
scalar_t y;
|
39 |
-
|
40 |
-
switch (act * 10 + grad) {
|
41 |
-
default:
|
42 |
-
case 10:
|
43 |
-
y = x;
|
44 |
-
break;
|
45 |
-
case 11:
|
46 |
-
y = x;
|
47 |
-
break;
|
48 |
-
case 12:
|
49 |
-
y = 0.0;
|
50 |
-
break;
|
51 |
-
|
52 |
-
case 30:
|
53 |
-
y = (x > 0.0) ? x : x * alpha;
|
54 |
-
break;
|
55 |
-
case 31:
|
56 |
-
y = (ref > 0.0) ? x : x * alpha;
|
57 |
-
break;
|
58 |
-
case 32:
|
59 |
-
y = 0.0;
|
60 |
-
break;
|
61 |
-
}
|
62 |
-
|
63 |
-
out[xi] = y * scale;
|
64 |
-
}
|
65 |
-
}
|
66 |
-
|
67 |
-
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
68 |
-
const torch::Tensor &bias,
|
69 |
-
const torch::Tensor &refer, int act, int grad,
|
70 |
-
float alpha, float scale) {
|
71 |
-
int curDevice = -1;
|
72 |
-
cudaGetDevice(&curDevice);
|
73 |
-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
74 |
-
|
75 |
-
auto x = input.contiguous();
|
76 |
-
auto b = bias.contiguous();
|
77 |
-
auto ref = refer.contiguous();
|
78 |
-
|
79 |
-
int use_bias = b.numel() ? 1 : 0;
|
80 |
-
int use_ref = ref.numel() ? 1 : 0;
|
81 |
-
|
82 |
-
int size_x = x.numel();
|
83 |
-
int size_b = b.numel();
|
84 |
-
int step_b = 1;
|
85 |
-
|
86 |
-
for (int i = 1 + 1; i < x.dim(); i++) {
|
87 |
-
step_b *= x.size(i);
|
88 |
-
}
|
89 |
-
|
90 |
-
int loop_x = 4;
|
91 |
-
int block_size = 4 * 32;
|
92 |
-
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
93 |
-
|
94 |
-
auto y = torch::empty_like(x);
|
95 |
-
|
96 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
97 |
-
x.scalar_type(), "fused_bias_act_kernel", [&] {
|
98 |
-
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
99 |
-
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
100 |
-
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
|
101 |
-
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
|
102 |
-
});
|
103 |
-
|
104 |
-
return y;
|
105 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/upfirdn2d.cpp
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
#include <ATen/ATen.h>
|
2 |
-
#include <torch/extension.h>
|
3 |
-
|
4 |
-
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
5 |
-
const torch::Tensor &kernel, int up_x, int up_y,
|
6 |
-
int down_x, int down_y, int pad_x0, int pad_x1,
|
7 |
-
int pad_y0, int pad_y1);
|
8 |
-
|
9 |
-
#define CHECK_CUDA(x) \
|
10 |
-
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
-
#define CHECK_CONTIGUOUS(x) \
|
12 |
-
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
13 |
-
#define CHECK_INPUT(x) \
|
14 |
-
CHECK_CUDA(x); \
|
15 |
-
CHECK_CONTIGUOUS(x)
|
16 |
-
|
17 |
-
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
|
18 |
-
int up_x, int up_y, int down_x, int down_y, int pad_x0,
|
19 |
-
int pad_x1, int pad_y0, int pad_y1) {
|
20 |
-
CHECK_INPUT(input);
|
21 |
-
CHECK_INPUT(kernel);
|
22 |
-
|
23 |
-
at::DeviceGuard guard(input.device());
|
24 |
-
|
25 |
-
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
26 |
-
pad_y0, pad_y1);
|
27 |
-
}
|
28 |
-
|
29 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
30 |
-
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
31 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/upfirdn2d.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
from collections import abc
|
2 |
-
import os
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch.nn import functional as F
|
6 |
-
from torch.autograd import Function
|
7 |
-
from torch.utils.cpp_extension import load
|
8 |
-
|
9 |
-
|
10 |
-
module_path = os.path.dirname(__file__)
|
11 |
-
upfirdn2d_op = load(
|
12 |
-
"upfirdn2d",
|
13 |
-
sources=[
|
14 |
-
os.path.join(module_path, "upfirdn2d.cpp"),
|
15 |
-
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
16 |
-
],
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
class UpFirDn2dBackward(Function):
|
21 |
-
@staticmethod
|
22 |
-
def forward(
|
23 |
-
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
24 |
-
):
|
25 |
-
|
26 |
-
up_x, up_y = up
|
27 |
-
down_x, down_y = down
|
28 |
-
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
29 |
-
|
30 |
-
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
31 |
-
|
32 |
-
grad_input = upfirdn2d_op.upfirdn2d(
|
33 |
-
grad_output,
|
34 |
-
grad_kernel,
|
35 |
-
down_x,
|
36 |
-
down_y,
|
37 |
-
up_x,
|
38 |
-
up_y,
|
39 |
-
g_pad_x0,
|
40 |
-
g_pad_x1,
|
41 |
-
g_pad_y0,
|
42 |
-
g_pad_y1,
|
43 |
-
)
|
44 |
-
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
45 |
-
|
46 |
-
ctx.save_for_backward(kernel)
|
47 |
-
|
48 |
-
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
49 |
-
|
50 |
-
ctx.up_x = up_x
|
51 |
-
ctx.up_y = up_y
|
52 |
-
ctx.down_x = down_x
|
53 |
-
ctx.down_y = down_y
|
54 |
-
ctx.pad_x0 = pad_x0
|
55 |
-
ctx.pad_x1 = pad_x1
|
56 |
-
ctx.pad_y0 = pad_y0
|
57 |
-
ctx.pad_y1 = pad_y1
|
58 |
-
ctx.in_size = in_size
|
59 |
-
ctx.out_size = out_size
|
60 |
-
|
61 |
-
return grad_input
|
62 |
-
|
63 |
-
@staticmethod
|
64 |
-
def backward(ctx, gradgrad_input):
|
65 |
-
kernel, = ctx.saved_tensors
|
66 |
-
|
67 |
-
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
68 |
-
|
69 |
-
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
70 |
-
gradgrad_input,
|
71 |
-
kernel,
|
72 |
-
ctx.up_x,
|
73 |
-
ctx.up_y,
|
74 |
-
ctx.down_x,
|
75 |
-
ctx.down_y,
|
76 |
-
ctx.pad_x0,
|
77 |
-
ctx.pad_x1,
|
78 |
-
ctx.pad_y0,
|
79 |
-
ctx.pad_y1,
|
80 |
-
)
|
81 |
-
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
82 |
-
gradgrad_out = gradgrad_out.view(
|
83 |
-
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
84 |
-
)
|
85 |
-
|
86 |
-
return gradgrad_out, None, None, None, None, None, None, None, None
|
87 |
-
|
88 |
-
|
89 |
-
class UpFirDn2d(Function):
|
90 |
-
@staticmethod
|
91 |
-
def forward(ctx, input, kernel, up, down, pad):
|
92 |
-
up_x, up_y = up
|
93 |
-
down_x, down_y = down
|
94 |
-
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
95 |
-
|
96 |
-
kernel_h, kernel_w = kernel.shape
|
97 |
-
batch, channel, in_h, in_w = input.shape
|
98 |
-
ctx.in_size = input.shape
|
99 |
-
|
100 |
-
input = input.reshape(-1, in_h, in_w, 1)
|
101 |
-
|
102 |
-
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
103 |
-
|
104 |
-
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
105 |
-
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
106 |
-
ctx.out_size = (out_h, out_w)
|
107 |
-
|
108 |
-
ctx.up = (up_x, up_y)
|
109 |
-
ctx.down = (down_x, down_y)
|
110 |
-
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
111 |
-
|
112 |
-
g_pad_x0 = kernel_w - pad_x0 - 1
|
113 |
-
g_pad_y0 = kernel_h - pad_y0 - 1
|
114 |
-
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
115 |
-
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
116 |
-
|
117 |
-
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
118 |
-
|
119 |
-
out = upfirdn2d_op.upfirdn2d(
|
120 |
-
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
121 |
-
)
|
122 |
-
# out = out.view(major, out_h, out_w, minor)
|
123 |
-
out = out.view(-1, channel, out_h, out_w)
|
124 |
-
|
125 |
-
return out
|
126 |
-
|
127 |
-
@staticmethod
|
128 |
-
def backward(ctx, grad_output):
|
129 |
-
kernel, grad_kernel = ctx.saved_tensors
|
130 |
-
|
131 |
-
grad_input = None
|
132 |
-
|
133 |
-
if ctx.needs_input_grad[0]:
|
134 |
-
grad_input = UpFirDn2dBackward.apply(
|
135 |
-
grad_output,
|
136 |
-
kernel,
|
137 |
-
grad_kernel,
|
138 |
-
ctx.up,
|
139 |
-
ctx.down,
|
140 |
-
ctx.pad,
|
141 |
-
ctx.g_pad,
|
142 |
-
ctx.in_size,
|
143 |
-
ctx.out_size,
|
144 |
-
)
|
145 |
-
|
146 |
-
return grad_input, None, None, None, None
|
147 |
-
|
148 |
-
|
149 |
-
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
150 |
-
if not isinstance(up, abc.Iterable):
|
151 |
-
up = (up, up)
|
152 |
-
|
153 |
-
if not isinstance(down, abc.Iterable):
|
154 |
-
down = (down, down)
|
155 |
-
|
156 |
-
if len(pad) == 2:
|
157 |
-
pad = (pad[0], pad[1], pad[0], pad[1])
|
158 |
-
|
159 |
-
if input.device.type == "cpu":
|
160 |
-
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
|
161 |
-
|
162 |
-
else:
|
163 |
-
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
164 |
-
|
165 |
-
return out
|
166 |
-
|
167 |
-
|
168 |
-
def upfirdn2d_native(
|
169 |
-
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
170 |
-
):
|
171 |
-
_, channel, in_h, in_w = input.shape
|
172 |
-
input = input.reshape(-1, in_h, in_w, 1)
|
173 |
-
|
174 |
-
_, in_h, in_w, minor = input.shape
|
175 |
-
kernel_h, kernel_w = kernel.shape
|
176 |
-
|
177 |
-
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
178 |
-
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
179 |
-
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
180 |
-
|
181 |
-
out = F.pad(
|
182 |
-
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
183 |
-
)
|
184 |
-
out = out[
|
185 |
-
:,
|
186 |
-
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
187 |
-
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
188 |
-
:,
|
189 |
-
]
|
190 |
-
|
191 |
-
out = out.permute(0, 3, 1, 2)
|
192 |
-
out = out.reshape(
|
193 |
-
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
194 |
-
)
|
195 |
-
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
196 |
-
out = F.conv2d(out, w)
|
197 |
-
out = out.reshape(
|
198 |
-
-1,
|
199 |
-
minor,
|
200 |
-
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
201 |
-
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
202 |
-
)
|
203 |
-
out = out.permute(0, 2, 3, 1)
|
204 |
-
out = out[:, ::down_y, ::down_x, :]
|
205 |
-
|
206 |
-
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
207 |
-
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
208 |
-
|
209 |
-
return out.view(-1, channel, out_h, out_w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/upfirdn2d_kernel.cu
DELETED
@@ -1,369 +0,0 @@
|
|
1 |
-
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
-
//
|
3 |
-
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
-
// To view a copy of this license, visit
|
5 |
-
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
-
|
7 |
-
#include <torch/types.h>
|
8 |
-
|
9 |
-
#include <ATen/ATen.h>
|
10 |
-
#include <ATen/AccumulateType.h>
|
11 |
-
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
-
#include <ATen/cuda/CUDAContext.h>
|
13 |
-
|
14 |
-
#include <cuda.h>
|
15 |
-
#include <cuda_runtime.h>
|
16 |
-
|
17 |
-
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
18 |
-
int c = a / b;
|
19 |
-
|
20 |
-
if (c * b > a) {
|
21 |
-
c--;
|
22 |
-
}
|
23 |
-
|
24 |
-
return c;
|
25 |
-
}
|
26 |
-
|
27 |
-
struct UpFirDn2DKernelParams {
|
28 |
-
int up_x;
|
29 |
-
int up_y;
|
30 |
-
int down_x;
|
31 |
-
int down_y;
|
32 |
-
int pad_x0;
|
33 |
-
int pad_x1;
|
34 |
-
int pad_y0;
|
35 |
-
int pad_y1;
|
36 |
-
|
37 |
-
int major_dim;
|
38 |
-
int in_h;
|
39 |
-
int in_w;
|
40 |
-
int minor_dim;
|
41 |
-
int kernel_h;
|
42 |
-
int kernel_w;
|
43 |
-
int out_h;
|
44 |
-
int out_w;
|
45 |
-
int loop_major;
|
46 |
-
int loop_x;
|
47 |
-
};
|
48 |
-
|
49 |
-
template <typename scalar_t>
|
50 |
-
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
51 |
-
const scalar_t *kernel,
|
52 |
-
const UpFirDn2DKernelParams p) {
|
53 |
-
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
54 |
-
int out_y = minor_idx / p.minor_dim;
|
55 |
-
minor_idx -= out_y * p.minor_dim;
|
56 |
-
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
57 |
-
int major_idx_base = blockIdx.z * p.loop_major;
|
58 |
-
|
59 |
-
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
60 |
-
major_idx_base >= p.major_dim) {
|
61 |
-
return;
|
62 |
-
}
|
63 |
-
|
64 |
-
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
65 |
-
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
66 |
-
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
67 |
-
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
68 |
-
|
69 |
-
for (int loop_major = 0, major_idx = major_idx_base;
|
70 |
-
loop_major < p.loop_major && major_idx < p.major_dim;
|
71 |
-
loop_major++, major_idx++) {
|
72 |
-
for (int loop_x = 0, out_x = out_x_base;
|
73 |
-
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
74 |
-
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
75 |
-
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
76 |
-
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
77 |
-
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
78 |
-
|
79 |
-
const scalar_t *x_p =
|
80 |
-
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
81 |
-
minor_idx];
|
82 |
-
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
83 |
-
int x_px = p.minor_dim;
|
84 |
-
int k_px = -p.up_x;
|
85 |
-
int x_py = p.in_w * p.minor_dim;
|
86 |
-
int k_py = -p.up_y * p.kernel_w;
|
87 |
-
|
88 |
-
scalar_t v = 0.0f;
|
89 |
-
|
90 |
-
for (int y = 0; y < h; y++) {
|
91 |
-
for (int x = 0; x < w; x++) {
|
92 |
-
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
93 |
-
x_p += x_px;
|
94 |
-
k_p += k_px;
|
95 |
-
}
|
96 |
-
|
97 |
-
x_p += x_py - w * x_px;
|
98 |
-
k_p += k_py - w * k_px;
|
99 |
-
}
|
100 |
-
|
101 |
-
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
102 |
-
minor_idx] = v;
|
103 |
-
}
|
104 |
-
}
|
105 |
-
}
|
106 |
-
|
107 |
-
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
108 |
-
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
109 |
-
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
110 |
-
const scalar_t *kernel,
|
111 |
-
const UpFirDn2DKernelParams p) {
|
112 |
-
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
113 |
-
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
114 |
-
|
115 |
-
__shared__ volatile float sk[kernel_h][kernel_w];
|
116 |
-
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
117 |
-
|
118 |
-
int minor_idx = blockIdx.x;
|
119 |
-
int tile_out_y = minor_idx / p.minor_dim;
|
120 |
-
minor_idx -= tile_out_y * p.minor_dim;
|
121 |
-
tile_out_y *= tile_out_h;
|
122 |
-
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
123 |
-
int major_idx_base = blockIdx.z * p.loop_major;
|
124 |
-
|
125 |
-
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
126 |
-
major_idx_base >= p.major_dim) {
|
127 |
-
return;
|
128 |
-
}
|
129 |
-
|
130 |
-
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
131 |
-
tap_idx += blockDim.x) {
|
132 |
-
int ky = tap_idx / kernel_w;
|
133 |
-
int kx = tap_idx - ky * kernel_w;
|
134 |
-
scalar_t v = 0.0;
|
135 |
-
|
136 |
-
if (kx < p.kernel_w & ky < p.kernel_h) {
|
137 |
-
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
138 |
-
}
|
139 |
-
|
140 |
-
sk[ky][kx] = v;
|
141 |
-
}
|
142 |
-
|
143 |
-
for (int loop_major = 0, major_idx = major_idx_base;
|
144 |
-
loop_major < p.loop_major & major_idx < p.major_dim;
|
145 |
-
loop_major++, major_idx++) {
|
146 |
-
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
147 |
-
loop_x < p.loop_x & tile_out_x < p.out_w;
|
148 |
-
loop_x++, tile_out_x += tile_out_w) {
|
149 |
-
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
150 |
-
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
151 |
-
int tile_in_x = floor_div(tile_mid_x, up_x);
|
152 |
-
int tile_in_y = floor_div(tile_mid_y, up_y);
|
153 |
-
|
154 |
-
__syncthreads();
|
155 |
-
|
156 |
-
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
157 |
-
in_idx += blockDim.x) {
|
158 |
-
int rel_in_y = in_idx / tile_in_w;
|
159 |
-
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
160 |
-
int in_x = rel_in_x + tile_in_x;
|
161 |
-
int in_y = rel_in_y + tile_in_y;
|
162 |
-
|
163 |
-
scalar_t v = 0.0;
|
164 |
-
|
165 |
-
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
166 |
-
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
167 |
-
p.minor_dim +
|
168 |
-
minor_idx];
|
169 |
-
}
|
170 |
-
|
171 |
-
sx[rel_in_y][rel_in_x] = v;
|
172 |
-
}
|
173 |
-
|
174 |
-
__syncthreads();
|
175 |
-
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
176 |
-
out_idx += blockDim.x) {
|
177 |
-
int rel_out_y = out_idx / tile_out_w;
|
178 |
-
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
179 |
-
int out_x = rel_out_x + tile_out_x;
|
180 |
-
int out_y = rel_out_y + tile_out_y;
|
181 |
-
|
182 |
-
int mid_x = tile_mid_x + rel_out_x * down_x;
|
183 |
-
int mid_y = tile_mid_y + rel_out_y * down_y;
|
184 |
-
int in_x = floor_div(mid_x, up_x);
|
185 |
-
int in_y = floor_div(mid_y, up_y);
|
186 |
-
int rel_in_x = in_x - tile_in_x;
|
187 |
-
int rel_in_y = in_y - tile_in_y;
|
188 |
-
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
189 |
-
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
190 |
-
|
191 |
-
scalar_t v = 0.0;
|
192 |
-
|
193 |
-
#pragma unroll
|
194 |
-
for (int y = 0; y < kernel_h / up_y; y++)
|
195 |
-
#pragma unroll
|
196 |
-
for (int x = 0; x < kernel_w / up_x; x++)
|
197 |
-
v += sx[rel_in_y + y][rel_in_x + x] *
|
198 |
-
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
199 |
-
|
200 |
-
if (out_x < p.out_w & out_y < p.out_h) {
|
201 |
-
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
202 |
-
minor_idx] = v;
|
203 |
-
}
|
204 |
-
}
|
205 |
-
}
|
206 |
-
}
|
207 |
-
}
|
208 |
-
|
209 |
-
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
210 |
-
const torch::Tensor &kernel, int up_x, int up_y,
|
211 |
-
int down_x, int down_y, int pad_x0, int pad_x1,
|
212 |
-
int pad_y0, int pad_y1) {
|
213 |
-
int curDevice = -1;
|
214 |
-
cudaGetDevice(&curDevice);
|
215 |
-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
216 |
-
|
217 |
-
UpFirDn2DKernelParams p;
|
218 |
-
|
219 |
-
auto x = input.contiguous();
|
220 |
-
auto k = kernel.contiguous();
|
221 |
-
|
222 |
-
p.major_dim = x.size(0);
|
223 |
-
p.in_h = x.size(1);
|
224 |
-
p.in_w = x.size(2);
|
225 |
-
p.minor_dim = x.size(3);
|
226 |
-
p.kernel_h = k.size(0);
|
227 |
-
p.kernel_w = k.size(1);
|
228 |
-
p.up_x = up_x;
|
229 |
-
p.up_y = up_y;
|
230 |
-
p.down_x = down_x;
|
231 |
-
p.down_y = down_y;
|
232 |
-
p.pad_x0 = pad_x0;
|
233 |
-
p.pad_x1 = pad_x1;
|
234 |
-
p.pad_y0 = pad_y0;
|
235 |
-
p.pad_y1 = pad_y1;
|
236 |
-
|
237 |
-
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
238 |
-
p.down_y;
|
239 |
-
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
240 |
-
p.down_x;
|
241 |
-
|
242 |
-
auto out =
|
243 |
-
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
244 |
-
|
245 |
-
int mode = -1;
|
246 |
-
|
247 |
-
int tile_out_h = -1;
|
248 |
-
int tile_out_w = -1;
|
249 |
-
|
250 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
251 |
-
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
252 |
-
mode = 1;
|
253 |
-
tile_out_h = 16;
|
254 |
-
tile_out_w = 64;
|
255 |
-
}
|
256 |
-
|
257 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
258 |
-
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
259 |
-
mode = 2;
|
260 |
-
tile_out_h = 16;
|
261 |
-
tile_out_w = 64;
|
262 |
-
}
|
263 |
-
|
264 |
-
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
265 |
-
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
266 |
-
mode = 3;
|
267 |
-
tile_out_h = 16;
|
268 |
-
tile_out_w = 64;
|
269 |
-
}
|
270 |
-
|
271 |
-
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
272 |
-
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
273 |
-
mode = 4;
|
274 |
-
tile_out_h = 16;
|
275 |
-
tile_out_w = 64;
|
276 |
-
}
|
277 |
-
|
278 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
279 |
-
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
280 |
-
mode = 5;
|
281 |
-
tile_out_h = 8;
|
282 |
-
tile_out_w = 32;
|
283 |
-
}
|
284 |
-
|
285 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
286 |
-
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
287 |
-
mode = 6;
|
288 |
-
tile_out_h = 8;
|
289 |
-
tile_out_w = 32;
|
290 |
-
}
|
291 |
-
|
292 |
-
dim3 block_size;
|
293 |
-
dim3 grid_size;
|
294 |
-
|
295 |
-
if (tile_out_h > 0 && tile_out_w > 0) {
|
296 |
-
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
297 |
-
p.loop_x = 1;
|
298 |
-
block_size = dim3(32 * 8, 1, 1);
|
299 |
-
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
300 |
-
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
301 |
-
(p.major_dim - 1) / p.loop_major + 1);
|
302 |
-
} else {
|
303 |
-
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
304 |
-
p.loop_x = 4;
|
305 |
-
block_size = dim3(4, 32, 1);
|
306 |
-
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
307 |
-
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
308 |
-
(p.major_dim - 1) / p.loop_major + 1);
|
309 |
-
}
|
310 |
-
|
311 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
312 |
-
switch (mode) {
|
313 |
-
case 1:
|
314 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
315 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
316 |
-
x.data_ptr<scalar_t>(),
|
317 |
-
k.data_ptr<scalar_t>(), p);
|
318 |
-
|
319 |
-
break;
|
320 |
-
|
321 |
-
case 2:
|
322 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
323 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
324 |
-
x.data_ptr<scalar_t>(),
|
325 |
-
k.data_ptr<scalar_t>(), p);
|
326 |
-
|
327 |
-
break;
|
328 |
-
|
329 |
-
case 3:
|
330 |
-
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
331 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
332 |
-
x.data_ptr<scalar_t>(),
|
333 |
-
k.data_ptr<scalar_t>(), p);
|
334 |
-
|
335 |
-
break;
|
336 |
-
|
337 |
-
case 4:
|
338 |
-
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
339 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
340 |
-
x.data_ptr<scalar_t>(),
|
341 |
-
k.data_ptr<scalar_t>(), p);
|
342 |
-
|
343 |
-
break;
|
344 |
-
|
345 |
-
case 5:
|
346 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
347 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
348 |
-
x.data_ptr<scalar_t>(),
|
349 |
-
k.data_ptr<scalar_t>(), p);
|
350 |
-
|
351 |
-
break;
|
352 |
-
|
353 |
-
case 6:
|
354 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
355 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
356 |
-
x.data_ptr<scalar_t>(),
|
357 |
-
k.data_ptr<scalar_t>(), p);
|
358 |
-
|
359 |
-
break;
|
360 |
-
|
361 |
-
default:
|
362 |
-
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
363 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
364 |
-
k.data_ptr<scalar_t>(), p);
|
365 |
-
}
|
366 |
-
});
|
367 |
-
|
368 |
-
return out;
|
369 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ppl.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch.nn import functional as F
|
5 |
-
import numpy as np
|
6 |
-
from tqdm import tqdm
|
7 |
-
|
8 |
-
import lpips
|
9 |
-
from model import Generator
|
10 |
-
|
11 |
-
|
12 |
-
def normalize(x):
|
13 |
-
return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
|
14 |
-
|
15 |
-
|
16 |
-
def slerp(a, b, t):
|
17 |
-
a = normalize(a)
|
18 |
-
b = normalize(b)
|
19 |
-
d = (a * b).sum(-1, keepdim=True)
|
20 |
-
p = t * torch.acos(d)
|
21 |
-
c = normalize(b - d * a)
|
22 |
-
d = a * torch.cos(p) + c * torch.sin(p)
|
23 |
-
|
24 |
-
return normalize(d)
|
25 |
-
|
26 |
-
|
27 |
-
def lerp(a, b, t):
|
28 |
-
return a + (b - a) * t
|
29 |
-
|
30 |
-
|
31 |
-
if __name__ == "__main__":
|
32 |
-
device = "cuda"
|
33 |
-
|
34 |
-
parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
|
35 |
-
|
36 |
-
parser.add_argument(
|
37 |
-
"--space", choices=["z", "w"], help="space that PPL calculated with"
|
38 |
-
)
|
39 |
-
parser.add_argument(
|
40 |
-
"--batch", type=int, default=64, help="batch size for the models"
|
41 |
-
)
|
42 |
-
parser.add_argument(
|
43 |
-
"--n_sample",
|
44 |
-
type=int,
|
45 |
-
default=5000,
|
46 |
-
help="number of the samples for calculating PPL",
|
47 |
-
)
|
48 |
-
parser.add_argument(
|
49 |
-
"--size", type=int, default=256, help="output image sizes of the generator"
|
50 |
-
)
|
51 |
-
parser.add_argument(
|
52 |
-
"--eps", type=float, default=1e-4, help="epsilon for numerical stability"
|
53 |
-
)
|
54 |
-
parser.add_argument(
|
55 |
-
"--crop", action="store_true", help="apply center crop to the images"
|
56 |
-
)
|
57 |
-
parser.add_argument(
|
58 |
-
"--sampling",
|
59 |
-
default="end",
|
60 |
-
choices=["end", "full"],
|
61 |
-
help="set endpoint sampling method",
|
62 |
-
)
|
63 |
-
parser.add_argument(
|
64 |
-
"ckpt", metavar="CHECKPOINT", help="path to the model checkpoints"
|
65 |
-
)
|
66 |
-
|
67 |
-
args = parser.parse_args()
|
68 |
-
|
69 |
-
latent_dim = 512
|
70 |
-
|
71 |
-
ckpt = torch.load(args.ckpt)
|
72 |
-
|
73 |
-
g = Generator(args.size, latent_dim, 8).to(device)
|
74 |
-
g.load_state_dict(ckpt["g_ema"])
|
75 |
-
g.eval()
|
76 |
-
|
77 |
-
percept = lpips.PerceptualLoss(
|
78 |
-
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
79 |
-
)
|
80 |
-
|
81 |
-
distances = []
|
82 |
-
|
83 |
-
n_batch = args.n_sample // args.batch
|
84 |
-
resid = args.n_sample - (n_batch * args.batch)
|
85 |
-
batch_sizes = [args.batch] * n_batch + [resid]
|
86 |
-
|
87 |
-
with torch.no_grad():
|
88 |
-
for batch in tqdm(batch_sizes):
|
89 |
-
noise = g.make_noise()
|
90 |
-
|
91 |
-
inputs = torch.randn([batch * 2, latent_dim], device=device)
|
92 |
-
if args.sampling == "full":
|
93 |
-
lerp_t = torch.rand(batch, device=device)
|
94 |
-
else:
|
95 |
-
lerp_t = torch.zeros(batch, device=device)
|
96 |
-
|
97 |
-
if args.space == "w":
|
98 |
-
latent = g.get_latent(inputs)
|
99 |
-
latent_t0, latent_t1 = latent[::2], latent[1::2]
|
100 |
-
latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
|
101 |
-
latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
|
102 |
-
latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
|
103 |
-
|
104 |
-
image, _ = g([latent_e], input_is_latent=True, noise=noise)
|
105 |
-
|
106 |
-
if args.crop:
|
107 |
-
c = image.shape[2] // 8
|
108 |
-
image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
|
109 |
-
|
110 |
-
factor = image.shape[2] // 256
|
111 |
-
|
112 |
-
if factor > 1:
|
113 |
-
image = F.interpolate(
|
114 |
-
image, size=(256, 256), mode="bilinear", align_corners=False
|
115 |
-
)
|
116 |
-
|
117 |
-
dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
|
118 |
-
args.eps ** 2
|
119 |
-
)
|
120 |
-
distances.append(dist.to("cpu").numpy())
|
121 |
-
|
122 |
-
distances = np.concatenate(distances, 0)
|
123 |
-
|
124 |
-
lo = np.percentile(distances, 1, interpolation="lower")
|
125 |
-
hi = np.percentile(distances, 99, interpolation="higher")
|
126 |
-
filtered_dist = np.extract(
|
127 |
-
np.logical_and(lo <= distances, distances <= hi), distances
|
128 |
-
)
|
129 |
-
|
130 |
-
print("ppl:", filtered_dist.mean())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prepare_data.py
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
from io import BytesIO
|
3 |
-
import multiprocessing
|
4 |
-
from functools import partial
|
5 |
-
|
6 |
-
from PIL import Image
|
7 |
-
import lmdb
|
8 |
-
from tqdm import tqdm
|
9 |
-
from torchvision import datasets
|
10 |
-
from torchvision.transforms import functional as trans_fn
|
11 |
-
|
12 |
-
|
13 |
-
def resize_and_convert(img, size, resample, quality=100):
|
14 |
-
img = trans_fn.resize(img, size, resample)
|
15 |
-
img = trans_fn.center_crop(img, size)
|
16 |
-
buffer = BytesIO()
|
17 |
-
img.save(buffer, format="jpeg", quality=quality)
|
18 |
-
val = buffer.getvalue()
|
19 |
-
|
20 |
-
return val
|
21 |
-
|
22 |
-
|
23 |
-
def resize_multiple(
|
24 |
-
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
|
25 |
-
):
|
26 |
-
imgs = []
|
27 |
-
|
28 |
-
for size in sizes:
|
29 |
-
imgs.append(resize_and_convert(img, size, resample, quality))
|
30 |
-
|
31 |
-
return imgs
|
32 |
-
|
33 |
-
|
34 |
-
def resize_worker(img_file, sizes, resample):
|
35 |
-
i, file = img_file
|
36 |
-
img = Image.open(file)
|
37 |
-
img = img.convert("RGB")
|
38 |
-
out = resize_multiple(img, sizes=sizes, resample=resample)
|
39 |
-
|
40 |
-
return i, out
|
41 |
-
|
42 |
-
|
43 |
-
def prepare(
|
44 |
-
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
|
45 |
-
):
|
46 |
-
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
|
47 |
-
|
48 |
-
files = sorted(dataset.imgs, key=lambda x: x[0])
|
49 |
-
files = [(i, file) for i, (file, label) in enumerate(files)]
|
50 |
-
total = 0
|
51 |
-
|
52 |
-
with multiprocessing.Pool(n_worker) as pool:
|
53 |
-
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
|
54 |
-
for size, img in zip(sizes, imgs):
|
55 |
-
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
|
56 |
-
|
57 |
-
with env.begin(write=True) as txn:
|
58 |
-
txn.put(key, img)
|
59 |
-
|
60 |
-
total += 1
|
61 |
-
|
62 |
-
with env.begin(write=True) as txn:
|
63 |
-
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
|
64 |
-
|
65 |
-
|
66 |
-
if __name__ == "__main__":
|
67 |
-
parser = argparse.ArgumentParser(description="Preprocess images for model training")
|
68 |
-
parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
|
69 |
-
parser.add_argument(
|
70 |
-
"--size",
|
71 |
-
type=str,
|
72 |
-
default="128,256,512,1024",
|
73 |
-
help="resolutions of images for the dataset",
|
74 |
-
)
|
75 |
-
parser.add_argument(
|
76 |
-
"--n_worker",
|
77 |
-
type=int,
|
78 |
-
default=8,
|
79 |
-
help="number of workers for preparing dataset",
|
80 |
-
)
|
81 |
-
parser.add_argument(
|
82 |
-
"--resample",
|
83 |
-
type=str,
|
84 |
-
default="lanczos",
|
85 |
-
help="resampling methods for resizing images",
|
86 |
-
)
|
87 |
-
parser.add_argument("path", type=str, help="path to the image dataset")
|
88 |
-
|
89 |
-
args = parser.parse_args()
|
90 |
-
|
91 |
-
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
|
92 |
-
resample = resample_map[args.resample]
|
93 |
-
|
94 |
-
sizes = [int(s.strip()) for s in args.size.split(",")]
|
95 |
-
|
96 |
-
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
|
97 |
-
|
98 |
-
imgset = datasets.ImageFolder(args.path)
|
99 |
-
|
100 |
-
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
|
101 |
-
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
projector.py
DELETED
@@ -1,248 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import math
|
3 |
-
import os
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from torch import optim
|
7 |
-
from torch.nn import functional as F
|
8 |
-
from torchvision import transforms
|
9 |
-
from PIL import Image
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
import lpips
|
13 |
-
from model import Generator
|
14 |
-
|
15 |
-
|
16 |
-
def noise_regularize(noises):
|
17 |
-
loss = 0
|
18 |
-
|
19 |
-
for noise in noises:
|
20 |
-
size = noise.shape[2]
|
21 |
-
|
22 |
-
while True:
|
23 |
-
loss = (
|
24 |
-
loss
|
25 |
-
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
26 |
-
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
27 |
-
)
|
28 |
-
|
29 |
-
if size <= 8:
|
30 |
-
break
|
31 |
-
|
32 |
-
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
|
33 |
-
noise = noise.mean([3, 5])
|
34 |
-
size //= 2
|
35 |
-
|
36 |
-
return loss
|
37 |
-
|
38 |
-
|
39 |
-
def noise_normalize_(noises):
|
40 |
-
for noise in noises:
|
41 |
-
mean = noise.mean()
|
42 |
-
std = noise.std()
|
43 |
-
|
44 |
-
noise.data.add_(-mean).div_(std)
|
45 |
-
|
46 |
-
|
47 |
-
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
48 |
-
lr_ramp = min(1, (1 - t) / rampdown)
|
49 |
-
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
50 |
-
lr_ramp = lr_ramp * min(1, t / rampup)
|
51 |
-
|
52 |
-
return initial_lr * lr_ramp
|
53 |
-
|
54 |
-
|
55 |
-
def latent_noise(latent, strength):
|
56 |
-
noise = torch.randn_like(latent) * strength
|
57 |
-
|
58 |
-
return latent + noise
|
59 |
-
|
60 |
-
|
61 |
-
def make_image(tensor):
|
62 |
-
return (
|
63 |
-
tensor.detach()
|
64 |
-
.clamp_(min=-1, max=1)
|
65 |
-
.add(1)
|
66 |
-
.div_(2)
|
67 |
-
.mul(255)
|
68 |
-
.type(torch.uint8)
|
69 |
-
.permute(0, 2, 3, 1)
|
70 |
-
.to("cpu")
|
71 |
-
.numpy()
|
72 |
-
)
|
73 |
-
|
74 |
-
|
75 |
-
if __name__ == "__main__":
|
76 |
-
device = "cuda"
|
77 |
-
|
78 |
-
parser = argparse.ArgumentParser(
|
79 |
-
description="Image projector to the generator latent spaces"
|
80 |
-
)
|
81 |
-
parser.add_argument(
|
82 |
-
"--ckpt", type=str, required=True, help="path to the model checkpoint"
|
83 |
-
)
|
84 |
-
parser.add_argument(
|
85 |
-
"--size", type=int, default=256, help="output image sizes of the generator"
|
86 |
-
)
|
87 |
-
parser.add_argument(
|
88 |
-
"--lr_rampup",
|
89 |
-
type=float,
|
90 |
-
default=0.05,
|
91 |
-
help="duration of the learning rate warmup",
|
92 |
-
)
|
93 |
-
parser.add_argument(
|
94 |
-
"--lr_rampdown",
|
95 |
-
type=float,
|
96 |
-
default=0.25,
|
97 |
-
help="duration of the learning rate decay",
|
98 |
-
)
|
99 |
-
parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
|
100 |
-
parser.add_argument(
|
101 |
-
"--noise", type=float, default=0.05, help="strength of the noise level"
|
102 |
-
)
|
103 |
-
parser.add_argument(
|
104 |
-
"--noise_ramp",
|
105 |
-
type=float,
|
106 |
-
default=0.75,
|
107 |
-
help="duration of the noise level decay",
|
108 |
-
)
|
109 |
-
parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
|
110 |
-
parser.add_argument(
|
111 |
-
"--noise_regularize",
|
112 |
-
type=float,
|
113 |
-
default=1e5,
|
114 |
-
help="weight of the noise regularization",
|
115 |
-
)
|
116 |
-
parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
|
117 |
-
parser.add_argument(
|
118 |
-
"--w_plus",
|
119 |
-
action="store_true",
|
120 |
-
help="allow to use distinct latent codes to each layers",
|
121 |
-
)
|
122 |
-
parser.add_argument(
|
123 |
-
"files", metavar="FILES", nargs="+", help="path to image files to be projected"
|
124 |
-
)
|
125 |
-
|
126 |
-
args = parser.parse_args()
|
127 |
-
|
128 |
-
n_mean_latent = 10000
|
129 |
-
|
130 |
-
resize = min(args.size, 256)
|
131 |
-
|
132 |
-
transform = transforms.Compose(
|
133 |
-
[
|
134 |
-
transforms.Resize(resize),
|
135 |
-
transforms.CenterCrop(resize),
|
136 |
-
transforms.ToTensor(),
|
137 |
-
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
138 |
-
]
|
139 |
-
)
|
140 |
-
|
141 |
-
imgs = []
|
142 |
-
|
143 |
-
for imgfile in args.files:
|
144 |
-
img = transform(Image.open(imgfile).convert("RGB"))
|
145 |
-
imgs.append(img)
|
146 |
-
|
147 |
-
imgs = torch.stack(imgs, 0).to(device)
|
148 |
-
|
149 |
-
g_ema = Generator(args.size, 512, 8)
|
150 |
-
g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
|
151 |
-
g_ema.eval()
|
152 |
-
g_ema = g_ema.to(device)
|
153 |
-
|
154 |
-
with torch.no_grad():
|
155 |
-
noise_sample = torch.randn(n_mean_latent, 512, device=device)
|
156 |
-
latent_out = g_ema.style(noise_sample)
|
157 |
-
|
158 |
-
latent_mean = latent_out.mean(0)
|
159 |
-
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
|
160 |
-
|
161 |
-
percept = lpips.PerceptualLoss(
|
162 |
-
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
163 |
-
)
|
164 |
-
|
165 |
-
noises_single = g_ema.make_noise()
|
166 |
-
noises = []
|
167 |
-
for noise in noises_single:
|
168 |
-
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
|
169 |
-
|
170 |
-
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
|
171 |
-
|
172 |
-
if args.w_plus:
|
173 |
-
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
|
174 |
-
|
175 |
-
latent_in.requires_grad = True
|
176 |
-
|
177 |
-
for noise in noises:
|
178 |
-
noise.requires_grad = True
|
179 |
-
|
180 |
-
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
|
181 |
-
|
182 |
-
pbar = tqdm(range(args.step))
|
183 |
-
latent_path = []
|
184 |
-
|
185 |
-
for i in pbar:
|
186 |
-
t = i / args.step
|
187 |
-
lr = get_lr(t, args.lr)
|
188 |
-
optimizer.param_groups[0]["lr"] = lr
|
189 |
-
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
|
190 |
-
latent_n = latent_noise(latent_in, noise_strength.item())
|
191 |
-
|
192 |
-
img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
|
193 |
-
|
194 |
-
batch, channel, height, width = img_gen.shape
|
195 |
-
|
196 |
-
if height > 256:
|
197 |
-
factor = height // 256
|
198 |
-
|
199 |
-
img_gen = img_gen.reshape(
|
200 |
-
batch, channel, height // factor, factor, width // factor, factor
|
201 |
-
)
|
202 |
-
img_gen = img_gen.mean([3, 5])
|
203 |
-
|
204 |
-
p_loss = percept(img_gen, imgs).sum()
|
205 |
-
n_loss = noise_regularize(noises)
|
206 |
-
mse_loss = F.mse_loss(img_gen, imgs)
|
207 |
-
|
208 |
-
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
|
209 |
-
|
210 |
-
optimizer.zero_grad()
|
211 |
-
loss.backward()
|
212 |
-
optimizer.step()
|
213 |
-
|
214 |
-
noise_normalize_(noises)
|
215 |
-
|
216 |
-
if (i + 1) % 100 == 0:
|
217 |
-
latent_path.append(latent_in.detach().clone())
|
218 |
-
|
219 |
-
pbar.set_description(
|
220 |
-
(
|
221 |
-
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
|
222 |
-
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
|
223 |
-
)
|
224 |
-
)
|
225 |
-
|
226 |
-
img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
|
227 |
-
|
228 |
-
filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
|
229 |
-
|
230 |
-
img_ar = make_image(img_gen)
|
231 |
-
|
232 |
-
result_file = {}
|
233 |
-
for i, input_name in enumerate(args.files):
|
234 |
-
noise_single = []
|
235 |
-
for noise in noises:
|
236 |
-
noise_single.append(noise[i : i + 1])
|
237 |
-
|
238 |
-
result_file[input_name] = {
|
239 |
-
"img": img_gen[i],
|
240 |
-
"latent": latent_in[i],
|
241 |
-
"noise": noise_single,
|
242 |
-
}
|
243 |
-
|
244 |
-
img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
|
245 |
-
pil_img = Image.fromarray(img_ar[i])
|
246 |
-
pil_img.save(img_name)
|
247 |
-
|
248 |
-
torch.save(result_file, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample/.gitignore
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
*.png
|
|
|
|
swagan.py
DELETED
@@ -1,440 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import random
|
3 |
-
import functools
|
4 |
-
import operator
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import functional as F
|
9 |
-
from torch.autograd import Function
|
10 |
-
|
11 |
-
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
12 |
-
from model import (
|
13 |
-
ModulatedConv2d,
|
14 |
-
StyledConv,
|
15 |
-
ConstantInput,
|
16 |
-
PixelNorm,
|
17 |
-
Upsample,
|
18 |
-
Downsample,
|
19 |
-
Blur,
|
20 |
-
EqualLinear,
|
21 |
-
ConvLayer,
|
22 |
-
)
|
23 |
-
|
24 |
-
|
25 |
-
def get_haar_wavelet(in_channels):
|
26 |
-
haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
|
27 |
-
haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2)
|
28 |
-
haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0]
|
29 |
-
|
30 |
-
haar_wav_ll = haar_wav_l.T * haar_wav_l
|
31 |
-
haar_wav_lh = haar_wav_h.T * haar_wav_l
|
32 |
-
haar_wav_hl = haar_wav_l.T * haar_wav_h
|
33 |
-
haar_wav_hh = haar_wav_h.T * haar_wav_h
|
34 |
-
|
35 |
-
return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh
|
36 |
-
|
37 |
-
|
38 |
-
def dwt_init(x):
|
39 |
-
x01 = x[:, :, 0::2, :] / 2
|
40 |
-
x02 = x[:, :, 1::2, :] / 2
|
41 |
-
x1 = x01[:, :, :, 0::2]
|
42 |
-
x2 = x02[:, :, :, 0::2]
|
43 |
-
x3 = x01[:, :, :, 1::2]
|
44 |
-
x4 = x02[:, :, :, 1::2]
|
45 |
-
x_LL = x1 + x2 + x3 + x4
|
46 |
-
x_HL = -x1 - x2 + x3 + x4
|
47 |
-
x_LH = -x1 + x2 - x3 + x4
|
48 |
-
x_HH = x1 - x2 - x3 + x4
|
49 |
-
|
50 |
-
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
|
51 |
-
|
52 |
-
|
53 |
-
def iwt_init(x):
|
54 |
-
r = 2
|
55 |
-
in_batch, in_channel, in_height, in_width = x.size()
|
56 |
-
# print([in_batch, in_channel, in_height, in_width])
|
57 |
-
out_batch, out_channel, out_height, out_width = (
|
58 |
-
in_batch,
|
59 |
-
int(in_channel / (r ** 2)),
|
60 |
-
r * in_height,
|
61 |
-
r * in_width,
|
62 |
-
)
|
63 |
-
x1 = x[:, 0:out_channel, :, :] / 2
|
64 |
-
x2 = x[:, out_channel : out_channel * 2, :, :] / 2
|
65 |
-
x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
|
66 |
-
x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
|
67 |
-
|
68 |
-
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
|
69 |
-
|
70 |
-
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
|
71 |
-
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
|
72 |
-
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
|
73 |
-
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
|
74 |
-
|
75 |
-
return h
|
76 |
-
|
77 |
-
|
78 |
-
class HaarTransform(nn.Module):
|
79 |
-
def __init__(self, in_channels):
|
80 |
-
super().__init__()
|
81 |
-
|
82 |
-
ll, lh, hl, hh = get_haar_wavelet(in_channels)
|
83 |
-
|
84 |
-
self.register_buffer("ll", ll)
|
85 |
-
self.register_buffer("lh", lh)
|
86 |
-
self.register_buffer("hl", hl)
|
87 |
-
self.register_buffer("hh", hh)
|
88 |
-
|
89 |
-
def forward(self, input):
|
90 |
-
ll = upfirdn2d(input, self.ll, down=2)
|
91 |
-
lh = upfirdn2d(input, self.lh, down=2)
|
92 |
-
hl = upfirdn2d(input, self.hl, down=2)
|
93 |
-
hh = upfirdn2d(input, self.hh, down=2)
|
94 |
-
|
95 |
-
return torch.cat((ll, lh, hl, hh), 1)
|
96 |
-
|
97 |
-
|
98 |
-
class InverseHaarTransform(nn.Module):
|
99 |
-
def __init__(self, in_channels):
|
100 |
-
super().__init__()
|
101 |
-
|
102 |
-
ll, lh, hl, hh = get_haar_wavelet(in_channels)
|
103 |
-
|
104 |
-
self.register_buffer("ll", ll)
|
105 |
-
self.register_buffer("lh", -lh)
|
106 |
-
self.register_buffer("hl", -hl)
|
107 |
-
self.register_buffer("hh", hh)
|
108 |
-
|
109 |
-
def forward(self, input):
|
110 |
-
ll, lh, hl, hh = input.chunk(4, 1)
|
111 |
-
ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
|
112 |
-
lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
|
113 |
-
hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
|
114 |
-
hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))
|
115 |
-
|
116 |
-
return ll + lh + hl + hh
|
117 |
-
|
118 |
-
|
119 |
-
class ToRGB(nn.Module):
|
120 |
-
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
121 |
-
super().__init__()
|
122 |
-
|
123 |
-
if upsample:
|
124 |
-
self.iwt = InverseHaarTransform(3)
|
125 |
-
self.upsample = Upsample(blur_kernel)
|
126 |
-
self.dwt = HaarTransform(3)
|
127 |
-
|
128 |
-
self.conv = ModulatedConv2d(in_channel, 3 * 4, 1, style_dim, demodulate=False)
|
129 |
-
self.bias = nn.Parameter(torch.zeros(1, 3 * 4, 1, 1))
|
130 |
-
|
131 |
-
def forward(self, input, style, skip=None):
|
132 |
-
out = self.conv(input, style)
|
133 |
-
out = out + self.bias
|
134 |
-
|
135 |
-
if skip is not None:
|
136 |
-
skip = self.iwt(skip)
|
137 |
-
skip = self.upsample(skip)
|
138 |
-
skip = self.dwt(skip)
|
139 |
-
|
140 |
-
out = out + skip
|
141 |
-
|
142 |
-
return out
|
143 |
-
|
144 |
-
|
145 |
-
class Generator(nn.Module):
|
146 |
-
def __init__(
|
147 |
-
self,
|
148 |
-
size,
|
149 |
-
style_dim,
|
150 |
-
n_mlp,
|
151 |
-
channel_multiplier=2,
|
152 |
-
blur_kernel=[1, 3, 3, 1],
|
153 |
-
lr_mlp=0.01,
|
154 |
-
):
|
155 |
-
super().__init__()
|
156 |
-
|
157 |
-
self.size = size
|
158 |
-
|
159 |
-
self.style_dim = style_dim
|
160 |
-
|
161 |
-
layers = [PixelNorm()]
|
162 |
-
|
163 |
-
for i in range(n_mlp):
|
164 |
-
layers.append(
|
165 |
-
EqualLinear(
|
166 |
-
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
167 |
-
)
|
168 |
-
)
|
169 |
-
|
170 |
-
self.style = nn.Sequential(*layers)
|
171 |
-
|
172 |
-
self.channels = {
|
173 |
-
4: 512,
|
174 |
-
8: 512,
|
175 |
-
16: 512,
|
176 |
-
32: 512,
|
177 |
-
64: 256 * channel_multiplier,
|
178 |
-
128: 128 * channel_multiplier,
|
179 |
-
256: 64 * channel_multiplier,
|
180 |
-
512: 32 * channel_multiplier,
|
181 |
-
1024: 16 * channel_multiplier,
|
182 |
-
}
|
183 |
-
|
184 |
-
self.input = ConstantInput(self.channels[4])
|
185 |
-
self.conv1 = StyledConv(
|
186 |
-
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
187 |
-
)
|
188 |
-
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
189 |
-
|
190 |
-
self.log_size = int(math.log(size, 2)) - 1
|
191 |
-
self.num_layers = (self.log_size - 2) * 2 + 1
|
192 |
-
|
193 |
-
self.convs = nn.ModuleList()
|
194 |
-
self.upsamples = nn.ModuleList()
|
195 |
-
self.to_rgbs = nn.ModuleList()
|
196 |
-
self.noises = nn.Module()
|
197 |
-
|
198 |
-
in_channel = self.channels[4]
|
199 |
-
|
200 |
-
for layer_idx in range(self.num_layers):
|
201 |
-
res = (layer_idx + 5) // 2
|
202 |
-
shape = [1, 1, 2 ** res, 2 ** res]
|
203 |
-
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
204 |
-
|
205 |
-
for i in range(3, self.log_size + 1):
|
206 |
-
out_channel = self.channels[2 ** i]
|
207 |
-
|
208 |
-
self.convs.append(
|
209 |
-
StyledConv(
|
210 |
-
in_channel,
|
211 |
-
out_channel,
|
212 |
-
3,
|
213 |
-
style_dim,
|
214 |
-
upsample=True,
|
215 |
-
blur_kernel=blur_kernel,
|
216 |
-
)
|
217 |
-
)
|
218 |
-
|
219 |
-
self.convs.append(
|
220 |
-
StyledConv(
|
221 |
-
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
222 |
-
)
|
223 |
-
)
|
224 |
-
|
225 |
-
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
226 |
-
|
227 |
-
in_channel = out_channel
|
228 |
-
|
229 |
-
self.iwt = InverseHaarTransform(3)
|
230 |
-
|
231 |
-
self.n_latent = self.log_size * 2 - 2
|
232 |
-
|
233 |
-
def make_noise(self):
|
234 |
-
device = self.input.input.device
|
235 |
-
|
236 |
-
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
237 |
-
|
238 |
-
for i in range(3, self.log_size + 1):
|
239 |
-
for _ in range(2):
|
240 |
-
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
241 |
-
|
242 |
-
return noises
|
243 |
-
|
244 |
-
def mean_latent(self, n_latent):
|
245 |
-
latent_in = torch.randn(
|
246 |
-
n_latent, self.style_dim, device=self.input.input.device
|
247 |
-
)
|
248 |
-
latent = self.style(latent_in).mean(0, keepdim=True)
|
249 |
-
|
250 |
-
return latent
|
251 |
-
|
252 |
-
def get_latent(self, input):
|
253 |
-
return self.style(input)
|
254 |
-
|
255 |
-
def forward(
|
256 |
-
self,
|
257 |
-
styles,
|
258 |
-
return_latents=False,
|
259 |
-
inject_index=None,
|
260 |
-
truncation=1,
|
261 |
-
truncation_latent=None,
|
262 |
-
input_is_latent=False,
|
263 |
-
noise=None,
|
264 |
-
randomize_noise=True,
|
265 |
-
):
|
266 |
-
if not input_is_latent:
|
267 |
-
styles = [self.style(s) for s in styles]
|
268 |
-
|
269 |
-
if noise is None:
|
270 |
-
if randomize_noise:
|
271 |
-
noise = [None] * self.num_layers
|
272 |
-
else:
|
273 |
-
noise = [
|
274 |
-
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
275 |
-
]
|
276 |
-
|
277 |
-
if truncation < 1:
|
278 |
-
style_t = []
|
279 |
-
|
280 |
-
for style in styles:
|
281 |
-
style_t.append(
|
282 |
-
truncation_latent + truncation * (style - truncation_latent)
|
283 |
-
)
|
284 |
-
|
285 |
-
styles = style_t
|
286 |
-
|
287 |
-
if len(styles) < 2:
|
288 |
-
inject_index = self.n_latent
|
289 |
-
|
290 |
-
if styles[0].ndim < 3:
|
291 |
-
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
292 |
-
|
293 |
-
else:
|
294 |
-
latent = styles[0]
|
295 |
-
|
296 |
-
else:
|
297 |
-
if inject_index is None:
|
298 |
-
inject_index = random.randint(1, self.n_latent - 1)
|
299 |
-
|
300 |
-
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
301 |
-
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
302 |
-
|
303 |
-
latent = torch.cat([latent, latent2], 1)
|
304 |
-
|
305 |
-
out = self.input(latent)
|
306 |
-
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
307 |
-
|
308 |
-
skip = self.to_rgb1(out, latent[:, 1])
|
309 |
-
|
310 |
-
i = 1
|
311 |
-
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
312 |
-
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
313 |
-
):
|
314 |
-
out = conv1(out, latent[:, i], noise=noise1)
|
315 |
-
out = conv2(out, latent[:, i + 1], noise=noise2)
|
316 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
317 |
-
|
318 |
-
i += 2
|
319 |
-
|
320 |
-
image = self.iwt(skip)
|
321 |
-
|
322 |
-
if return_latents:
|
323 |
-
return image, latent
|
324 |
-
|
325 |
-
else:
|
326 |
-
return image, None
|
327 |
-
|
328 |
-
|
329 |
-
class ConvBlock(nn.Module):
|
330 |
-
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
331 |
-
super().__init__()
|
332 |
-
|
333 |
-
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
334 |
-
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
335 |
-
|
336 |
-
def forward(self, input):
|
337 |
-
out = self.conv1(input)
|
338 |
-
out = self.conv2(out)
|
339 |
-
|
340 |
-
return out
|
341 |
-
|
342 |
-
|
343 |
-
class FromRGB(nn.Module):
|
344 |
-
def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]):
|
345 |
-
super().__init__()
|
346 |
-
|
347 |
-
self.downsample = downsample
|
348 |
-
|
349 |
-
if downsample:
|
350 |
-
self.iwt = InverseHaarTransform(3)
|
351 |
-
self.downsample = Downsample(blur_kernel)
|
352 |
-
self.dwt = HaarTransform(3)
|
353 |
-
|
354 |
-
self.conv = ConvLayer(3 * 4, out_channel, 3)
|
355 |
-
|
356 |
-
def forward(self, input, skip=None):
|
357 |
-
if self.downsample:
|
358 |
-
input = self.iwt(input)
|
359 |
-
input = self.downsample(input)
|
360 |
-
input = self.dwt(input)
|
361 |
-
|
362 |
-
out = self.conv(input)
|
363 |
-
|
364 |
-
if skip is not None:
|
365 |
-
out = out + skip
|
366 |
-
|
367 |
-
return input, out
|
368 |
-
|
369 |
-
|
370 |
-
class Discriminator(nn.Module):
|
371 |
-
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
372 |
-
super().__init__()
|
373 |
-
|
374 |
-
channels = {
|
375 |
-
4: 512,
|
376 |
-
8: 512,
|
377 |
-
16: 512,
|
378 |
-
32: 512,
|
379 |
-
64: 256 * channel_multiplier,
|
380 |
-
128: 128 * channel_multiplier,
|
381 |
-
256: 64 * channel_multiplier,
|
382 |
-
512: 32 * channel_multiplier,
|
383 |
-
1024: 16 * channel_multiplier,
|
384 |
-
}
|
385 |
-
|
386 |
-
self.dwt = HaarTransform(3)
|
387 |
-
|
388 |
-
self.from_rgbs = nn.ModuleList()
|
389 |
-
self.convs = nn.ModuleList()
|
390 |
-
|
391 |
-
log_size = int(math.log(size, 2)) - 1
|
392 |
-
|
393 |
-
in_channel = channels[size]
|
394 |
-
|
395 |
-
for i in range(log_size, 2, -1):
|
396 |
-
out_channel = channels[2 ** (i - 1)]
|
397 |
-
|
398 |
-
self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size))
|
399 |
-
self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel))
|
400 |
-
|
401 |
-
in_channel = out_channel
|
402 |
-
|
403 |
-
self.from_rgbs.append(FromRGB(channels[4]))
|
404 |
-
|
405 |
-
self.stddev_group = 4
|
406 |
-
self.stddev_feat = 1
|
407 |
-
|
408 |
-
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
409 |
-
self.final_linear = nn.Sequential(
|
410 |
-
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
411 |
-
EqualLinear(channels[4], 1),
|
412 |
-
)
|
413 |
-
|
414 |
-
def forward(self, input):
|
415 |
-
input = self.dwt(input)
|
416 |
-
out = None
|
417 |
-
|
418 |
-
for from_rgb, conv in zip(self.from_rgbs, self.convs):
|
419 |
-
input, out = from_rgb(input, out)
|
420 |
-
out = conv(out)
|
421 |
-
|
422 |
-
_, out = self.from_rgbs[-1](input, out)
|
423 |
-
|
424 |
-
batch, channel, height, width = out.shape
|
425 |
-
group = min(batch, self.stddev_group)
|
426 |
-
stddev = out.view(
|
427 |
-
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
428 |
-
)
|
429 |
-
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
430 |
-
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
431 |
-
stddev = stddev.repeat(group, 1, height, width)
|
432 |
-
out = torch.cat([out, stddev], 1)
|
433 |
-
|
434 |
-
out = self.final_conv(out)
|
435 |
-
|
436 |
-
out = out.view(batch, -1)
|
437 |
-
out = self.final_linear(out)
|
438 |
-
|
439 |
-
return out
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,531 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import math
|
3 |
-
import random
|
4 |
-
import os
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
from torch import nn, autograd, optim
|
9 |
-
from torch.nn import functional as F
|
10 |
-
from torch.utils import data
|
11 |
-
import torch.distributed as dist
|
12 |
-
from torchvision import transforms, utils
|
13 |
-
from tqdm import tqdm
|
14 |
-
|
15 |
-
try:
|
16 |
-
import wandb
|
17 |
-
|
18 |
-
except ImportError:
|
19 |
-
wandb = None
|
20 |
-
|
21 |
-
|
22 |
-
from dataset import MultiResolutionDataset
|
23 |
-
from distributed import (
|
24 |
-
get_rank,
|
25 |
-
synchronize,
|
26 |
-
reduce_loss_dict,
|
27 |
-
reduce_sum,
|
28 |
-
get_world_size,
|
29 |
-
)
|
30 |
-
from op import conv2d_gradfix
|
31 |
-
from non_leaking import augment, AdaptiveAugment
|
32 |
-
|
33 |
-
|
34 |
-
def data_sampler(dataset, shuffle, distributed):
|
35 |
-
if distributed:
|
36 |
-
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
37 |
-
|
38 |
-
if shuffle:
|
39 |
-
return data.RandomSampler(dataset)
|
40 |
-
|
41 |
-
else:
|
42 |
-
return data.SequentialSampler(dataset)
|
43 |
-
|
44 |
-
|
45 |
-
def requires_grad(model, flag=True):
|
46 |
-
for p in model.parameters():
|
47 |
-
p.requires_grad = flag
|
48 |
-
|
49 |
-
|
50 |
-
def accumulate(model1, model2, decay=0.999):
|
51 |
-
par1 = dict(model1.named_parameters())
|
52 |
-
par2 = dict(model2.named_parameters())
|
53 |
-
|
54 |
-
for k in par1.keys():
|
55 |
-
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
|
56 |
-
|
57 |
-
|
58 |
-
def sample_data(loader):
|
59 |
-
while True:
|
60 |
-
for batch in loader:
|
61 |
-
yield batch
|
62 |
-
|
63 |
-
|
64 |
-
def d_logistic_loss(real_pred, fake_pred):
|
65 |
-
real_loss = F.softplus(-real_pred)
|
66 |
-
fake_loss = F.softplus(fake_pred)
|
67 |
-
|
68 |
-
return real_loss.mean() + fake_loss.mean()
|
69 |
-
|
70 |
-
|
71 |
-
def d_r1_loss(real_pred, real_img):
|
72 |
-
with conv2d_gradfix.no_weight_gradients():
|
73 |
-
grad_real, = autograd.grad(
|
74 |
-
outputs=real_pred.sum(), inputs=real_img, create_graph=True
|
75 |
-
)
|
76 |
-
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
|
77 |
-
|
78 |
-
return grad_penalty
|
79 |
-
|
80 |
-
|
81 |
-
def g_nonsaturating_loss(fake_pred):
|
82 |
-
loss = F.softplus(-fake_pred).mean()
|
83 |
-
|
84 |
-
return loss
|
85 |
-
|
86 |
-
|
87 |
-
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
88 |
-
noise = torch.randn_like(fake_img) / math.sqrt(
|
89 |
-
fake_img.shape[2] * fake_img.shape[3]
|
90 |
-
)
|
91 |
-
grad, = autograd.grad(
|
92 |
-
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
|
93 |
-
)
|
94 |
-
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
95 |
-
|
96 |
-
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
97 |
-
|
98 |
-
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
99 |
-
|
100 |
-
return path_penalty, path_mean.detach(), path_lengths
|
101 |
-
|
102 |
-
|
103 |
-
def make_noise(batch, latent_dim, n_noise, device):
|
104 |
-
if n_noise == 1:
|
105 |
-
return torch.randn(batch, latent_dim, device=device)
|
106 |
-
|
107 |
-
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
|
108 |
-
|
109 |
-
return noises
|
110 |
-
|
111 |
-
|
112 |
-
def mixing_noise(batch, latent_dim, prob, device):
|
113 |
-
if prob > 0 and random.random() < prob:
|
114 |
-
return make_noise(batch, latent_dim, 2, device)
|
115 |
-
|
116 |
-
else:
|
117 |
-
return [make_noise(batch, latent_dim, 1, device)]
|
118 |
-
|
119 |
-
|
120 |
-
def set_grad_none(model, targets):
|
121 |
-
for n, p in model.named_parameters():
|
122 |
-
if n in targets:
|
123 |
-
p.grad = None
|
124 |
-
|
125 |
-
|
126 |
-
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
|
127 |
-
loader = sample_data(loader)
|
128 |
-
|
129 |
-
pbar = range(args.iter)
|
130 |
-
|
131 |
-
if get_rank() == 0:
|
132 |
-
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
|
133 |
-
|
134 |
-
mean_path_length = 0
|
135 |
-
|
136 |
-
d_loss_val = 0
|
137 |
-
r1_loss = torch.tensor(0.0, device=device)
|
138 |
-
g_loss_val = 0
|
139 |
-
path_loss = torch.tensor(0.0, device=device)
|
140 |
-
path_lengths = torch.tensor(0.0, device=device)
|
141 |
-
mean_path_length_avg = 0
|
142 |
-
loss_dict = {}
|
143 |
-
|
144 |
-
if args.distributed:
|
145 |
-
g_module = generator.module
|
146 |
-
d_module = discriminator.module
|
147 |
-
|
148 |
-
else:
|
149 |
-
g_module = generator
|
150 |
-
d_module = discriminator
|
151 |
-
|
152 |
-
accum = 0.5 ** (32 / (10 * 1000))
|
153 |
-
ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
|
154 |
-
r_t_stat = 0
|
155 |
-
|
156 |
-
if args.augment and args.augment_p == 0:
|
157 |
-
ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)
|
158 |
-
|
159 |
-
sample_z = torch.randn(args.n_sample, args.latent, device=device)
|
160 |
-
|
161 |
-
for idx in pbar:
|
162 |
-
i = idx + args.start_iter
|
163 |
-
|
164 |
-
if i > args.iter:
|
165 |
-
print("Done!")
|
166 |
-
|
167 |
-
break
|
168 |
-
|
169 |
-
real_img = next(loader)
|
170 |
-
real_img = real_img.to(device)
|
171 |
-
|
172 |
-
requires_grad(generator, False)
|
173 |
-
requires_grad(discriminator, True)
|
174 |
-
|
175 |
-
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
176 |
-
fake_img, _ = generator(noise)
|
177 |
-
|
178 |
-
if args.augment:
|
179 |
-
real_img_aug, _ = augment(real_img, ada_aug_p)
|
180 |
-
fake_img, _ = augment(fake_img, ada_aug_p)
|
181 |
-
|
182 |
-
else:
|
183 |
-
real_img_aug = real_img
|
184 |
-
|
185 |
-
fake_pred = discriminator(fake_img)
|
186 |
-
real_pred = discriminator(real_img_aug)
|
187 |
-
d_loss = d_logistic_loss(real_pred, fake_pred)
|
188 |
-
|
189 |
-
loss_dict["d"] = d_loss
|
190 |
-
loss_dict["real_score"] = real_pred.mean()
|
191 |
-
loss_dict["fake_score"] = fake_pred.mean()
|
192 |
-
|
193 |
-
discriminator.zero_grad()
|
194 |
-
d_loss.backward()
|
195 |
-
d_optim.step()
|
196 |
-
|
197 |
-
if args.augment and args.augment_p == 0:
|
198 |
-
ada_aug_p = ada_augment.tune(real_pred)
|
199 |
-
r_t_stat = ada_augment.r_t_stat
|
200 |
-
|
201 |
-
d_regularize = i % args.d_reg_every == 0
|
202 |
-
|
203 |
-
if d_regularize:
|
204 |
-
real_img.requires_grad = True
|
205 |
-
|
206 |
-
if args.augment:
|
207 |
-
real_img_aug, _ = augment(real_img, ada_aug_p)
|
208 |
-
|
209 |
-
else:
|
210 |
-
real_img_aug = real_img
|
211 |
-
|
212 |
-
real_pred = discriminator(real_img_aug)
|
213 |
-
r1_loss = d_r1_loss(real_pred, real_img)
|
214 |
-
|
215 |
-
discriminator.zero_grad()
|
216 |
-
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
|
217 |
-
|
218 |
-
d_optim.step()
|
219 |
-
|
220 |
-
loss_dict["r1"] = r1_loss
|
221 |
-
|
222 |
-
requires_grad(generator, True)
|
223 |
-
requires_grad(discriminator, False)
|
224 |
-
|
225 |
-
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
226 |
-
fake_img, _ = generator(noise)
|
227 |
-
|
228 |
-
if args.augment:
|
229 |
-
fake_img, _ = augment(fake_img, ada_aug_p)
|
230 |
-
|
231 |
-
fake_pred = discriminator(fake_img)
|
232 |
-
g_loss = g_nonsaturating_loss(fake_pred)
|
233 |
-
|
234 |
-
loss_dict["g"] = g_loss
|
235 |
-
|
236 |
-
generator.zero_grad()
|
237 |
-
g_loss.backward()
|
238 |
-
g_optim.step()
|
239 |
-
|
240 |
-
g_regularize = i % args.g_reg_every == 0
|
241 |
-
|
242 |
-
if g_regularize:
|
243 |
-
path_batch_size = max(1, args.batch // args.path_batch_shrink)
|
244 |
-
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
|
245 |
-
fake_img, latents = generator(noise, return_latents=True)
|
246 |
-
|
247 |
-
path_loss, mean_path_length, path_lengths = g_path_regularize(
|
248 |
-
fake_img, latents, mean_path_length
|
249 |
-
)
|
250 |
-
|
251 |
-
generator.zero_grad()
|
252 |
-
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
|
253 |
-
|
254 |
-
if args.path_batch_shrink:
|
255 |
-
weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
|
256 |
-
|
257 |
-
weighted_path_loss.backward()
|
258 |
-
|
259 |
-
g_optim.step()
|
260 |
-
|
261 |
-
mean_path_length_avg = (
|
262 |
-
reduce_sum(mean_path_length).item() / get_world_size()
|
263 |
-
)
|
264 |
-
|
265 |
-
loss_dict["path"] = path_loss
|
266 |
-
loss_dict["path_length"] = path_lengths.mean()
|
267 |
-
|
268 |
-
accumulate(g_ema, g_module, accum)
|
269 |
-
|
270 |
-
loss_reduced = reduce_loss_dict(loss_dict)
|
271 |
-
|
272 |
-
d_loss_val = loss_reduced["d"].mean().item()
|
273 |
-
g_loss_val = loss_reduced["g"].mean().item()
|
274 |
-
r1_val = loss_reduced["r1"].mean().item()
|
275 |
-
path_loss_val = loss_reduced["path"].mean().item()
|
276 |
-
real_score_val = loss_reduced["real_score"].mean().item()
|
277 |
-
fake_score_val = loss_reduced["fake_score"].mean().item()
|
278 |
-
path_length_val = loss_reduced["path_length"].mean().item()
|
279 |
-
|
280 |
-
if get_rank() == 0:
|
281 |
-
pbar.set_description(
|
282 |
-
(
|
283 |
-
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
|
284 |
-
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
|
285 |
-
f"augment: {ada_aug_p:.4f}"
|
286 |
-
)
|
287 |
-
)
|
288 |
-
|
289 |
-
if wandb and args.wandb:
|
290 |
-
wandb.log(
|
291 |
-
{
|
292 |
-
"Generator": g_loss_val,
|
293 |
-
"Discriminator": d_loss_val,
|
294 |
-
"Augment": ada_aug_p,
|
295 |
-
"Rt": r_t_stat,
|
296 |
-
"R1": r1_val,
|
297 |
-
"Path Length Regularization": path_loss_val,
|
298 |
-
"Mean Path Length": mean_path_length,
|
299 |
-
"Real Score": real_score_val,
|
300 |
-
"Fake Score": fake_score_val,
|
301 |
-
"Path Length": path_length_val,
|
302 |
-
}
|
303 |
-
)
|
304 |
-
|
305 |
-
if i % 100 == 0:
|
306 |
-
with torch.no_grad():
|
307 |
-
g_ema.eval()
|
308 |
-
sample, _ = g_ema([sample_z])
|
309 |
-
utils.save_image(
|
310 |
-
sample,
|
311 |
-
f"sample/{str(i).zfill(6)}.png",
|
312 |
-
nrow=int(args.n_sample ** 0.5),
|
313 |
-
normalize=True,
|
314 |
-
range=(-1, 1),
|
315 |
-
)
|
316 |
-
|
317 |
-
if i % 10000 == 0:
|
318 |
-
torch.save(
|
319 |
-
{
|
320 |
-
"g": g_module.state_dict(),
|
321 |
-
"d": d_module.state_dict(),
|
322 |
-
"g_ema": g_ema.state_dict(),
|
323 |
-
"g_optim": g_optim.state_dict(),
|
324 |
-
"d_optim": d_optim.state_dict(),
|
325 |
-
"args": args,
|
326 |
-
"ada_aug_p": ada_aug_p,
|
327 |
-
},
|
328 |
-
f"checkpoint/{str(i).zfill(6)}.pt",
|
329 |
-
)
|
330 |
-
|
331 |
-
|
332 |
-
if __name__ == "__main__":
|
333 |
-
device = "cuda"
|
334 |
-
|
335 |
-
parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
|
336 |
-
|
337 |
-
parser.add_argument("path", type=str, help="path to the lmdb dataset")
|
338 |
-
parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
|
339 |
-
parser.add_argument(
|
340 |
-
"--iter", type=int, default=800000, help="total training iterations"
|
341 |
-
)
|
342 |
-
parser.add_argument(
|
343 |
-
"--batch", type=int, default=16, help="batch sizes for each gpus"
|
344 |
-
)
|
345 |
-
parser.add_argument(
|
346 |
-
"--n_sample",
|
347 |
-
type=int,
|
348 |
-
default=64,
|
349 |
-
help="number of the samples generated during training",
|
350 |
-
)
|
351 |
-
parser.add_argument(
|
352 |
-
"--size", type=int, default=256, help="image sizes for the model"
|
353 |
-
)
|
354 |
-
parser.add_argument(
|
355 |
-
"--r1", type=float, default=10, help="weight of the r1 regularization"
|
356 |
-
)
|
357 |
-
parser.add_argument(
|
358 |
-
"--path_regularize",
|
359 |
-
type=float,
|
360 |
-
default=2,
|
361 |
-
help="weight of the path length regularization",
|
362 |
-
)
|
363 |
-
parser.add_argument(
|
364 |
-
"--path_batch_shrink",
|
365 |
-
type=int,
|
366 |
-
default=2,
|
367 |
-
help="batch size reducing factor for the path length regularization (reduce memory consumption)",
|
368 |
-
)
|
369 |
-
parser.add_argument(
|
370 |
-
"--d_reg_every",
|
371 |
-
type=int,
|
372 |
-
default=16,
|
373 |
-
help="interval of the applying r1 regularization",
|
374 |
-
)
|
375 |
-
parser.add_argument(
|
376 |
-
"--g_reg_every",
|
377 |
-
type=int,
|
378 |
-
default=4,
|
379 |
-
help="interval of the applying path length regularization",
|
380 |
-
)
|
381 |
-
parser.add_argument(
|
382 |
-
"--mixing", type=float, default=0.9, help="probability of latent code mixing"
|
383 |
-
)
|
384 |
-
parser.add_argument(
|
385 |
-
"--ckpt",
|
386 |
-
type=str,
|
387 |
-
default=None,
|
388 |
-
help="path to the checkpoints to resume training",
|
389 |
-
)
|
390 |
-
parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
|
391 |
-
parser.add_argument(
|
392 |
-
"--channel_multiplier",
|
393 |
-
type=int,
|
394 |
-
default=2,
|
395 |
-
help="channel multiplier factor for the model. config-f = 2, else = 1",
|
396 |
-
)
|
397 |
-
parser.add_argument(
|
398 |
-
"--wandb", action="store_true", help="use weights and biases logging"
|
399 |
-
)
|
400 |
-
parser.add_argument(
|
401 |
-
"--local_rank", type=int, default=0, help="local rank for distributed training"
|
402 |
-
)
|
403 |
-
parser.add_argument(
|
404 |
-
"--augment", action="store_true", help="apply non leaking augmentation"
|
405 |
-
)
|
406 |
-
parser.add_argument(
|
407 |
-
"--augment_p",
|
408 |
-
type=float,
|
409 |
-
default=0,
|
410 |
-
help="probability of applying augmentation. 0 = use adaptive augmentation",
|
411 |
-
)
|
412 |
-
parser.add_argument(
|
413 |
-
"--ada_target",
|
414 |
-
type=float,
|
415 |
-
default=0.6,
|
416 |
-
help="target augmentation probability for adaptive augmentation",
|
417 |
-
)
|
418 |
-
parser.add_argument(
|
419 |
-
"--ada_length",
|
420 |
-
type=int,
|
421 |
-
default=500 * 1000,
|
422 |
-
help="target duraing to reach augmentation probability for adaptive augmentation",
|
423 |
-
)
|
424 |
-
parser.add_argument(
|
425 |
-
"--ada_every",
|
426 |
-
type=int,
|
427 |
-
default=256,
|
428 |
-
help="probability update interval of the adaptive augmentation",
|
429 |
-
)
|
430 |
-
|
431 |
-
args = parser.parse_args()
|
432 |
-
|
433 |
-
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
434 |
-
args.distributed = n_gpu > 1
|
435 |
-
|
436 |
-
if args.distributed:
|
437 |
-
torch.cuda.set_device(args.local_rank)
|
438 |
-
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
439 |
-
synchronize()
|
440 |
-
|
441 |
-
args.latent = 512
|
442 |
-
args.n_mlp = 8
|
443 |
-
|
444 |
-
args.start_iter = 0
|
445 |
-
|
446 |
-
if args.arch == 'stylegan2':
|
447 |
-
from model import Generator, Discriminator
|
448 |
-
|
449 |
-
elif args.arch == 'swagan':
|
450 |
-
from swagan import Generator, Discriminator
|
451 |
-
|
452 |
-
generator = Generator(
|
453 |
-
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
454 |
-
).to(device)
|
455 |
-
discriminator = Discriminator(
|
456 |
-
args.size, channel_multiplier=args.channel_multiplier
|
457 |
-
).to(device)
|
458 |
-
g_ema = Generator(
|
459 |
-
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
460 |
-
).to(device)
|
461 |
-
g_ema.eval()
|
462 |
-
accumulate(g_ema, generator, 0)
|
463 |
-
|
464 |
-
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
|
465 |
-
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
|
466 |
-
|
467 |
-
g_optim = optim.Adam(
|
468 |
-
generator.parameters(),
|
469 |
-
lr=args.lr * g_reg_ratio,
|
470 |
-
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
|
471 |
-
)
|
472 |
-
d_optim = optim.Adam(
|
473 |
-
discriminator.parameters(),
|
474 |
-
lr=args.lr * d_reg_ratio,
|
475 |
-
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
|
476 |
-
)
|
477 |
-
|
478 |
-
if args.ckpt is not None:
|
479 |
-
print("load model:", args.ckpt)
|
480 |
-
|
481 |
-
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
|
482 |
-
|
483 |
-
try:
|
484 |
-
ckpt_name = os.path.basename(args.ckpt)
|
485 |
-
args.start_iter = int(os.path.splitext(ckpt_name)[0])
|
486 |
-
|
487 |
-
except ValueError:
|
488 |
-
pass
|
489 |
-
|
490 |
-
generator.load_state_dict(ckpt["g"])
|
491 |
-
discriminator.load_state_dict(ckpt["d"])
|
492 |
-
g_ema.load_state_dict(ckpt["g_ema"])
|
493 |
-
|
494 |
-
g_optim.load_state_dict(ckpt["g_optim"])
|
495 |
-
d_optim.load_state_dict(ckpt["d_optim"])
|
496 |
-
|
497 |
-
if args.distributed:
|
498 |
-
generator = nn.parallel.DistributedDataParallel(
|
499 |
-
generator,
|
500 |
-
device_ids=[args.local_rank],
|
501 |
-
output_device=args.local_rank,
|
502 |
-
broadcast_buffers=False,
|
503 |
-
)
|
504 |
-
|
505 |
-
discriminator = nn.parallel.DistributedDataParallel(
|
506 |
-
discriminator,
|
507 |
-
device_ids=[args.local_rank],
|
508 |
-
output_device=args.local_rank,
|
509 |
-
broadcast_buffers=False,
|
510 |
-
)
|
511 |
-
|
512 |
-
transform = transforms.Compose(
|
513 |
-
[
|
514 |
-
transforms.RandomHorizontalFlip(),
|
515 |
-
transforms.ToTensor(),
|
516 |
-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
|
517 |
-
]
|
518 |
-
)
|
519 |
-
|
520 |
-
dataset = MultiResolutionDataset(args.path, transform, args.size)
|
521 |
-
loader = data.DataLoader(
|
522 |
-
dataset,
|
523 |
-
batch_size=args.batch,
|
524 |
-
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
|
525 |
-
drop_last=True,
|
526 |
-
)
|
527 |
-
|
528 |
-
if get_rank() == 0 and wandb is not None and args.wandb:
|
529 |
-
wandb.init(project="stylegan 2")
|
530 |
-
|
531 |
-
train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|