feng2022 commited on
Commit
92c1174
1 Parent(s): 494f311
LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2019 Kim Seonghyeon
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE-FID DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE-LPIPS DELETED
@@ -1,24 +0,0 @@
1
- Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
- All rights reserved.
3
-
4
- Redistribution and use in source and binary forms, with or without
5
- modification, are permitted provided that the following conditions are met:
6
-
7
- * Redistributions of source code must retain the above copyright notice, this
8
- list of conditions and the following disclaimer.
9
-
10
- * Redistributions in binary form must reproduce the above copyright notice,
11
- this list of conditions and the following disclaimer in the documentation
12
- and/or other materials provided with the distribution.
13
-
14
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE-NVIDIA DELETED
@@ -1,101 +0,0 @@
1
- Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
-
3
-
4
- Nvidia Source Code License-NC
5
-
6
- =======================================================================
7
-
8
- 1. Definitions
9
-
10
- "Licensor" means any person or entity that distributes its Work.
11
-
12
- "Software" means the original work of authorship made available under
13
- this License.
14
-
15
- "Work" means the Software and any additions to or derivative works of
16
- the Software that are made available under this License.
17
-
18
- "Nvidia Processors" means any central processing unit (CPU), graphics
19
- processing unit (GPU), field-programmable gate array (FPGA),
20
- application-specific integrated circuit (ASIC) or any combination
21
- thereof designed, made, sold, or provided by Nvidia or its affiliates.
22
-
23
- The terms "reproduce," "reproduction," "derivative works," and
24
- "distribution" have the meaning as provided under U.S. copyright law;
25
- provided, however, that for the purposes of this License, derivative
26
- works shall not include works that remain separable from, or merely
27
- link (or bind by name) to the interfaces of, the Work.
28
-
29
- Works, including the Software, are "made available" under this License
30
- by including in or with the Work either (a) a copyright notice
31
- referencing the applicability of this License to the Work, or (b) a
32
- copy of this License.
33
-
34
- 2. License Grants
35
-
36
- 2.1 Copyright Grant. Subject to the terms and conditions of this
37
- License, each Licensor grants to you a perpetual, worldwide,
38
- non-exclusive, royalty-free, copyright license to reproduce,
39
- prepare derivative works of, publicly display, publicly perform,
40
- sublicense and distribute its Work and any resulting derivative
41
- works in any form.
42
-
43
- 3. Limitations
44
-
45
- 3.1 Redistribution. You may reproduce or distribute the Work only
46
- if (a) you do so under this License, (b) you include a complete
47
- copy of this License with your distribution, and (c) you retain
48
- without modification any copyright, patent, trademark, or
49
- attribution notices that are present in the Work.
50
-
51
- 3.2 Derivative Works. You may specify that additional or different
52
- terms apply to the use, reproduction, and distribution of your
53
- derivative works of the Work ("Your Terms") only if (a) Your Terms
54
- provide that the use limitation in Section 3.3 applies to your
55
- derivative works, and (b) you identify the specific derivative
56
- works that are subject to Your Terms. Notwithstanding Your Terms,
57
- this License (including the redistribution requirements in Section
58
- 3.1) will continue to apply to the Work itself.
59
-
60
- 3.3 Use Limitation. The Work and any derivative works thereof only
61
- may be used or intended for use non-commercially. The Work or
62
- derivative works thereof may be used or intended for use by Nvidia
63
- or its affiliates commercially or non-commercially. As used herein,
64
- "non-commercially" means for research or evaluation purposes only.
65
-
66
- 3.4 Patent Claims. If you bring or threaten to bring a patent claim
67
- against any Licensor (including any claim, cross-claim or
68
- counterclaim in a lawsuit) to enforce any patents that you allege
69
- are infringed by any Work, then your rights under this License from
70
- such Licensor (including the grants in Sections 2.1 and 2.2) will
71
- terminate immediately.
72
-
73
- 3.5 Trademarks. This License does not grant any rights to use any
74
- Licensor's or its affiliates' names, logos, or trademarks, except
75
- as necessary to reproduce the notices described in this License.
76
-
77
- 3.6 Termination. If you violate any term of this License, then your
78
- rights under this License (including the grants in Sections 2.1 and
79
- 2.2) will terminate immediately.
80
-
81
- 4. Disclaimer of Warranty.
82
-
83
- THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
84
- KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
85
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
86
- NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
87
- THIS LICENSE.
88
-
89
- 5. Limitation of Liability.
90
-
91
- EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
92
- THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
93
- SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
94
- INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
95
- OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
96
- (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
97
- LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
98
- COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
99
- THE POSSIBILITY OF SUCH DAMAGES.
100
-
101
- =======================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
apply_factor.py DELETED
@@ -1,94 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
- from torchvision import utils
5
-
6
- from model import Generator
7
-
8
-
9
- if __name__ == "__main__":
10
- torch.set_grad_enabled(False)
11
-
12
- parser = argparse.ArgumentParser(description="Apply closed form factorization")
13
-
14
- parser.add_argument(
15
- "-i", "--index", type=int, default=0, help="index of eigenvector"
16
- )
17
- parser.add_argument(
18
- "-d",
19
- "--degree",
20
- type=float,
21
- default=5,
22
- help="scalar factors for moving latent vectors along eigenvector",
23
- )
24
- parser.add_argument(
25
- "--channel_multiplier",
26
- type=int,
27
- default=2,
28
- help='channel multiplier factor. config-f = 2, else = 1',
29
- )
30
- parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints")
31
- parser.add_argument(
32
- "--size", type=int, default=256, help="output image size of the generator"
33
- )
34
- parser.add_argument(
35
- "-n", "--n_sample", type=int, default=7, help="number of samples created"
36
- )
37
- parser.add_argument(
38
- "--truncation", type=float, default=0.7, help="truncation factor"
39
- )
40
- parser.add_argument(
41
- "--device", type=str, default="cuda", help="device to run the model"
42
- )
43
- parser.add_argument(
44
- "--out_prefix",
45
- type=str,
46
- default="factor",
47
- help="filename prefix to result samples",
48
- )
49
- parser.add_argument(
50
- "factor",
51
- type=str,
52
- help="name of the closed form factorization result factor file",
53
- )
54
-
55
- args = parser.parse_args()
56
-
57
- eigvec = torch.load(args.factor)["eigvec"].to(args.device)
58
- ckpt = torch.load(args.ckpt)
59
- g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)
60
- g.load_state_dict(ckpt["g_ema"], strict=False)
61
-
62
- trunc = g.mean_latent(4096)
63
-
64
- latent = torch.randn(args.n_sample, 512, device=args.device)
65
- latent = g.get_latent(latent)
66
-
67
- direction = args.degree * eigvec[:, args.index].unsqueeze(0)
68
-
69
- img, _ = g(
70
- [latent],
71
- truncation=args.truncation,
72
- truncation_latent=trunc,
73
- input_is_latent=True,
74
- )
75
- img1, _ = g(
76
- [latent + direction],
77
- truncation=args.truncation,
78
- truncation_latent=trunc,
79
- input_is_latent=True,
80
- )
81
- img2, _ = g(
82
- [latent - direction],
83
- truncation=args.truncation,
84
- truncation_latent=trunc,
85
- input_is_latent=True,
86
- )
87
-
88
- grid = utils.save_image(
89
- torch.cat([img1, img, img2], 0),
90
- f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png",
91
- normalize=True,
92
- range=(-1, 1),
93
- nrow=args.n_sample,
94
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
calc_inception.py DELETED
@@ -1,130 +0,0 @@
1
- import argparse
2
- import pickle
3
- import os
4
-
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
- from torch.utils.data import DataLoader
9
- from torchvision import transforms
10
- from torchvision.models import inception_v3, Inception3
11
- import numpy as np
12
- from tqdm import tqdm
13
-
14
- from inception import InceptionV3
15
- from dataset import MultiResolutionDataset
16
-
17
-
18
- class Inception3Feature(Inception3):
19
- def forward(self, x):
20
- if x.shape[2] != 299 or x.shape[3] != 299:
21
- x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
22
-
23
- x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
24
- x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
25
- x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
26
- x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
27
-
28
- x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
29
- x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
30
- x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
31
-
32
- x = self.Mixed_5b(x) # 35 x 35 x 192
33
- x = self.Mixed_5c(x) # 35 x 35 x 256
34
- x = self.Mixed_5d(x) # 35 x 35 x 288
35
-
36
- x = self.Mixed_6a(x) # 35 x 35 x 288
37
- x = self.Mixed_6b(x) # 17 x 17 x 768
38
- x = self.Mixed_6c(x) # 17 x 17 x 768
39
- x = self.Mixed_6d(x) # 17 x 17 x 768
40
- x = self.Mixed_6e(x) # 17 x 17 x 768
41
-
42
- x = self.Mixed_7a(x) # 17 x 17 x 768
43
- x = self.Mixed_7b(x) # 8 x 8 x 1280
44
- x = self.Mixed_7c(x) # 8 x 8 x 2048
45
-
46
- x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
47
-
48
- return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
49
-
50
-
51
- def load_patched_inception_v3():
52
- # inception = inception_v3(pretrained=True)
53
- # inception_feat = Inception3Feature()
54
- # inception_feat.load_state_dict(inception.state_dict())
55
- inception_feat = InceptionV3([3], normalize_input=False)
56
-
57
- return inception_feat
58
-
59
-
60
- @torch.no_grad()
61
- def extract_features(loader, inception, device):
62
- pbar = tqdm(loader)
63
-
64
- feature_list = []
65
-
66
- for img in pbar:
67
- img = img.to(device)
68
- feature = inception(img)[0].view(img.shape[0], -1)
69
- feature_list.append(feature.to("cpu"))
70
-
71
- features = torch.cat(feature_list, 0)
72
-
73
- return features
74
-
75
-
76
- if __name__ == "__main__":
77
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
-
79
- parser = argparse.ArgumentParser(
80
- description="Calculate Inception v3 features for datasets"
81
- )
82
- parser.add_argument(
83
- "--size",
84
- type=int,
85
- default=256,
86
- help="image sizes used for embedding calculation",
87
- )
88
- parser.add_argument(
89
- "--batch", default=64, type=int, help="batch size for inception networks"
90
- )
91
- parser.add_argument(
92
- "--n_sample",
93
- type=int,
94
- default=50000,
95
- help="number of samples used for embedding calculation",
96
- )
97
- parser.add_argument(
98
- "--flip", action="store_true", help="apply random flipping to real images"
99
- )
100
- parser.add_argument("path", metavar="PATH", help="path to datset lmdb file")
101
-
102
- args = parser.parse_args()
103
-
104
- inception = load_patched_inception_v3()
105
- inception = nn.DataParallel(inception).eval().to(device)
106
-
107
- transform = transforms.Compose(
108
- [
109
- transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
110
- transforms.ToTensor(),
111
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
112
- ]
113
- )
114
-
115
- dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
116
- loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
117
-
118
- features = extract_features(loader, inception, device).numpy()
119
-
120
- features = features[: args.n_sample]
121
-
122
- print(f"extracted {features.shape[0]} features")
123
-
124
- mean = np.mean(features, 0)
125
- cov = np.cov(features, rowvar=False)
126
-
127
- name = os.path.splitext(os.path.basename(args.path))[0]
128
-
129
- with open(f"inception_{name}.pkl", "wb") as f:
130
- pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoint/.gitignore DELETED
@@ -1 +0,0 @@
1
- *.pt
 
 
closed_form_factorization.py DELETED
@@ -1,33 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
-
5
-
6
- if __name__ == "__main__":
7
- parser = argparse.ArgumentParser(
8
- description="Extract factor/eigenvectors of latent spaces using closed form factorization"
9
- )
10
-
11
- parser.add_argument(
12
- "--out", type=str, default="factor.pt", help="name of the result factor file"
13
- )
14
- parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
15
-
16
- args = parser.parse_args()
17
-
18
- ckpt = torch.load(args.ckpt)
19
- modulate = {
20
- k: v
21
- for k, v in ckpt["g_ema"].items()
22
- if "modulation" in k and "to_rgbs" not in k and "weight" in k
23
- }
24
-
25
- weight_mat = []
26
- for k, v in modulate.items():
27
- weight_mat.append(v)
28
-
29
- W = torch.cat(weight_mat, 0)
30
- eigvec = torch.svd(W).V.to("cpu")
31
-
32
- torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out)
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
convert_weight.py DELETED
@@ -1,301 +0,0 @@
1
- import argparse
2
- import os
3
- import sys
4
- import pickle
5
- import math
6
-
7
- import torch
8
- import numpy as np
9
- from torchvision import utils
10
-
11
- from model import Generator, Discriminator
12
-
13
-
14
- def convert_modconv(vars, source_name, target_name, flip=False):
15
- weight = vars[source_name + "/weight"].value().eval()
16
- mod_weight = vars[source_name + "/mod_weight"].value().eval()
17
- mod_bias = vars[source_name + "/mod_bias"].value().eval()
18
- noise = vars[source_name + "/noise_strength"].value().eval()
19
- bias = vars[source_name + "/bias"].value().eval()
20
-
21
- dic = {
22
- "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
23
- "conv.modulation.weight": mod_weight.transpose((1, 0)),
24
- "conv.modulation.bias": mod_bias + 1,
25
- "noise.weight": np.array([noise]),
26
- "activate.bias": bias,
27
- }
28
-
29
- dic_torch = {}
30
-
31
- for k, v in dic.items():
32
- dic_torch[target_name + "." + k] = torch.from_numpy(v)
33
-
34
- if flip:
35
- dic_torch[target_name + ".conv.weight"] = torch.flip(
36
- dic_torch[target_name + ".conv.weight"], [3, 4]
37
- )
38
-
39
- return dic_torch
40
-
41
-
42
- def convert_conv(vars, source_name, target_name, bias=True, start=0):
43
- weight = vars[source_name + "/weight"].value().eval()
44
-
45
- dic = {"weight": weight.transpose((3, 2, 0, 1))}
46
-
47
- if bias:
48
- dic["bias"] = vars[source_name + "/bias"].value().eval()
49
-
50
- dic_torch = {}
51
-
52
- dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
53
-
54
- if bias:
55
- dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
56
-
57
- return dic_torch
58
-
59
-
60
- def convert_torgb(vars, source_name, target_name):
61
- weight = vars[source_name + "/weight"].value().eval()
62
- mod_weight = vars[source_name + "/mod_weight"].value().eval()
63
- mod_bias = vars[source_name + "/mod_bias"].value().eval()
64
- bias = vars[source_name + "/bias"].value().eval()
65
-
66
- dic = {
67
- "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
68
- "conv.modulation.weight": mod_weight.transpose((1, 0)),
69
- "conv.modulation.bias": mod_bias + 1,
70
- "bias": bias.reshape((1, 3, 1, 1)),
71
- }
72
-
73
- dic_torch = {}
74
-
75
- for k, v in dic.items():
76
- dic_torch[target_name + "." + k] = torch.from_numpy(v)
77
-
78
- return dic_torch
79
-
80
-
81
- def convert_dense(vars, source_name, target_name):
82
- weight = vars[source_name + "/weight"].value().eval()
83
- bias = vars[source_name + "/bias"].value().eval()
84
-
85
- dic = {"weight": weight.transpose((1, 0)), "bias": bias}
86
-
87
- dic_torch = {}
88
-
89
- for k, v in dic.items():
90
- dic_torch[target_name + "." + k] = torch.from_numpy(v)
91
-
92
- return dic_torch
93
-
94
-
95
- def update(state_dict, new):
96
- for k, v in new.items():
97
- if k not in state_dict:
98
- raise KeyError(k + " is not found")
99
-
100
- if v.shape != state_dict[k].shape:
101
- raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
102
-
103
- state_dict[k] = v
104
-
105
-
106
- def discriminator_fill_statedict(statedict, vars, size):
107
- log_size = int(math.log(size, 2))
108
-
109
- update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
110
-
111
- conv_i = 1
112
-
113
- for i in range(log_size - 2, 0, -1):
114
- reso = 4 * 2 ** i
115
- update(
116
- statedict,
117
- convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
118
- )
119
- update(
120
- statedict,
121
- convert_conv(
122
- vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
123
- ),
124
- )
125
- update(
126
- statedict,
127
- convert_conv(
128
- vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
129
- ),
130
- )
131
- conv_i += 1
132
-
133
- update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
134
- update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
135
- update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
136
-
137
- return statedict
138
-
139
-
140
- def fill_statedict(state_dict, vars, size, n_mlp):
141
- log_size = int(math.log(size, 2))
142
-
143
- for i in range(n_mlp):
144
- update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
145
-
146
- update(
147
- state_dict,
148
- {
149
- "input.input": torch.from_numpy(
150
- vars["G_synthesis/4x4/Const/const"].value().eval()
151
- )
152
- },
153
- )
154
-
155
- update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
156
-
157
- for i in range(log_size - 2):
158
- reso = 4 * 2 ** (i + 1)
159
- update(
160
- state_dict,
161
- convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
162
- )
163
-
164
- update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
165
-
166
- conv_i = 0
167
-
168
- for i in range(log_size - 2):
169
- reso = 4 * 2 ** (i + 1)
170
- update(
171
- state_dict,
172
- convert_modconv(
173
- vars,
174
- f"G_synthesis/{reso}x{reso}/Conv0_up",
175
- f"convs.{conv_i}",
176
- flip=True,
177
- ),
178
- )
179
- update(
180
- state_dict,
181
- convert_modconv(
182
- vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
183
- ),
184
- )
185
- conv_i += 2
186
-
187
- for i in range(0, (log_size - 2) * 2 + 1):
188
- update(
189
- state_dict,
190
- {
191
- f"noises.noise_{i}": torch.from_numpy(
192
- vars[f"G_synthesis/noise{i}"].value().eval()
193
- )
194
- },
195
- )
196
-
197
- return state_dict
198
-
199
-
200
- if __name__ == "__main__":
201
- device = "cuda"
202
-
203
- parser = argparse.ArgumentParser(
204
- description="Tensorflow to pytorch model checkpoint converter"
205
- )
206
- parser.add_argument(
207
- "--repo",
208
- type=str,
209
- required=True,
210
- help="path to the offical StyleGAN2 repository with dnnlib/ folder",
211
- )
212
- parser.add_argument(
213
- "--gen", action="store_true", help="convert the generator weights"
214
- )
215
- parser.add_argument(
216
- "--disc", action="store_true", help="convert the discriminator weights"
217
- )
218
- parser.add_argument(
219
- "--channel_multiplier",
220
- type=int,
221
- default=2,
222
- help="channel multiplier factor. config-f = 2, else = 1",
223
- )
224
- parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
225
-
226
- args = parser.parse_args()
227
-
228
- sys.path.append(args.repo)
229
-
230
- import dnnlib
231
- from dnnlib import tflib
232
-
233
- tflib.init_tf()
234
-
235
- with open(args.path, "rb") as f:
236
- generator, discriminator, g_ema = pickle.load(f)
237
-
238
- size = g_ema.output_shape[2]
239
-
240
- n_mlp = 0
241
- mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
242
- for layer in mapping_layers_names:
243
- if layer[0].startswith('Dense'):
244
- n_mlp += 1
245
-
246
- g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
247
- state_dict = g.state_dict()
248
- state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
249
-
250
- g.load_state_dict(state_dict)
251
-
252
- latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
253
-
254
- ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
255
-
256
- if args.gen:
257
- g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
258
- g_train_state = g_train.state_dict()
259
- g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
260
- ckpt["g"] = g_train_state
261
-
262
- if args.disc:
263
- disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
264
- d_state = disc.state_dict()
265
- d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
266
- ckpt["d"] = d_state
267
-
268
- name = os.path.splitext(os.path.basename(args.path))[0]
269
- torch.save(ckpt, name + ".pt")
270
-
271
- batch_size = {256: 16, 512: 9, 1024: 4}
272
- n_sample = batch_size.get(size, 25)
273
-
274
- g = g.to(device)
275
-
276
- z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
277
-
278
- with torch.no_grad():
279
- img_pt, _ = g(
280
- [torch.from_numpy(z).to(device)],
281
- truncation=0.5,
282
- truncation_latent=latent_avg.to(device),
283
- randomize_noise=False,
284
- )
285
-
286
- Gs_kwargs = dnnlib.EasyDict()
287
- Gs_kwargs.randomize_noise = False
288
- img_tf = g_ema.run(z, None, **Gs_kwargs)
289
- img_tf = torch.from_numpy(img_tf).to(device)
290
-
291
- img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
292
- 0.0, 1.0
293
- )
294
-
295
- img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
296
-
297
- print(img_diff.abs().max())
298
-
299
- utils.save_image(
300
- img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
301
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset.py DELETED
@@ -1,40 +0,0 @@
1
- from io import BytesIO
2
-
3
- import lmdb
4
- from PIL import Image
5
- from torch.utils.data import Dataset
6
-
7
-
8
- class MultiResolutionDataset(Dataset):
9
- def __init__(self, path, transform, resolution=256):
10
- self.env = lmdb.open(
11
- path,
12
- max_readers=32,
13
- readonly=True,
14
- lock=False,
15
- readahead=False,
16
- meminit=False,
17
- )
18
-
19
- if not self.env:
20
- raise IOError('Cannot open lmdb dataset', path)
21
-
22
- with self.env.begin(write=False) as txn:
23
- self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24
-
25
- self.resolution = resolution
26
- self.transform = transform
27
-
28
- def __len__(self):
29
- return self.length
30
-
31
- def __getitem__(self, index):
32
- with self.env.begin(write=False) as txn:
33
- key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34
- img_bytes = txn.get(key)
35
-
36
- buffer = BytesIO(img_bytes)
37
- img = Image.open(buffer)
38
- img = self.transform(img)
39
-
40
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
distributed.py DELETED
@@ -1,126 +0,0 @@
1
- import math
2
- import pickle
3
-
4
- import torch
5
- from torch import distributed as dist
6
- from torch.utils.data.sampler import Sampler
7
-
8
-
9
- def get_rank():
10
- if not dist.is_available():
11
- return 0
12
-
13
- if not dist.is_initialized():
14
- return 0
15
-
16
- return dist.get_rank()
17
-
18
-
19
- def synchronize():
20
- if not dist.is_available():
21
- return
22
-
23
- if not dist.is_initialized():
24
- return
25
-
26
- world_size = dist.get_world_size()
27
-
28
- if world_size == 1:
29
- return
30
-
31
- dist.barrier()
32
-
33
-
34
- def get_world_size():
35
- if not dist.is_available():
36
- return 1
37
-
38
- if not dist.is_initialized():
39
- return 1
40
-
41
- return dist.get_world_size()
42
-
43
-
44
- def reduce_sum(tensor):
45
- if not dist.is_available():
46
- return tensor
47
-
48
- if not dist.is_initialized():
49
- return tensor
50
-
51
- tensor = tensor.clone()
52
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53
-
54
- return tensor
55
-
56
-
57
- def gather_grad(params):
58
- world_size = get_world_size()
59
-
60
- if world_size == 1:
61
- return
62
-
63
- for param in params:
64
- if param.grad is not None:
65
- dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66
- param.grad.data.div_(world_size)
67
-
68
-
69
- def all_gather(data):
70
- world_size = get_world_size()
71
-
72
- if world_size == 1:
73
- return [data]
74
-
75
- buffer = pickle.dumps(data)
76
- storage = torch.ByteStorage.from_buffer(buffer)
77
- tensor = torch.ByteTensor(storage).to('cuda')
78
-
79
- local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80
- size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81
- dist.all_gather(size_list, local_size)
82
- size_list = [int(size.item()) for size in size_list]
83
- max_size = max(size_list)
84
-
85
- tensor_list = []
86
- for _ in size_list:
87
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88
-
89
- if local_size != max_size:
90
- padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91
- tensor = torch.cat((tensor, padding), 0)
92
-
93
- dist.all_gather(tensor_list, tensor)
94
-
95
- data_list = []
96
-
97
- for size, tensor in zip(size_list, tensor_list):
98
- buffer = tensor.cpu().numpy().tobytes()[:size]
99
- data_list.append(pickle.loads(buffer))
100
-
101
- return data_list
102
-
103
-
104
- def reduce_loss_dict(loss_dict):
105
- world_size = get_world_size()
106
-
107
- if world_size < 2:
108
- return loss_dict
109
-
110
- with torch.no_grad():
111
- keys = []
112
- losses = []
113
-
114
- for k in sorted(loss_dict.keys()):
115
- keys.append(k)
116
- losses.append(loss_dict[k])
117
-
118
- losses = torch.stack(losses, 0)
119
- dist.reduce(losses, dst=0)
120
-
121
- if dist.get_rank() == 0:
122
- losses /= world_size
123
-
124
- reduced_losses = {k: v for k, v in zip(keys, losses)}
125
-
126
- return reduced_losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
download.py DELETED
@@ -1,47 +0,0 @@
1
- from transformers import BertConfig
2
-
3
- model_name = 'bert-base-chinese' # bert版本名称
4
- model_path = 'D:/Transformers-Bert/bert-base-chinese/' # 用户下载的预训练bert文件存放地址
5
- config_path = 'D:/Transformers-Bert/bert-base-chinese/config.json' # 用户下载的预训练bert文件config.json存放地址
6
-
7
- # 载入config 文件可以采取三种方式:bert名称、bert文件夹地址、config文件地址
8
- # config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
9
- # config = BertConfig.from_pretrained(model_path)
10
- config = BertConfig.from_pretrained(config_path)
11
-
12
- from transformers import BertTokenizer,BertModel
13
- model_name = 'KenjieDec/Time-Travel-Rephotograph_e4e_ffhq_encode'
14
- config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
15
- tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
16
- model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
17
-
18
- model_name = 'KenjieDec/Time-Travel-Rephotography_stylegan2-ffhq-config-f'
19
- config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
20
- tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
21
- model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
22
-
23
- model_name = 'KenjieDec/Time-Travel-Rephotography_vgg_face_dag'
24
- config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
25
- tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
26
- model = BertModel.from_pretrained(model_name)
27
-
28
- model_name = 'D:\premodel\checkpoint_b'
29
- config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
30
- tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
31
- model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
32
-
33
- model_name = 'D:\premodel\checkpoint_g'
34
- config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
35
- tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
36
- model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
37
-
38
- model_name = 'D:\premodel\checkpoint_gb'
39
- config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
40
- tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
41
- model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
42
-
43
- model_name = 'clearspandex/face-parsing'
44
- config = BertConfig.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
45
- tokenizer = BertTokenizer.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
46
- model = BertModel.from_pretrained(model_name) # 这个方法会自动从官方的s3数据库下载模型信息
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fid.py DELETED
@@ -1,129 +0,0 @@
1
- import argparse
2
- import pickle
3
-
4
- import torch
5
- from torch import nn
6
- import numpy as np
7
- from scipy import linalg
8
- from tqdm import tqdm
9
-
10
- from model import Generator
11
- from calc_inception import load_patched_inception_v3
12
-
13
-
14
- @torch.no_grad()
15
- def extract_feature_from_samples(
16
- generator, inception, truncation, truncation_latent, batch_size, n_sample, device
17
- ):
18
- n_batch = n_sample // batch_size
19
- resid = n_sample - (n_batch * batch_size)
20
- batch_sizes = [batch_size] * n_batch + [resid]
21
- features = []
22
-
23
- for batch in tqdm(batch_sizes):
24
- latent = torch.randn(batch, 512, device=device)
25
- img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
26
- feat = inception(img)[0].view(img.shape[0], -1)
27
- features.append(feat.to("cpu"))
28
-
29
- features = torch.cat(features, 0)
30
-
31
- return features
32
-
33
-
34
- def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
35
- cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
36
-
37
- if not np.isfinite(cov_sqrt).all():
38
- print("product of cov matrices is singular")
39
- offset = np.eye(sample_cov.shape[0]) * eps
40
- cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
41
-
42
- if np.iscomplexobj(cov_sqrt):
43
- if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
44
- m = np.max(np.abs(cov_sqrt.imag))
45
-
46
- raise ValueError(f"Imaginary component {m}")
47
-
48
- cov_sqrt = cov_sqrt.real
49
-
50
- mean_diff = sample_mean - real_mean
51
- mean_norm = mean_diff @ mean_diff
52
-
53
- trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
54
-
55
- fid = mean_norm + trace
56
-
57
- return fid
58
-
59
-
60
- if __name__ == "__main__":
61
- device = "cuda"
62
-
63
- parser = argparse.ArgumentParser(description="Calculate FID scores")
64
-
65
- parser.add_argument("--truncation", type=float, default=1, help="truncation factor")
66
- parser.add_argument(
67
- "--truncation_mean",
68
- type=int,
69
- default=4096,
70
- help="number of samples to calculate mean for truncation",
71
- )
72
- parser.add_argument(
73
- "--batch", type=int, default=64, help="batch size for the generator"
74
- )
75
- parser.add_argument(
76
- "--n_sample",
77
- type=int,
78
- default=50000,
79
- help="number of the samples for calculating FID",
80
- )
81
- parser.add_argument(
82
- "--size", type=int, default=256, help="image sizes for generator"
83
- )
84
- parser.add_argument(
85
- "--inception",
86
- type=str,
87
- default=None,
88
- required=True,
89
- help="path to precomputed inception embedding",
90
- )
91
- parser.add_argument(
92
- "ckpt", metavar="CHECKPOINT", help="path to generator checkpoint"
93
- )
94
-
95
- args = parser.parse_args()
96
-
97
- ckpt = torch.load(args.ckpt)
98
-
99
- g = Generator(args.size, 512, 8).to(device)
100
- g.load_state_dict(ckpt["g_ema"])
101
- g = nn.DataParallel(g)
102
- g.eval()
103
-
104
- if args.truncation < 1:
105
- with torch.no_grad():
106
- mean_latent = g.mean_latent(args.truncation_mean)
107
-
108
- else:
109
- mean_latent = None
110
-
111
- inception = nn.DataParallel(load_patched_inception_v3()).to(device)
112
- inception.eval()
113
-
114
- features = extract_feature_from_samples(
115
- g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
116
- ).numpy()
117
- print(f"extracted {features.shape[0]} features")
118
-
119
- sample_mean = np.mean(features, 0)
120
- sample_cov = np.cov(features, rowvar=False)
121
-
122
- with open(args.inception, "rb") as f:
123
- embeds = pickle.load(f)
124
- real_mean = embeds["mean"]
125
- real_cov = embeds["cov"]
126
-
127
- fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
128
-
129
- print("fid:", fid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate.py DELETED
@@ -1,84 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
- from torchvision import utils
5
- from model import Generator
6
- from tqdm import tqdm
7
-
8
-
9
- def generate(args, g_ema, device, mean_latent):
10
-
11
- with torch.no_grad():
12
- g_ema.eval()
13
- for i in tqdm(range(args.pics)):
14
- sample_z = torch.randn(args.sample, args.latent, device=device)
15
-
16
- sample, _ = g_ema(
17
- [sample_z], truncation=args.truncation, truncation_latent=mean_latent
18
- )
19
-
20
- utils.save_image(
21
- sample,
22
- f"sample/{str(i).zfill(6)}.png",
23
- nrow=1,
24
- normalize=True,
25
- range=(-1, 1),
26
- )
27
-
28
-
29
- if __name__ == "__main__":
30
- device = "cuda"
31
-
32
- parser = argparse.ArgumentParser(description="Generate samples from the generator")
33
-
34
- parser.add_argument(
35
- "--size", type=int, default=1024, help="output image size of the generator"
36
- )
37
- parser.add_argument(
38
- "--sample",
39
- type=int,
40
- default=1,
41
- help="number of samples to be generated for each image",
42
- )
43
- parser.add_argument(
44
- "--pics", type=int, default=20, help="number of images to be generated"
45
- )
46
- parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
47
- parser.add_argument(
48
- "--truncation_mean",
49
- type=int,
50
- default=4096,
51
- help="number of vectors to calculate mean for the truncation",
52
- )
53
- parser.add_argument(
54
- "--ckpt",
55
- type=str,
56
- default="stylegan2-ffhq-config-f.pt",
57
- help="path to the model checkpoint",
58
- )
59
- parser.add_argument(
60
- "--channel_multiplier",
61
- type=int,
62
- default=2,
63
- help="channel multiplier of the generator. config-f = 2, else = 1",
64
- )
65
-
66
- args = parser.parse_args()
67
-
68
- args.latent = 512
69
- args.n_mlp = 8
70
-
71
- g_ema = Generator(
72
- args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
73
- ).to(device)
74
- checkpoint = torch.load(args.ckpt)
75
-
76
- g_ema.load_state_dict(checkpoint["g_ema"])
77
-
78
- if args.truncation < 1:
79
- with torch.no_grad():
80
- mean_latent = g_ema.mean_latent(args.truncation_mean)
81
- else:
82
- mean_latent = None
83
-
84
- generate(args, g_ema, device, mean_latent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inception.py DELETED
@@ -1,310 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torchvision import models
5
-
6
- try:
7
- from torchvision.models.utils import load_state_dict_from_url
8
- except ImportError:
9
- from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
-
11
- # Inception weights ported to Pytorch from
12
- # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
- FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
-
15
-
16
- class InceptionV3(nn.Module):
17
- """Pretrained InceptionV3 network returning feature maps"""
18
-
19
- # Index of default block of inception to return,
20
- # corresponds to output of final average pooling
21
- DEFAULT_BLOCK_INDEX = 3
22
-
23
- # Maps feature dimensionality to their output blocks indices
24
- BLOCK_INDEX_BY_DIM = {
25
- 64: 0, # First max pooling features
26
- 192: 1, # Second max pooling featurs
27
- 768: 2, # Pre-aux classifier features
28
- 2048: 3 # Final average pooling features
29
- }
30
-
31
- def __init__(self,
32
- output_blocks=[DEFAULT_BLOCK_INDEX],
33
- resize_input=True,
34
- normalize_input=True,
35
- requires_grad=False,
36
- use_fid_inception=True):
37
- """Build pretrained InceptionV3
38
-
39
- Parameters
40
- ----------
41
- output_blocks : list of int
42
- Indices of blocks to return features of. Possible values are:
43
- - 0: corresponds to output of first max pooling
44
- - 1: corresponds to output of second max pooling
45
- - 2: corresponds to output which is fed to aux classifier
46
- - 3: corresponds to output of final average pooling
47
- resize_input : bool
48
- If true, bilinearly resizes input to width and height 299 before
49
- feeding input to model. As the network without fully connected
50
- layers is fully convolutional, it should be able to handle inputs
51
- of arbitrary size, so resizing might not be strictly needed
52
- normalize_input : bool
53
- If true, scales the input from range (0, 1) to the range the
54
- pretrained Inception network expects, namely (-1, 1)
55
- requires_grad : bool
56
- If true, parameters of the model require gradients. Possibly useful
57
- for finetuning the network
58
- use_fid_inception : bool
59
- If true, uses the pretrained Inception model used in Tensorflow's
60
- FID implementation. If false, uses the pretrained Inception model
61
- available in torchvision. The FID Inception model has different
62
- weights and a slightly different structure from torchvision's
63
- Inception model. If you want to compute FID scores, you are
64
- strongly advised to set this parameter to true to get comparable
65
- results.
66
- """
67
- super(InceptionV3, self).__init__()
68
-
69
- self.resize_input = resize_input
70
- self.normalize_input = normalize_input
71
- self.output_blocks = sorted(output_blocks)
72
- self.last_needed_block = max(output_blocks)
73
-
74
- assert self.last_needed_block <= 3, \
75
- 'Last possible output block index is 3'
76
-
77
- self.blocks = nn.ModuleList()
78
-
79
- if use_fid_inception:
80
- inception = fid_inception_v3()
81
- else:
82
- inception = models.inception_v3(pretrained=True)
83
-
84
- # Block 0: input to maxpool1
85
- block0 = [
86
- inception.Conv2d_1a_3x3,
87
- inception.Conv2d_2a_3x3,
88
- inception.Conv2d_2b_3x3,
89
- nn.MaxPool2d(kernel_size=3, stride=2)
90
- ]
91
- self.blocks.append(nn.Sequential(*block0))
92
-
93
- # Block 1: maxpool1 to maxpool2
94
- if self.last_needed_block >= 1:
95
- block1 = [
96
- inception.Conv2d_3b_1x1,
97
- inception.Conv2d_4a_3x3,
98
- nn.MaxPool2d(kernel_size=3, stride=2)
99
- ]
100
- self.blocks.append(nn.Sequential(*block1))
101
-
102
- # Block 2: maxpool2 to aux classifier
103
- if self.last_needed_block >= 2:
104
- block2 = [
105
- inception.Mixed_5b,
106
- inception.Mixed_5c,
107
- inception.Mixed_5d,
108
- inception.Mixed_6a,
109
- inception.Mixed_6b,
110
- inception.Mixed_6c,
111
- inception.Mixed_6d,
112
- inception.Mixed_6e,
113
- ]
114
- self.blocks.append(nn.Sequential(*block2))
115
-
116
- # Block 3: aux classifier to final avgpool
117
- if self.last_needed_block >= 3:
118
- block3 = [
119
- inception.Mixed_7a,
120
- inception.Mixed_7b,
121
- inception.Mixed_7c,
122
- nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
- ]
124
- self.blocks.append(nn.Sequential(*block3))
125
-
126
- for param in self.parameters():
127
- param.requires_grad = requires_grad
128
-
129
- def forward(self, inp):
130
- """Get Inception feature maps
131
-
132
- Parameters
133
- ----------
134
- inp : torch.autograd.Variable
135
- Input tensor of shape Bx3xHxW. Values are expected to be in
136
- range (0, 1)
137
-
138
- Returns
139
- -------
140
- List of torch.autograd.Variable, corresponding to the selected output
141
- block, sorted ascending by index
142
- """
143
- outp = []
144
- x = inp
145
-
146
- if self.resize_input:
147
- x = F.interpolate(x,
148
- size=(299, 299),
149
- mode='bilinear',
150
- align_corners=False)
151
-
152
- if self.normalize_input:
153
- x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
-
155
- for idx, block in enumerate(self.blocks):
156
- x = block(x)
157
- if idx in self.output_blocks:
158
- outp.append(x)
159
-
160
- if idx == self.last_needed_block:
161
- break
162
-
163
- return outp
164
-
165
-
166
- def fid_inception_v3():
167
- """Build pretrained Inception model for FID computation
168
-
169
- The Inception model for FID computation uses a different set of weights
170
- and has a slightly different structure than torchvision's Inception.
171
-
172
- This method first constructs torchvision's Inception and then patches the
173
- necessary parts that are different in the FID Inception model.
174
- """
175
- inception = models.inception_v3(num_classes=1008,
176
- aux_logits=False,
177
- pretrained=False)
178
- inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
179
- inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
180
- inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
181
- inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
182
- inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
183
- inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
184
- inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
185
- inception.Mixed_7b = FIDInceptionE_1(1280)
186
- inception.Mixed_7c = FIDInceptionE_2(2048)
187
-
188
- state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
189
- inception.load_state_dict(state_dict)
190
- return inception
191
-
192
-
193
- class FIDInceptionA(models.inception.InceptionA):
194
- """InceptionA block patched for FID computation"""
195
- def __init__(self, in_channels, pool_features):
196
- super(FIDInceptionA, self).__init__(in_channels, pool_features)
197
-
198
- def forward(self, x):
199
- branch1x1 = self.branch1x1(x)
200
-
201
- branch5x5 = self.branch5x5_1(x)
202
- branch5x5 = self.branch5x5_2(branch5x5)
203
-
204
- branch3x3dbl = self.branch3x3dbl_1(x)
205
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
206
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
207
-
208
- # Patch: Tensorflow's average pool does not use the padded zero's in
209
- # its average calculation
210
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
211
- count_include_pad=False)
212
- branch_pool = self.branch_pool(branch_pool)
213
-
214
- outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
215
- return torch.cat(outputs, 1)
216
-
217
-
218
- class FIDInceptionC(models.inception.InceptionC):
219
- """InceptionC block patched for FID computation"""
220
- def __init__(self, in_channels, channels_7x7):
221
- super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
222
-
223
- def forward(self, x):
224
- branch1x1 = self.branch1x1(x)
225
-
226
- branch7x7 = self.branch7x7_1(x)
227
- branch7x7 = self.branch7x7_2(branch7x7)
228
- branch7x7 = self.branch7x7_3(branch7x7)
229
-
230
- branch7x7dbl = self.branch7x7dbl_1(x)
231
- branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
232
- branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
233
- branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
234
- branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
235
-
236
- # Patch: Tensorflow's average pool does not use the padded zero's in
237
- # its average calculation
238
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
239
- count_include_pad=False)
240
- branch_pool = self.branch_pool(branch_pool)
241
-
242
- outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
243
- return torch.cat(outputs, 1)
244
-
245
-
246
- class FIDInceptionE_1(models.inception.InceptionE):
247
- """First InceptionE block patched for FID computation"""
248
- def __init__(self, in_channels):
249
- super(FIDInceptionE_1, self).__init__(in_channels)
250
-
251
- def forward(self, x):
252
- branch1x1 = self.branch1x1(x)
253
-
254
- branch3x3 = self.branch3x3_1(x)
255
- branch3x3 = [
256
- self.branch3x3_2a(branch3x3),
257
- self.branch3x3_2b(branch3x3),
258
- ]
259
- branch3x3 = torch.cat(branch3x3, 1)
260
-
261
- branch3x3dbl = self.branch3x3dbl_1(x)
262
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
263
- branch3x3dbl = [
264
- self.branch3x3dbl_3a(branch3x3dbl),
265
- self.branch3x3dbl_3b(branch3x3dbl),
266
- ]
267
- branch3x3dbl = torch.cat(branch3x3dbl, 1)
268
-
269
- # Patch: Tensorflow's average pool does not use the padded zero's in
270
- # its average calculation
271
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
272
- count_include_pad=False)
273
- branch_pool = self.branch_pool(branch_pool)
274
-
275
- outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
276
- return torch.cat(outputs, 1)
277
-
278
-
279
- class FIDInceptionE_2(models.inception.InceptionE):
280
- """Second InceptionE block patched for FID computation"""
281
- def __init__(self, in_channels):
282
- super(FIDInceptionE_2, self).__init__(in_channels)
283
-
284
- def forward(self, x):
285
- branch1x1 = self.branch1x1(x)
286
-
287
- branch3x3 = self.branch3x3_1(x)
288
- branch3x3 = [
289
- self.branch3x3_2a(branch3x3),
290
- self.branch3x3_2b(branch3x3),
291
- ]
292
- branch3x3 = torch.cat(branch3x3, 1)
293
-
294
- branch3x3dbl = self.branch3x3dbl_1(x)
295
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
296
- branch3x3dbl = [
297
- self.branch3x3dbl_3a(branch3x3dbl),
298
- self.branch3x3dbl_3b(branch3x3dbl),
299
- ]
300
- branch3x3dbl = torch.cat(branch3x3dbl, 1)
301
-
302
- # Patch: The FID Inception model uses max pooling instead of average
303
- # pooling. This is likely an error in this specific Inception
304
- # implementation, as other Inception models use average pooling here
305
- # (which matches the description in the paper).
306
- branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
307
- branch_pool = self.branch_pool(branch_pool)
308
-
309
- outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
310
- return torch.cat(outputs, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lpips/__init__.py DELETED
@@ -1,160 +0,0 @@
1
-
2
- from __future__ import absolute_import
3
- from __future__ import division
4
- from __future__ import print_function
5
-
6
- import numpy as np
7
- from skimage.measure import compare_ssim
8
- import torch
9
- from torch.autograd import Variable
10
-
11
- from lpips import dist_model
12
-
13
- class PerceptualLoss(torch.nn.Module):
14
- def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15
- # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16
- super(PerceptualLoss, self).__init__()
17
- print('Setting up Perceptual loss...')
18
- self.use_gpu = use_gpu
19
- self.spatial = spatial
20
- self.gpu_ids = gpu_ids
21
- self.model = dist_model.DistModel()
22
- self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23
- print('...[%s] initialized'%self.model.name())
24
- print('...Done')
25
-
26
- def forward(self, pred, target, normalize=False):
27
- """
28
- Pred and target are Variables.
29
- If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30
- If normalize is False, assumes the images are already between [-1,+1]
31
-
32
- Inputs pred and target are Nx3xHxW
33
- Output pytorch Variable N long
34
- """
35
-
36
- if normalize:
37
- target = 2 * target - 1
38
- pred = 2 * pred - 1
39
-
40
- return self.model.forward(target, pred)
41
-
42
- def normalize_tensor(in_feat,eps=1e-10):
43
- norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44
- return in_feat/(norm_factor+eps)
45
-
46
- def l2(p0, p1, range=255.):
47
- return .5*np.mean((p0 / range - p1 / range)**2)
48
-
49
- def psnr(p0, p1, peak=255.):
50
- return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51
-
52
- def dssim(p0, p1, range=255.):
53
- return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
54
-
55
- def rgb2lab(in_img,mean_cent=False):
56
- from skimage import color
57
- img_lab = color.rgb2lab(in_img)
58
- if(mean_cent):
59
- img_lab[:,:,0] = img_lab[:,:,0]-50
60
- return img_lab
61
-
62
- def tensor2np(tensor_obj):
63
- # change dimension of a tensor object into a numpy array
64
- return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65
-
66
- def np2tensor(np_obj):
67
- # change dimenion of np array into tensor array
68
- return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69
-
70
- def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71
- # image tensor to lab tensor
72
- from skimage import color
73
-
74
- img = tensor2im(image_tensor)
75
- img_lab = color.rgb2lab(img)
76
- if(mc_only):
77
- img_lab[:,:,0] = img_lab[:,:,0]-50
78
- if(to_norm and not mc_only):
79
- img_lab[:,:,0] = img_lab[:,:,0]-50
80
- img_lab = img_lab/100.
81
-
82
- return np2tensor(img_lab)
83
-
84
- def tensorlab2tensor(lab_tensor,return_inbnd=False):
85
- from skimage import color
86
- import warnings
87
- warnings.filterwarnings("ignore")
88
-
89
- lab = tensor2np(lab_tensor)*100.
90
- lab[:,:,0] = lab[:,:,0]+50
91
-
92
- rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93
- if(return_inbnd):
94
- # convert back to lab, see if we match
95
- lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96
- mask = 1.*np.isclose(lab_back,lab,atol=2.)
97
- mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98
- return (im2tensor(rgb_back),mask)
99
- else:
100
- return im2tensor(rgb_back)
101
-
102
- def rgb2lab(input):
103
- from skimage import color
104
- return color.rgb2lab(input / 255.)
105
-
106
- def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107
- image_numpy = image_tensor[0].cpu().float().numpy()
108
- image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109
- return image_numpy.astype(imtype)
110
-
111
- def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112
- return torch.Tensor((image / factor - cent)
113
- [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114
-
115
- def tensor2vec(vector_tensor):
116
- return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117
-
118
- def voc_ap(rec, prec, use_07_metric=False):
119
- """ ap = voc_ap(rec, prec, [use_07_metric])
120
- Compute VOC AP given precision and recall.
121
- If use_07_metric is true, uses the
122
- VOC 07 11 point method (default:False).
123
- """
124
- if use_07_metric:
125
- # 11 point metric
126
- ap = 0.
127
- for t in np.arange(0., 1.1, 0.1):
128
- if np.sum(rec >= t) == 0:
129
- p = 0
130
- else:
131
- p = np.max(prec[rec >= t])
132
- ap = ap + p / 11.
133
- else:
134
- # correct AP calculation
135
- # first append sentinel values at the end
136
- mrec = np.concatenate(([0.], rec, [1.]))
137
- mpre = np.concatenate(([0.], prec, [0.]))
138
-
139
- # compute the precision envelope
140
- for i in range(mpre.size - 1, 0, -1):
141
- mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142
-
143
- # to calculate area under PR curve, look for points
144
- # where X axis (recall) changes value
145
- i = np.where(mrec[1:] != mrec[:-1])[0]
146
-
147
- # and sum (\Delta recall) * prec
148
- ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149
- return ap
150
-
151
- def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152
- # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153
- image_numpy = image_tensor[0].cpu().float().numpy()
154
- image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155
- return image_numpy.astype(imtype)
156
-
157
- def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158
- # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159
- return torch.Tensor((image / factor - cent)
160
- [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lpips/base_model.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- from torch.autograd import Variable
5
- from pdb import set_trace as st
6
- from IPython import embed
7
-
8
- class BaseModel():
9
- def __init__(self):
10
- pass;
11
-
12
- def name(self):
13
- return 'BaseModel'
14
-
15
- def initialize(self, use_gpu=True, gpu_ids=[0]):
16
- self.use_gpu = use_gpu
17
- self.gpu_ids = gpu_ids
18
-
19
- def forward(self):
20
- pass
21
-
22
- def get_image_paths(self):
23
- pass
24
-
25
- def optimize_parameters(self):
26
- pass
27
-
28
- def get_current_visuals(self):
29
- return self.input
30
-
31
- def get_current_errors(self):
32
- return {}
33
-
34
- def save(self, label):
35
- pass
36
-
37
- # helper saving function that can be used by subclasses
38
- def save_network(self, network, path, network_label, epoch_label):
39
- save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
- save_path = os.path.join(path, save_filename)
41
- torch.save(network.state_dict(), save_path)
42
-
43
- # helper loading function that can be used by subclasses
44
- def load_network(self, network, network_label, epoch_label):
45
- save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
- save_path = os.path.join(self.save_dir, save_filename)
47
- print('Loading network from %s'%save_path)
48
- network.load_state_dict(torch.load(save_path))
49
-
50
- def update_learning_rate():
51
- pass
52
-
53
- def get_image_paths(self):
54
- return self.image_paths
55
-
56
- def save_done(self, flag=False):
57
- np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
- np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lpips/dist_model.py DELETED
@@ -1,284 +0,0 @@
1
-
2
- from __future__ import absolute_import
3
-
4
- import sys
5
- import numpy as np
6
- import torch
7
- from torch import nn
8
- import os
9
- from collections import OrderedDict
10
- from torch.autograd import Variable
11
- import itertools
12
- from .base_model import BaseModel
13
- from scipy.ndimage import zoom
14
- import fractions
15
- import functools
16
- import skimage.transform
17
- from tqdm import tqdm
18
-
19
- from IPython import embed
20
-
21
- from . import networks_basic as networks
22
- import lpips as util
23
-
24
- class DistModel(BaseModel):
25
- def name(self):
26
- return self.model_name
27
-
28
- def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
29
- use_gpu=True, printNet=False, spatial=False,
30
- is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
31
- '''
32
- INPUTS
33
- model - ['net-lin'] for linearly calibrated network
34
- ['net'] for off-the-shelf network
35
- ['L2'] for L2 distance in Lab colorspace
36
- ['SSIM'] for ssim in RGB colorspace
37
- net - ['squeeze','alex','vgg']
38
- model_path - if None, will look in weights/[NET_NAME].pth
39
- colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
40
- use_gpu - bool - whether or not to use a GPU
41
- printNet - bool - whether or not to print network architecture out
42
- spatial - bool - whether to output an array containing varying distances across spatial dimensions
43
- spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
44
- spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
45
- spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
46
- is_train - bool - [True] for training mode
47
- lr - float - initial learning rate
48
- beta1 - float - initial momentum term for adam
49
- version - 0.1 for latest, 0.0 was original (with a bug)
50
- gpu_ids - int array - [0] by default, gpus to use
51
- '''
52
- BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
53
-
54
- self.model = model
55
- self.net = net
56
- self.is_train = is_train
57
- self.spatial = spatial
58
- self.gpu_ids = gpu_ids
59
- self.model_name = '%s [%s]'%(model,net)
60
-
61
- if(self.model == 'net-lin'): # pretrained net + linear layer
62
- self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
63
- use_dropout=True, spatial=spatial, version=version, lpips=True)
64
- kw = {}
65
- if not use_gpu:
66
- kw['map_location'] = 'cpu'
67
- if(model_path is None):
68
- import inspect
69
- model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
70
-
71
- if(not is_train):
72
- print('Loading model from: %s'%model_path)
73
- self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
74
-
75
- elif(self.model=='net'): # pretrained network
76
- self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
77
- elif(self.model in ['L2','l2']):
78
- self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
79
- self.model_name = 'L2'
80
- elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
81
- self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
82
- self.model_name = 'SSIM'
83
- else:
84
- raise ValueError("Model [%s] not recognized." % self.model)
85
-
86
- self.parameters = list(self.net.parameters())
87
-
88
- if self.is_train: # training mode
89
- # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90
- self.rankLoss = networks.BCERankingLoss()
91
- self.parameters += list(self.rankLoss.net.parameters())
92
- self.lr = lr
93
- self.old_lr = lr
94
- self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
95
- else: # test mode
96
- self.net.eval()
97
-
98
- if(use_gpu):
99
- self.net.to(gpu_ids[0])
100
- self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
101
- if(self.is_train):
102
- self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
103
-
104
- if(printNet):
105
- print('---------- Networks initialized -------------')
106
- networks.print_network(self.net)
107
- print('-----------------------------------------------')
108
-
109
- def forward(self, in0, in1, retPerLayer=False):
110
- ''' Function computes the distance between image patches in0 and in1
111
- INPUTS
112
- in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
113
- OUTPUT
114
- computed distances between in0 and in1
115
- '''
116
-
117
- return self.net.forward(in0, in1, retPerLayer=retPerLayer)
118
-
119
- # ***** TRAINING FUNCTIONS *****
120
- def optimize_parameters(self):
121
- self.forward_train()
122
- self.optimizer_net.zero_grad()
123
- self.backward_train()
124
- self.optimizer_net.step()
125
- self.clamp_weights()
126
-
127
- def clamp_weights(self):
128
- for module in self.net.modules():
129
- if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
130
- module.weight.data = torch.clamp(module.weight.data,min=0)
131
-
132
- def set_input(self, data):
133
- self.input_ref = data['ref']
134
- self.input_p0 = data['p0']
135
- self.input_p1 = data['p1']
136
- self.input_judge = data['judge']
137
-
138
- if(self.use_gpu):
139
- self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
140
- self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
141
- self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
142
- self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
143
-
144
- self.var_ref = Variable(self.input_ref,requires_grad=True)
145
- self.var_p0 = Variable(self.input_p0,requires_grad=True)
146
- self.var_p1 = Variable(self.input_p1,requires_grad=True)
147
-
148
- def forward_train(self): # run forward pass
149
- # print(self.net.module.scaling_layer.shift)
150
- # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
151
-
152
- self.d0 = self.forward(self.var_ref, self.var_p0)
153
- self.d1 = self.forward(self.var_ref, self.var_p1)
154
- self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
155
-
156
- self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
157
-
158
- self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
159
-
160
- return self.loss_total
161
-
162
- def backward_train(self):
163
- torch.mean(self.loss_total).backward()
164
-
165
- def compute_accuracy(self,d0,d1,judge):
166
- ''' d0, d1 are Variables, judge is a Tensor '''
167
- d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
168
- judge_per = judge.cpu().numpy().flatten()
169
- return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
170
-
171
- def get_current_errors(self):
172
- retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
173
- ('acc_r', self.acc_r)])
174
-
175
- for key in retDict.keys():
176
- retDict[key] = np.mean(retDict[key])
177
-
178
- return retDict
179
-
180
- def get_current_visuals(self):
181
- zoom_factor = 256/self.var_ref.data.size()[2]
182
-
183
- ref_img = util.tensor2im(self.var_ref.data)
184
- p0_img = util.tensor2im(self.var_p0.data)
185
- p1_img = util.tensor2im(self.var_p1.data)
186
-
187
- ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
188
- p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
189
- p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
190
-
191
- return OrderedDict([('ref', ref_img_vis),
192
- ('p0', p0_img_vis),
193
- ('p1', p1_img_vis)])
194
-
195
- def save(self, path, label):
196
- if(self.use_gpu):
197
- self.save_network(self.net.module, path, '', label)
198
- else:
199
- self.save_network(self.net, path, '', label)
200
- self.save_network(self.rankLoss.net, path, 'rank', label)
201
-
202
- def update_learning_rate(self,nepoch_decay):
203
- lrd = self.lr / nepoch_decay
204
- lr = self.old_lr - lrd
205
-
206
- for param_group in self.optimizer_net.param_groups:
207
- param_group['lr'] = lr
208
-
209
- print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
210
- self.old_lr = lr
211
-
212
- def score_2afc_dataset(data_loader, func, name=''):
213
- ''' Function computes Two Alternative Forced Choice (2AFC) score using
214
- distance function 'func' in dataset 'data_loader'
215
- INPUTS
216
- data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
217
- func - callable distance function - calling d=func(in0,in1) should take 2
218
- pytorch tensors with shape Nx3xXxY, and return numpy array of length N
219
- OUTPUTS
220
- [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
221
- [1] - dictionary with following elements
222
- d0s,d1s - N arrays containing distances between reference patch to perturbed patches
223
- gts - N array in [0,1], preferred patch selected by human evaluators
224
- (closer to "0" for left patch p0, "1" for right patch p1,
225
- "0.6" means 60pct people preferred right patch, 40pct preferred left)
226
- scores - N array in [0,1], corresponding to what percentage function agreed with humans
227
- CONSTS
228
- N - number of test triplets in data_loader
229
- '''
230
-
231
- d0s = []
232
- d1s = []
233
- gts = []
234
-
235
- for data in tqdm(data_loader.load_data(), desc=name):
236
- d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
237
- d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
238
- gts+=data['judge'].cpu().numpy().flatten().tolist()
239
-
240
- d0s = np.array(d0s)
241
- d1s = np.array(d1s)
242
- gts = np.array(gts)
243
- scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
244
-
245
- return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
246
-
247
- def score_jnd_dataset(data_loader, func, name=''):
248
- ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
249
- INPUTS
250
- data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
251
- func - callable distance function - calling d=func(in0,in1) should take 2
252
- pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
253
- OUTPUTS
254
- [0] - JND score in [0,1], mAP score (area under precision-recall curve)
255
- [1] - dictionary with following elements
256
- ds - N array containing distances between two patches shown to human evaluator
257
- sames - N array containing fraction of people who thought the two patches were identical
258
- CONSTS
259
- N - number of test triplets in data_loader
260
- '''
261
-
262
- ds = []
263
- gts = []
264
-
265
- for data in tqdm(data_loader.load_data(), desc=name):
266
- ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
267
- gts+=data['same'].cpu().numpy().flatten().tolist()
268
-
269
- sames = np.array(gts)
270
- ds = np.array(ds)
271
-
272
- sorted_inds = np.argsort(ds)
273
- ds_sorted = ds[sorted_inds]
274
- sames_sorted = sames[sorted_inds]
275
-
276
- TPs = np.cumsum(sames_sorted)
277
- FPs = np.cumsum(1-sames_sorted)
278
- FNs = np.sum(sames_sorted)-TPs
279
-
280
- precs = TPs/(TPs+FPs)
281
- recs = TPs/(TPs+FNs)
282
- score = util.voc_ap(recs,precs)
283
-
284
- return(score, dict(ds=ds,sames=sames))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lpips/networks_basic.py DELETED
@@ -1,187 +0,0 @@
1
-
2
- from __future__ import absolute_import
3
-
4
- import sys
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.init as init
8
- from torch.autograd import Variable
9
- import numpy as np
10
- from pdb import set_trace as st
11
- from skimage import color
12
- from IPython import embed
13
- from . import pretrained_networks as pn
14
-
15
- import lpips as util
16
-
17
- def spatial_average(in_tens, keepdim=True):
18
- return in_tens.mean([2,3],keepdim=keepdim)
19
-
20
- def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
21
- in_H = in_tens.shape[2]
22
- scale_factor = 1.*out_H/in_H
23
-
24
- return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
25
-
26
- # Learned perceptual metric
27
- class PNetLin(nn.Module):
28
- def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
29
- super(PNetLin, self).__init__()
30
-
31
- self.pnet_type = pnet_type
32
- self.pnet_tune = pnet_tune
33
- self.pnet_rand = pnet_rand
34
- self.spatial = spatial
35
- self.lpips = lpips
36
- self.version = version
37
- self.scaling_layer = ScalingLayer()
38
-
39
- if(self.pnet_type in ['vgg','vgg16']):
40
- net_type = pn.vgg16
41
- self.chns = [64,128,256,512,512]
42
- elif(self.pnet_type=='alex'):
43
- net_type = pn.alexnet
44
- self.chns = [64,192,384,256,256]
45
- elif(self.pnet_type=='squeeze'):
46
- net_type = pn.squeezenet
47
- self.chns = [64,128,256,384,384,512,512]
48
- self.L = len(self.chns)
49
-
50
- self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
51
-
52
- if(lpips):
53
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
54
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
55
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
56
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
57
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
58
- self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
59
- if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
60
- self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
61
- self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
62
- self.lins+=[self.lin5,self.lin6]
63
-
64
- def forward(self, in0, in1, retPerLayer=False):
65
- # v0.0 - original release had a bug, where input was not scaled
66
- in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
67
- outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
68
- feats0, feats1, diffs = {}, {}, {}
69
-
70
- for kk in range(self.L):
71
- feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
72
- diffs[kk] = (feats0[kk]-feats1[kk])**2
73
-
74
- if(self.lpips):
75
- if(self.spatial):
76
- res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
77
- else:
78
- res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
79
- else:
80
- if(self.spatial):
81
- res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
82
- else:
83
- res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
84
-
85
- val = res[0]
86
- for l in range(1,self.L):
87
- val += res[l]
88
-
89
- if(retPerLayer):
90
- return (val, res)
91
- else:
92
- return val
93
-
94
- class ScalingLayer(nn.Module):
95
- def __init__(self):
96
- super(ScalingLayer, self).__init__()
97
- self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
98
- self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
99
-
100
- def forward(self, inp):
101
- return (inp - self.shift) / self.scale
102
-
103
-
104
- class NetLinLayer(nn.Module):
105
- ''' A single linear layer which does a 1x1 conv '''
106
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
107
- super(NetLinLayer, self).__init__()
108
-
109
- layers = [nn.Dropout(),] if(use_dropout) else []
110
- layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
111
- self.model = nn.Sequential(*layers)
112
-
113
-
114
- class Dist2LogitLayer(nn.Module):
115
- ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
116
- def __init__(self, chn_mid=32, use_sigmoid=True):
117
- super(Dist2LogitLayer, self).__init__()
118
-
119
- layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
120
- layers += [nn.LeakyReLU(0.2,True),]
121
- layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
122
- layers += [nn.LeakyReLU(0.2,True),]
123
- layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
124
- if(use_sigmoid):
125
- layers += [nn.Sigmoid(),]
126
- self.model = nn.Sequential(*layers)
127
-
128
- def forward(self,d0,d1,eps=0.1):
129
- return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
130
-
131
- class BCERankingLoss(nn.Module):
132
- def __init__(self, chn_mid=32):
133
- super(BCERankingLoss, self).__init__()
134
- self.net = Dist2LogitLayer(chn_mid=chn_mid)
135
- # self.parameters = list(self.net.parameters())
136
- self.loss = torch.nn.BCELoss()
137
-
138
- def forward(self, d0, d1, judge):
139
- per = (judge+1.)/2.
140
- self.logit = self.net.forward(d0,d1)
141
- return self.loss(self.logit, per)
142
-
143
- # L2, DSSIM metrics
144
- class FakeNet(nn.Module):
145
- def __init__(self, use_gpu=True, colorspace='Lab'):
146
- super(FakeNet, self).__init__()
147
- self.use_gpu = use_gpu
148
- self.colorspace=colorspace
149
-
150
- class L2(FakeNet):
151
-
152
- def forward(self, in0, in1, retPerLayer=None):
153
- assert(in0.size()[0]==1) # currently only supports batchSize 1
154
-
155
- if(self.colorspace=='RGB'):
156
- (N,C,X,Y) = in0.size()
157
- value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
158
- return value
159
- elif(self.colorspace=='Lab'):
160
- value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
161
- util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
162
- ret_var = Variable( torch.Tensor((value,) ) )
163
- if(self.use_gpu):
164
- ret_var = ret_var.cuda()
165
- return ret_var
166
-
167
- class DSSIM(FakeNet):
168
-
169
- def forward(self, in0, in1, retPerLayer=None):
170
- assert(in0.size()[0]==1) # currently only supports batchSize 1
171
-
172
- if(self.colorspace=='RGB'):
173
- value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
174
- elif(self.colorspace=='Lab'):
175
- value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
176
- util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
177
- ret_var = Variable( torch.Tensor((value,) ) )
178
- if(self.use_gpu):
179
- ret_var = ret_var.cuda()
180
- return ret_var
181
-
182
- def print_network(net):
183
- num_params = 0
184
- for param in net.parameters():
185
- num_params += param.numel()
186
- print('Network',net)
187
- print('Total number of parameters: %d' % num_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lpips/pretrained_networks.py DELETED
@@ -1,181 +0,0 @@
1
- from collections import namedtuple
2
- import torch
3
- from torchvision import models as tv
4
- from IPython import embed
5
-
6
- class squeezenet(torch.nn.Module):
7
- def __init__(self, requires_grad=False, pretrained=True):
8
- super(squeezenet, self).__init__()
9
- pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
- self.slice1 = torch.nn.Sequential()
11
- self.slice2 = torch.nn.Sequential()
12
- self.slice3 = torch.nn.Sequential()
13
- self.slice4 = torch.nn.Sequential()
14
- self.slice5 = torch.nn.Sequential()
15
- self.slice6 = torch.nn.Sequential()
16
- self.slice7 = torch.nn.Sequential()
17
- self.N_slices = 7
18
- for x in range(2):
19
- self.slice1.add_module(str(x), pretrained_features[x])
20
- for x in range(2,5):
21
- self.slice2.add_module(str(x), pretrained_features[x])
22
- for x in range(5, 8):
23
- self.slice3.add_module(str(x), pretrained_features[x])
24
- for x in range(8, 10):
25
- self.slice4.add_module(str(x), pretrained_features[x])
26
- for x in range(10, 11):
27
- self.slice5.add_module(str(x), pretrained_features[x])
28
- for x in range(11, 12):
29
- self.slice6.add_module(str(x), pretrained_features[x])
30
- for x in range(12, 13):
31
- self.slice7.add_module(str(x), pretrained_features[x])
32
- if not requires_grad:
33
- for param in self.parameters():
34
- param.requires_grad = False
35
-
36
- def forward(self, X):
37
- h = self.slice1(X)
38
- h_relu1 = h
39
- h = self.slice2(h)
40
- h_relu2 = h
41
- h = self.slice3(h)
42
- h_relu3 = h
43
- h = self.slice4(h)
44
- h_relu4 = h
45
- h = self.slice5(h)
46
- h_relu5 = h
47
- h = self.slice6(h)
48
- h_relu6 = h
49
- h = self.slice7(h)
50
- h_relu7 = h
51
- vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
- out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
-
54
- return out
55
-
56
-
57
- class alexnet(torch.nn.Module):
58
- def __init__(self, requires_grad=False, pretrained=True):
59
- super(alexnet, self).__init__()
60
- alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
- self.slice1 = torch.nn.Sequential()
62
- self.slice2 = torch.nn.Sequential()
63
- self.slice3 = torch.nn.Sequential()
64
- self.slice4 = torch.nn.Sequential()
65
- self.slice5 = torch.nn.Sequential()
66
- self.N_slices = 5
67
- for x in range(2):
68
- self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
- for x in range(2, 5):
70
- self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
- for x in range(5, 8):
72
- self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
- for x in range(8, 10):
74
- self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
- for x in range(10, 12):
76
- self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
- if not requires_grad:
78
- for param in self.parameters():
79
- param.requires_grad = False
80
-
81
- def forward(self, X):
82
- h = self.slice1(X)
83
- h_relu1 = h
84
- h = self.slice2(h)
85
- h_relu2 = h
86
- h = self.slice3(h)
87
- h_relu3 = h
88
- h = self.slice4(h)
89
- h_relu4 = h
90
- h = self.slice5(h)
91
- h_relu5 = h
92
- alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
- out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
-
95
- return out
96
-
97
- class vgg16(torch.nn.Module):
98
- def __init__(self, requires_grad=False, pretrained=True):
99
- super(vgg16, self).__init__()
100
- vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
- self.slice1 = torch.nn.Sequential()
102
- self.slice2 = torch.nn.Sequential()
103
- self.slice3 = torch.nn.Sequential()
104
- self.slice4 = torch.nn.Sequential()
105
- self.slice5 = torch.nn.Sequential()
106
- self.N_slices = 5
107
- for x in range(4):
108
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
- for x in range(4, 9):
110
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
- for x in range(9, 16):
112
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
- for x in range(16, 23):
114
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
- for x in range(23, 30):
116
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
- if not requires_grad:
118
- for param in self.parameters():
119
- param.requires_grad = False
120
-
121
- def forward(self, X):
122
- h = self.slice1(X)
123
- h_relu1_2 = h
124
- h = self.slice2(h)
125
- h_relu2_2 = h
126
- h = self.slice3(h)
127
- h_relu3_3 = h
128
- h = self.slice4(h)
129
- h_relu4_3 = h
130
- h = self.slice5(h)
131
- h_relu5_3 = h
132
- vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
-
135
- return out
136
-
137
-
138
-
139
- class resnet(torch.nn.Module):
140
- def __init__(self, requires_grad=False, pretrained=True, num=18):
141
- super(resnet, self).__init__()
142
- if(num==18):
143
- self.net = tv.resnet18(pretrained=pretrained)
144
- elif(num==34):
145
- self.net = tv.resnet34(pretrained=pretrained)
146
- elif(num==50):
147
- self.net = tv.resnet50(pretrained=pretrained)
148
- elif(num==101):
149
- self.net = tv.resnet101(pretrained=pretrained)
150
- elif(num==152):
151
- self.net = tv.resnet152(pretrained=pretrained)
152
- self.N_slices = 5
153
-
154
- self.conv1 = self.net.conv1
155
- self.bn1 = self.net.bn1
156
- self.relu = self.net.relu
157
- self.maxpool = self.net.maxpool
158
- self.layer1 = self.net.layer1
159
- self.layer2 = self.net.layer2
160
- self.layer3 = self.net.layer3
161
- self.layer4 = self.net.layer4
162
-
163
- def forward(self, X):
164
- h = self.conv1(X)
165
- h = self.bn1(h)
166
- h = self.relu(h)
167
- h_relu1 = h
168
- h = self.maxpool(h)
169
- h = self.layer1(h)
170
- h_conv2 = h
171
- h = self.layer2(h)
172
- h_conv3 = h
173
- h = self.layer3(h)
174
- h_conv4 = h
175
- h = self.layer4(h)
176
- h_conv5 = h
177
-
178
- outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
- out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
-
181
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lpips/weights/v0.0/alex.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
3
- size 5455
 
 
 
 
lpips/weights/v0.0/squeeze.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
3
- size 10057
 
 
 
 
lpips/weights/v0.0/vgg.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
3
- size 6735
 
 
 
 
lpips/weights/v0.1/alex.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
3
- size 6009
 
 
 
 
lpips/weights/v0.1/squeeze.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
3
- size 10811
 
 
 
 
lpips/weights/v0.1/vgg.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
- size 7289
 
 
 
 
model.py DELETED
@@ -1,698 +0,0 @@
1
- import math
2
- import random
3
- import functools
4
- import operator
5
-
6
- import torch
7
- from torch import nn
8
- from torch.nn import functional as F
9
- from torch.autograd import Function
10
-
11
- from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
-
13
-
14
- class PixelNorm(nn.Module):
15
- def __init__(self):
16
- super().__init__()
17
-
18
- def forward(self, input):
19
- return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20
-
21
-
22
- def make_kernel(k):
23
- k = torch.tensor(k, dtype=torch.float32)
24
-
25
- if k.ndim == 1:
26
- k = k[None, :] * k[:, None]
27
-
28
- k /= k.sum()
29
-
30
- return k
31
-
32
-
33
- class Upsample(nn.Module):
34
- def __init__(self, kernel, factor=2):
35
- super().__init__()
36
-
37
- self.factor = factor
38
- kernel = make_kernel(kernel) * (factor ** 2)
39
- self.register_buffer("kernel", kernel)
40
-
41
- p = kernel.shape[0] - factor
42
-
43
- pad0 = (p + 1) // 2 + factor - 1
44
- pad1 = p // 2
45
-
46
- self.pad = (pad0, pad1)
47
-
48
- def forward(self, input):
49
- out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50
-
51
- return out
52
-
53
-
54
- class Downsample(nn.Module):
55
- def __init__(self, kernel, factor=2):
56
- super().__init__()
57
-
58
- self.factor = factor
59
- kernel = make_kernel(kernel)
60
- self.register_buffer("kernel", kernel)
61
-
62
- p = kernel.shape[0] - factor
63
-
64
- pad0 = (p + 1) // 2
65
- pad1 = p // 2
66
-
67
- self.pad = (pad0, pad1)
68
-
69
- def forward(self, input):
70
- out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71
-
72
- return out
73
-
74
-
75
- class Blur(nn.Module):
76
- def __init__(self, kernel, pad, upsample_factor=1):
77
- super().__init__()
78
-
79
- kernel = make_kernel(kernel)
80
-
81
- if upsample_factor > 1:
82
- kernel = kernel * (upsample_factor ** 2)
83
-
84
- self.register_buffer("kernel", kernel)
85
-
86
- self.pad = pad
87
-
88
- def forward(self, input):
89
- out = upfirdn2d(input, self.kernel, pad=self.pad)
90
-
91
- return out
92
-
93
-
94
- class EqualConv2d(nn.Module):
95
- def __init__(
96
- self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97
- ):
98
- super().__init__()
99
-
100
- self.weight = nn.Parameter(
101
- torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102
- )
103
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104
-
105
- self.stride = stride
106
- self.padding = padding
107
-
108
- if bias:
109
- self.bias = nn.Parameter(torch.zeros(out_channel))
110
-
111
- else:
112
- self.bias = None
113
-
114
- def forward(self, input):
115
- out = conv2d_gradfix.conv2d(
116
- input,
117
- self.weight * self.scale,
118
- bias=self.bias,
119
- stride=self.stride,
120
- padding=self.padding,
121
- )
122
-
123
- return out
124
-
125
- def __repr__(self):
126
- return (
127
- f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128
- f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129
- )
130
-
131
-
132
- class EqualLinear(nn.Module):
133
- def __init__(
134
- self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135
- ):
136
- super().__init__()
137
-
138
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139
-
140
- if bias:
141
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142
-
143
- else:
144
- self.bias = None
145
-
146
- self.activation = activation
147
-
148
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149
- self.lr_mul = lr_mul
150
-
151
- def forward(self, input):
152
- if self.activation:
153
- out = F.linear(input, self.weight * self.scale)
154
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
155
-
156
- else:
157
- out = F.linear(
158
- input, self.weight * self.scale, bias=self.bias * self.lr_mul
159
- )
160
-
161
- return out
162
-
163
- def __repr__(self):
164
- return (
165
- f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166
- )
167
-
168
-
169
- class ModulatedConv2d(nn.Module):
170
- def __init__(
171
- self,
172
- in_channel,
173
- out_channel,
174
- kernel_size,
175
- style_dim,
176
- demodulate=True,
177
- upsample=False,
178
- downsample=False,
179
- blur_kernel=[1, 3, 3, 1],
180
- fused=True,
181
- ):
182
- super().__init__()
183
-
184
- self.eps = 1e-8
185
- self.kernel_size = kernel_size
186
- self.in_channel = in_channel
187
- self.out_channel = out_channel
188
- self.upsample = upsample
189
- self.downsample = downsample
190
-
191
- if upsample:
192
- factor = 2
193
- p = (len(blur_kernel) - factor) - (kernel_size - 1)
194
- pad0 = (p + 1) // 2 + factor - 1
195
- pad1 = p // 2 + 1
196
-
197
- self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
198
-
199
- if downsample:
200
- factor = 2
201
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
202
- pad0 = (p + 1) // 2
203
- pad1 = p // 2
204
-
205
- self.blur = Blur(blur_kernel, pad=(pad0, pad1))
206
-
207
- fan_in = in_channel * kernel_size ** 2
208
- self.scale = 1 / math.sqrt(fan_in)
209
- self.padding = kernel_size // 2
210
-
211
- self.weight = nn.Parameter(
212
- torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
213
- )
214
-
215
- self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216
-
217
- self.demodulate = demodulate
218
- self.fused = fused
219
-
220
- def __repr__(self):
221
- return (
222
- f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
223
- f"upsample={self.upsample}, downsample={self.downsample})"
224
- )
225
-
226
- def forward(self, input, style):
227
- batch, in_channel, height, width = input.shape
228
-
229
- if not self.fused:
230
- weight = self.scale * self.weight.squeeze(0)
231
- style = self.modulation(style)
232
-
233
- if self.demodulate:
234
- w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
235
- dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
236
-
237
- input = input * style.reshape(batch, in_channel, 1, 1)
238
-
239
- if self.upsample:
240
- weight = weight.transpose(0, 1)
241
- out = conv2d_gradfix.conv_transpose2d(
242
- input, weight, padding=0, stride=2
243
- )
244
- out = self.blur(out)
245
-
246
- elif self.downsample:
247
- input = self.blur(input)
248
- out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
249
-
250
- else:
251
- out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
252
-
253
- if self.demodulate:
254
- out = out * dcoefs.view(batch, -1, 1, 1)
255
-
256
- return out
257
-
258
- style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
259
- weight = self.scale * self.weight * style
260
-
261
- if self.demodulate:
262
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
263
- weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
264
-
265
- weight = weight.view(
266
- batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
267
- )
268
-
269
- if self.upsample:
270
- input = input.view(1, batch * in_channel, height, width)
271
- weight = weight.view(
272
- batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
273
- )
274
- weight = weight.transpose(1, 2).reshape(
275
- batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
276
- )
277
- out = conv2d_gradfix.conv_transpose2d(
278
- input, weight, padding=0, stride=2, groups=batch
279
- )
280
- _, _, height, width = out.shape
281
- out = out.view(batch, self.out_channel, height, width)
282
- out = self.blur(out)
283
-
284
- elif self.downsample:
285
- input = self.blur(input)
286
- _, _, height, width = input.shape
287
- input = input.view(1, batch * in_channel, height, width)
288
- out = conv2d_gradfix.conv2d(
289
- input, weight, padding=0, stride=2, groups=batch
290
- )
291
- _, _, height, width = out.shape
292
- out = out.view(batch, self.out_channel, height, width)
293
-
294
- else:
295
- input = input.view(1, batch * in_channel, height, width)
296
- out = conv2d_gradfix.conv2d(
297
- input, weight, padding=self.padding, groups=batch
298
- )
299
- _, _, height, width = out.shape
300
- out = out.view(batch, self.out_channel, height, width)
301
-
302
- return out
303
-
304
-
305
- class NoiseInjection(nn.Module):
306
- def __init__(self):
307
- super().__init__()
308
-
309
- self.weight = nn.Parameter(torch.zeros(1))
310
-
311
- def forward(self, image, noise=None):
312
- if noise is None:
313
- batch, _, height, width = image.shape
314
- noise = image.new_empty(batch, 1, height, width).normal_()
315
-
316
- return image + self.weight * noise
317
-
318
-
319
- class ConstantInput(nn.Module):
320
- def __init__(self, channel, size=4):
321
- super().__init__()
322
-
323
- self.input = nn.Parameter(torch.randn(1, channel, size, size))
324
-
325
- def forward(self, input):
326
- batch = input.shape[0]
327
- out = self.input.repeat(batch, 1, 1, 1)
328
-
329
- return out
330
-
331
-
332
- class StyledConv(nn.Module):
333
- def __init__(
334
- self,
335
- in_channel,
336
- out_channel,
337
- kernel_size,
338
- style_dim,
339
- upsample=False,
340
- blur_kernel=[1, 3, 3, 1],
341
- demodulate=True,
342
- ):
343
- super().__init__()
344
-
345
- self.conv = ModulatedConv2d(
346
- in_channel,
347
- out_channel,
348
- kernel_size,
349
- style_dim,
350
- upsample=upsample,
351
- blur_kernel=blur_kernel,
352
- demodulate=demodulate,
353
- )
354
-
355
- self.noise = NoiseInjection()
356
- # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
357
- # self.activate = ScaledLeakyReLU(0.2)
358
- self.activate = FusedLeakyReLU(out_channel)
359
-
360
- def forward(self, input, style, noise=None):
361
- out = self.conv(input, style)
362
- out = self.noise(out, noise=noise)
363
- # out = out + self.bias
364
- out = self.activate(out)
365
-
366
- return out
367
-
368
-
369
- class ToRGB(nn.Module):
370
- def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
371
- super().__init__()
372
-
373
- if upsample:
374
- self.upsample = Upsample(blur_kernel)
375
-
376
- self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
377
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
378
-
379
- def forward(self, input, style, skip=None):
380
- out = self.conv(input, style)
381
- out = out + self.bias
382
-
383
- if skip is not None:
384
- skip = self.upsample(skip)
385
-
386
- out = out + skip
387
-
388
- return out
389
-
390
-
391
- class Generator(nn.Module):
392
- def __init__(
393
- self,
394
- size,
395
- style_dim,
396
- n_mlp,
397
- channel_multiplier=2,
398
- blur_kernel=[1, 3, 3, 1],
399
- lr_mlp=0.01,
400
- ):
401
- super().__init__()
402
-
403
- self.size = size
404
-
405
- self.style_dim = style_dim
406
-
407
- layers = [PixelNorm()]
408
-
409
- for i in range(n_mlp):
410
- layers.append(
411
- EqualLinear(
412
- style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
413
- )
414
- )
415
-
416
- self.style = nn.Sequential(*layers)
417
-
418
- self.channels = {
419
- 4: 512,
420
- 8: 512,
421
- 16: 512,
422
- 32: 512,
423
- 64: 256 * channel_multiplier,
424
- 128: 128 * channel_multiplier,
425
- 256: 64 * channel_multiplier,
426
- 512: 32 * channel_multiplier,
427
- 1024: 16 * channel_multiplier,
428
- }
429
-
430
- self.input = ConstantInput(self.channels[4])
431
- self.conv1 = StyledConv(
432
- self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
433
- )
434
- self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
435
-
436
- self.log_size = int(math.log(size, 2))
437
- self.num_layers = (self.log_size - 2) * 2 + 1
438
-
439
- self.convs = nn.ModuleList()
440
- self.upsamples = nn.ModuleList()
441
- self.to_rgbs = nn.ModuleList()
442
- self.noises = nn.Module()
443
-
444
- in_channel = self.channels[4]
445
-
446
- for layer_idx in range(self.num_layers):
447
- res = (layer_idx + 5) // 2
448
- shape = [1, 1, 2 ** res, 2 ** res]
449
- self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
450
-
451
- for i in range(3, self.log_size + 1):
452
- out_channel = self.channels[2 ** i]
453
-
454
- self.convs.append(
455
- StyledConv(
456
- in_channel,
457
- out_channel,
458
- 3,
459
- style_dim,
460
- upsample=True,
461
- blur_kernel=blur_kernel,
462
- )
463
- )
464
-
465
- self.convs.append(
466
- StyledConv(
467
- out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
468
- )
469
- )
470
-
471
- self.to_rgbs.append(ToRGB(out_channel, style_dim))
472
-
473
- in_channel = out_channel
474
-
475
- self.n_latent = self.log_size * 2 - 2
476
-
477
- def make_noise(self):
478
- device = self.input.input.device
479
-
480
- noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
481
-
482
- for i in range(3, self.log_size + 1):
483
- for _ in range(2):
484
- noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
485
-
486
- return noises
487
-
488
- def mean_latent(self, n_latent):
489
- latent_in = torch.randn(
490
- n_latent, self.style_dim, device=self.input.input.device
491
- )
492
- latent = self.style(latent_in).mean(0, keepdim=True)
493
-
494
- return latent
495
-
496
- def get_latent(self, input):
497
- return self.style(input)
498
-
499
- def forward(
500
- self,
501
- styles,
502
- return_latents=False,
503
- inject_index=None,
504
- truncation=1,
505
- truncation_latent=None,
506
- input_is_latent=False,
507
- noise=None,
508
- randomize_noise=True,
509
- ):
510
- if not input_is_latent:
511
- styles = [self.style(s) for s in styles]
512
-
513
- if noise is None:
514
- if randomize_noise:
515
- noise = [None] * self.num_layers
516
- else:
517
- noise = [
518
- getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
519
- ]
520
-
521
- if truncation < 1:
522
- style_t = []
523
-
524
- for style in styles:
525
- style_t.append(
526
- truncation_latent + truncation * (style - truncation_latent)
527
- )
528
-
529
- styles = style_t
530
-
531
- if len(styles) < 2:
532
- inject_index = self.n_latent
533
-
534
- if styles[0].ndim < 3:
535
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
536
-
537
- else:
538
- latent = styles[0]
539
-
540
- else:
541
- if inject_index is None:
542
- inject_index = random.randint(1, self.n_latent - 1)
543
-
544
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
545
- latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
546
-
547
- latent = torch.cat([latent, latent2], 1)
548
-
549
- out = self.input(latent)
550
- out = self.conv1(out, latent[:, 0], noise=noise[0])
551
-
552
- skip = self.to_rgb1(out, latent[:, 1])
553
-
554
- i = 1
555
- for conv1, conv2, noise1, noise2, to_rgb in zip(
556
- self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
557
- ):
558
- out = conv1(out, latent[:, i], noise=noise1)
559
- out = conv2(out, latent[:, i + 1], noise=noise2)
560
- skip = to_rgb(out, latent[:, i + 2], skip)
561
-
562
- i += 2
563
-
564
- image = skip
565
-
566
- if return_latents:
567
- return image, latent
568
-
569
- else:
570
- return image, None
571
-
572
-
573
- class ConvLayer(nn.Sequential):
574
- def __init__(
575
- self,
576
- in_channel,
577
- out_channel,
578
- kernel_size,
579
- downsample=False,
580
- blur_kernel=[1, 3, 3, 1],
581
- bias=True,
582
- activate=True,
583
- ):
584
- layers = []
585
-
586
- if downsample:
587
- factor = 2
588
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
589
- pad0 = (p + 1) // 2
590
- pad1 = p // 2
591
-
592
- layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
593
-
594
- stride = 2
595
- self.padding = 0
596
-
597
- else:
598
- stride = 1
599
- self.padding = kernel_size // 2
600
-
601
- layers.append(
602
- EqualConv2d(
603
- in_channel,
604
- out_channel,
605
- kernel_size,
606
- padding=self.padding,
607
- stride=stride,
608
- bias=bias and not activate,
609
- )
610
- )
611
-
612
- if activate:
613
- layers.append(FusedLeakyReLU(out_channel, bias=bias))
614
-
615
- super().__init__(*layers)
616
-
617
-
618
- class ResBlock(nn.Module):
619
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
620
- super().__init__()
621
-
622
- self.conv1 = ConvLayer(in_channel, in_channel, 3)
623
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
624
-
625
- self.skip = ConvLayer(
626
- in_channel, out_channel, 1, downsample=True, activate=False, bias=False
627
- )
628
-
629
- def forward(self, input):
630
- out = self.conv1(input)
631
- out = self.conv2(out)
632
-
633
- skip = self.skip(input)
634
- out = (out + skip) / math.sqrt(2)
635
-
636
- return out
637
-
638
-
639
- class Discriminator(nn.Module):
640
- def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
641
- super().__init__()
642
-
643
- channels = {
644
- 4: 512,
645
- 8: 512,
646
- 16: 512,
647
- 32: 512,
648
- 64: 256 * channel_multiplier,
649
- 128: 128 * channel_multiplier,
650
- 256: 64 * channel_multiplier,
651
- 512: 32 * channel_multiplier,
652
- 1024: 16 * channel_multiplier,
653
- }
654
-
655
- convs = [ConvLayer(3, channels[size], 1)]
656
-
657
- log_size = int(math.log(size, 2))
658
-
659
- in_channel = channels[size]
660
-
661
- for i in range(log_size, 2, -1):
662
- out_channel = channels[2 ** (i - 1)]
663
-
664
- convs.append(ResBlock(in_channel, out_channel, blur_kernel))
665
-
666
- in_channel = out_channel
667
-
668
- self.convs = nn.Sequential(*convs)
669
-
670
- self.stddev_group = 4
671
- self.stddev_feat = 1
672
-
673
- self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
674
- self.final_linear = nn.Sequential(
675
- EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
676
- EqualLinear(channels[4], 1),
677
- )
678
-
679
- def forward(self, input):
680
- out = self.convs(input)
681
-
682
- batch, channel, height, width = out.shape
683
- group = min(batch, self.stddev_group)
684
- stddev = out.view(
685
- group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
686
- )
687
- stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
688
- stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
689
- stddev = stddev.repeat(group, 1, height, width)
690
- out = torch.cat([out, stddev], 1)
691
-
692
- out = self.final_conv(out)
693
-
694
- out = out.view(batch, -1)
695
- out = self.final_linear(out)
696
-
697
- return out
698
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
non_leaking.py DELETED
@@ -1,465 +0,0 @@
1
- import math
2
-
3
- import torch
4
- from torch import autograd
5
- from torch.nn import functional as F
6
- import numpy as np
7
-
8
- from distributed import reduce_sum
9
- from op import upfirdn2d
10
-
11
-
12
- class AdaptiveAugment:
13
- def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
14
- self.ada_aug_target = ada_aug_target
15
- self.ada_aug_len = ada_aug_len
16
- self.update_every = update_every
17
-
18
- self.ada_update = 0
19
- self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
20
- self.r_t_stat = 0
21
- self.ada_aug_p = 0
22
-
23
- @torch.no_grad()
24
- def tune(self, real_pred):
25
- self.ada_aug_buf += torch.tensor(
26
- (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
27
- device=real_pred.device,
28
- )
29
- self.ada_update += 1
30
-
31
- if self.ada_update % self.update_every == 0:
32
- self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
33
- pred_signs, n_pred = self.ada_aug_buf.tolist()
34
-
35
- self.r_t_stat = pred_signs / n_pred
36
-
37
- if self.r_t_stat > self.ada_aug_target:
38
- sign = 1
39
-
40
- else:
41
- sign = -1
42
-
43
- self.ada_aug_p += sign * n_pred / self.ada_aug_len
44
- self.ada_aug_p = min(1, max(0, self.ada_aug_p))
45
- self.ada_aug_buf.mul_(0)
46
- self.ada_update = 0
47
-
48
- return self.ada_aug_p
49
-
50
-
51
- SYM6 = (
52
- 0.015404109327027373,
53
- 0.0034907120842174702,
54
- -0.11799011114819057,
55
- -0.048311742585633,
56
- 0.4910559419267466,
57
- 0.787641141030194,
58
- 0.3379294217276218,
59
- -0.07263752278646252,
60
- -0.021060292512300564,
61
- 0.04472490177066578,
62
- 0.0017677118642428036,
63
- -0.007800708325034148,
64
- )
65
-
66
-
67
- def translate_mat(t_x, t_y, device="cpu"):
68
- batch = t_x.shape[0]
69
-
70
- mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
71
- translate = torch.stack((t_x, t_y), 1)
72
- mat[:, :2, 2] = translate
73
-
74
- return mat
75
-
76
-
77
- def rotate_mat(theta, device="cpu"):
78
- batch = theta.shape[0]
79
-
80
- mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
81
- sin_t = torch.sin(theta)
82
- cos_t = torch.cos(theta)
83
- rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
84
- mat[:, :2, :2] = rot
85
-
86
- return mat
87
-
88
-
89
- def scale_mat(s_x, s_y, device="cpu"):
90
- batch = s_x.shape[0]
91
-
92
- mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
93
- mat[:, 0, 0] = s_x
94
- mat[:, 1, 1] = s_y
95
-
96
- return mat
97
-
98
-
99
- def translate3d_mat(t_x, t_y, t_z):
100
- batch = t_x.shape[0]
101
-
102
- mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
103
- translate = torch.stack((t_x, t_y, t_z), 1)
104
- mat[:, :3, 3] = translate
105
-
106
- return mat
107
-
108
-
109
- def rotate3d_mat(axis, theta):
110
- batch = theta.shape[0]
111
-
112
- u_x, u_y, u_z = axis
113
-
114
- eye = torch.eye(3).unsqueeze(0)
115
- cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
116
- outer = torch.tensor(axis)
117
- outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
118
-
119
- sin_t = torch.sin(theta).view(-1, 1, 1)
120
- cos_t = torch.cos(theta).view(-1, 1, 1)
121
-
122
- rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
123
-
124
- eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
125
- eye_4[:, :3, :3] = rot
126
-
127
- return eye_4
128
-
129
-
130
- def scale3d_mat(s_x, s_y, s_z):
131
- batch = s_x.shape[0]
132
-
133
- mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
134
- mat[:, 0, 0] = s_x
135
- mat[:, 1, 1] = s_y
136
- mat[:, 2, 2] = s_z
137
-
138
- return mat
139
-
140
-
141
- def luma_flip_mat(axis, i):
142
- batch = i.shape[0]
143
-
144
- eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
145
- axis = torch.tensor(axis + (0,))
146
- flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
147
-
148
- return eye - flip
149
-
150
-
151
- def saturation_mat(axis, i):
152
- batch = i.shape[0]
153
-
154
- eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
155
- axis = torch.tensor(axis + (0,))
156
- axis = torch.ger(axis, axis)
157
- saturate = axis + (eye - axis) * i.view(-1, 1, 1)
158
-
159
- return saturate
160
-
161
-
162
- def lognormal_sample(size, mean=0, std=1, device="cpu"):
163
- return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
164
-
165
-
166
- def category_sample(size, categories, device="cpu"):
167
- category = torch.tensor(categories, device=device)
168
- sample = torch.randint(high=len(categories), size=(size,), device=device)
169
-
170
- return category[sample]
171
-
172
-
173
- def uniform_sample(size, low, high, device="cpu"):
174
- return torch.empty(size, device=device).uniform_(low, high)
175
-
176
-
177
- def normal_sample(size, mean=0, std=1, device="cpu"):
178
- return torch.empty(size, device=device).normal_(mean, std)
179
-
180
-
181
- def bernoulli_sample(size, p, device="cpu"):
182
- return torch.empty(size, device=device).bernoulli_(p)
183
-
184
-
185
- def random_mat_apply(p, transform, prev, eye, device="cpu"):
186
- size = transform.shape[0]
187
- select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
188
- select_transform = select * transform + (1 - select) * eye
189
-
190
- return select_transform @ prev
191
-
192
-
193
- def sample_affine(p, size, height, width, device="cpu"):
194
- G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
195
- eye = G
196
-
197
- # flip
198
- param = category_sample(size, (0, 1))
199
- Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
200
- G = random_mat_apply(p, Gc, G, eye, device=device)
201
- # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
202
-
203
- # 90 rotate
204
- param = category_sample(size, (0, 3))
205
- Gc = rotate_mat(-math.pi / 2 * param, device=device)
206
- G = random_mat_apply(p, Gc, G, eye, device=device)
207
- # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
208
-
209
- # integer translate
210
- param = uniform_sample((2, size), -0.125, 0.125)
211
- param_height = torch.round(param[0] * height)
212
- param_width = torch.round(param[1] * width)
213
- Gc = translate_mat(param_width, param_height, device=device)
214
- G = random_mat_apply(p, Gc, G, eye, device=device)
215
- # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
216
-
217
- # isotropic scale
218
- param = lognormal_sample(size, std=0.2 * math.log(2))
219
- Gc = scale_mat(param, param, device=device)
220
- G = random_mat_apply(p, Gc, G, eye, device=device)
221
- # print('isotropic scale', G, scale_mat(param, param), sep='\n')
222
-
223
- p_rot = 1 - math.sqrt(1 - p)
224
-
225
- # pre-rotate
226
- param = uniform_sample(size, -math.pi, math.pi)
227
- Gc = rotate_mat(-param, device=device)
228
- G = random_mat_apply(p_rot, Gc, G, eye, device=device)
229
- # print('pre-rotate', G, rotate_mat(-param), sep='\n')
230
-
231
- # anisotropic scale
232
- param = lognormal_sample(size, std=0.2 * math.log(2))
233
- Gc = scale_mat(param, 1 / param, device=device)
234
- G = random_mat_apply(p, Gc, G, eye, device=device)
235
- # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
236
-
237
- # post-rotate
238
- param = uniform_sample(size, -math.pi, math.pi)
239
- Gc = rotate_mat(-param, device=device)
240
- G = random_mat_apply(p_rot, Gc, G, eye, device=device)
241
- # print('post-rotate', G, rotate_mat(-param), sep='\n')
242
-
243
- # fractional translate
244
- param = normal_sample((2, size), std=0.125)
245
- Gc = translate_mat(param[1] * width, param[0] * height, device=device)
246
- G = random_mat_apply(p, Gc, G, eye, device=device)
247
- # print('fractional translate', G, translate_mat(param, param), sep='\n')
248
-
249
- return G
250
-
251
-
252
- def sample_color(p, size):
253
- C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
254
- eye = C
255
- axis_val = 1 / math.sqrt(3)
256
- axis = (axis_val, axis_val, axis_val)
257
-
258
- # brightness
259
- param = normal_sample(size, std=0.2)
260
- Cc = translate3d_mat(param, param, param)
261
- C = random_mat_apply(p, Cc, C, eye)
262
-
263
- # contrast
264
- param = lognormal_sample(size, std=0.5 * math.log(2))
265
- Cc = scale3d_mat(param, param, param)
266
- C = random_mat_apply(p, Cc, C, eye)
267
-
268
- # luma flip
269
- param = category_sample(size, (0, 1))
270
- Cc = luma_flip_mat(axis, param)
271
- C = random_mat_apply(p, Cc, C, eye)
272
-
273
- # hue rotation
274
- param = uniform_sample(size, -math.pi, math.pi)
275
- Cc = rotate3d_mat(axis, param)
276
- C = random_mat_apply(p, Cc, C, eye)
277
-
278
- # saturation
279
- param = lognormal_sample(size, std=1 * math.log(2))
280
- Cc = saturation_mat(axis, param)
281
- C = random_mat_apply(p, Cc, C, eye)
282
-
283
- return C
284
-
285
-
286
- def make_grid(shape, x0, x1, y0, y1, device):
287
- n, c, h, w = shape
288
- grid = torch.empty(n, h, w, 3, device=device)
289
- grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
290
- grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
291
- grid[:, :, :, 2] = 1
292
-
293
- return grid
294
-
295
-
296
- def affine_grid(grid, mat):
297
- n, h, w, _ = grid.shape
298
- return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
299
-
300
-
301
- def get_padding(G, height, width, kernel_size):
302
- device = G.device
303
-
304
- cx = (width - 1) / 2
305
- cy = (height - 1) / 2
306
- cp = torch.tensor(
307
- [(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
308
- )
309
- cp = G @ cp.T
310
-
311
- pad_k = kernel_size // 4
312
-
313
- pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
314
- pad = torch.cat((-pad, pad)).max(1).values
315
- pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
316
- pad = pad.max(torch.tensor([0, 0] * 2, device=device))
317
- pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
318
-
319
- pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
320
-
321
- return pad_x1, pad_x2, pad_y1, pad_y2
322
-
323
-
324
- def try_sample_affine_and_pad(img, p, kernel_size, G=None):
325
- batch, _, height, width = img.shape
326
-
327
- G_try = G
328
-
329
- if G is None:
330
- G_try = torch.inverse(sample_affine(p, batch, height, width))
331
-
332
- pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
333
-
334
- img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
335
-
336
- return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
337
-
338
-
339
- class GridSampleForward(autograd.Function):
340
- @staticmethod
341
- def forward(ctx, input, grid):
342
- out = F.grid_sample(
343
- input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
344
- )
345
- ctx.save_for_backward(input, grid)
346
-
347
- return out
348
-
349
- @staticmethod
350
- def backward(ctx, grad_output):
351
- input, grid = ctx.saved_tensors
352
- grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
353
-
354
- return grad_input, grad_grid
355
-
356
-
357
- class GridSampleBackward(autograd.Function):
358
- @staticmethod
359
- def forward(ctx, grad_output, input, grid):
360
- op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
361
- grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
362
- ctx.save_for_backward(grid)
363
-
364
- return grad_input, grad_grid
365
-
366
- @staticmethod
367
- def backward(ctx, grad_grad_input, grad_grad_grid):
368
- (grid,) = ctx.saved_tensors
369
- grad_grad_output = None
370
-
371
- if ctx.needs_input_grad[0]:
372
- grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
373
-
374
- return grad_grad_output, None, None
375
-
376
-
377
- grid_sample = GridSampleForward.apply
378
-
379
-
380
- def scale_mat_single(s_x, s_y):
381
- return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
382
-
383
-
384
- def translate_mat_single(t_x, t_y):
385
- return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
386
-
387
-
388
- def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
389
- kernel = antialiasing_kernel
390
- len_k = len(kernel)
391
-
392
- kernel = torch.as_tensor(kernel).to(img)
393
- # kernel = torch.ger(kernel, kernel).to(img)
394
- kernel_flip = torch.flip(kernel, (0,))
395
-
396
- img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
397
- img, p, len_k, G
398
- )
399
-
400
- G_inv = (
401
- translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
402
- @ G
403
- )
404
- up_pad = (
405
- (len_k + 2 - 1) // 2,
406
- (len_k - 2) // 2,
407
- (len_k + 2 - 1) // 2,
408
- (len_k - 2) // 2,
409
- )
410
- img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
411
- img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
412
- G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
413
- G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
414
- batch_size, channel, height, width = img.shape
415
- pad_k = len_k // 4
416
- shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
417
- G_inv = (
418
- scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
419
- @ G_inv
420
- @ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
421
- )
422
- grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
423
- img_affine = grid_sample(img_2x, grid)
424
- d_p = -pad_k * 2
425
- down_pad = (
426
- d_p + (len_k - 2 + 1) // 2,
427
- d_p + (len_k - 2) // 2,
428
- d_p + (len_k - 2 + 1) // 2,
429
- d_p + (len_k - 2) // 2,
430
- )
431
- img_down = upfirdn2d(
432
- img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
433
- )
434
- img_down = upfirdn2d(
435
- img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
436
- )
437
-
438
- return img_down, G
439
-
440
-
441
- def apply_color(img, mat):
442
- batch = img.shape[0]
443
- img = img.permute(0, 2, 3, 1)
444
- mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
445
- mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
446
- img = img @ mat_mul + mat_add
447
- img = img.permute(0, 3, 1, 2)
448
-
449
- return img
450
-
451
-
452
- def random_apply_color(img, p, C=None):
453
- if C is None:
454
- C = sample_color(p, img.shape[0])
455
-
456
- img = apply_color(img, C.to(img))
457
-
458
- return img, C
459
-
460
-
461
- def augment(img, p, transform_matrix=(None, None)):
462
- img, G = random_apply_affine(img, p, transform_matrix[0])
463
- img, C = random_apply_color(img, p, transform_matrix[1])
464
-
465
- return img, (G, C)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
- from .upfirdn2d import upfirdn2d
 
 
 
op/conv2d_gradfix.py DELETED
@@ -1,227 +0,0 @@
1
- import contextlib
2
- import warnings
3
-
4
- import torch
5
- from torch import autograd
6
- from torch.nn import functional as F
7
-
8
- enabled = True
9
- weight_gradients_disabled = False
10
-
11
-
12
- @contextlib.contextmanager
13
- def no_weight_gradients():
14
- global weight_gradients_disabled
15
-
16
- old = weight_gradients_disabled
17
- weight_gradients_disabled = True
18
- yield
19
- weight_gradients_disabled = old
20
-
21
-
22
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
- if could_use_op(input):
24
- return conv2d_gradfix(
25
- transpose=False,
26
- weight_shape=weight.shape,
27
- stride=stride,
28
- padding=padding,
29
- output_padding=0,
30
- dilation=dilation,
31
- groups=groups,
32
- ).apply(input, weight, bias)
33
-
34
- return F.conv2d(
35
- input=input,
36
- weight=weight,
37
- bias=bias,
38
- stride=stride,
39
- padding=padding,
40
- dilation=dilation,
41
- groups=groups,
42
- )
43
-
44
-
45
- def conv_transpose2d(
46
- input,
47
- weight,
48
- bias=None,
49
- stride=1,
50
- padding=0,
51
- output_padding=0,
52
- groups=1,
53
- dilation=1,
54
- ):
55
- if could_use_op(input):
56
- return conv2d_gradfix(
57
- transpose=True,
58
- weight_shape=weight.shape,
59
- stride=stride,
60
- padding=padding,
61
- output_padding=output_padding,
62
- groups=groups,
63
- dilation=dilation,
64
- ).apply(input, weight, bias)
65
-
66
- return F.conv_transpose2d(
67
- input=input,
68
- weight=weight,
69
- bias=bias,
70
- stride=stride,
71
- padding=padding,
72
- output_padding=output_padding,
73
- dilation=dilation,
74
- groups=groups,
75
- )
76
-
77
-
78
- def could_use_op(input):
79
- if (not enabled) or (not torch.backends.cudnn.enabled):
80
- return False
81
-
82
- if input.device.type != "cuda":
83
- return False
84
-
85
- if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
- return True
87
-
88
- warnings.warn(
89
- f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
- )
91
-
92
- return False
93
-
94
-
95
- def ensure_tuple(xs, ndim):
96
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
-
98
- return xs
99
-
100
-
101
- conv2d_gradfix_cache = dict()
102
-
103
-
104
- def conv2d_gradfix(
105
- transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
- ):
107
- ndim = 2
108
- weight_shape = tuple(weight_shape)
109
- stride = ensure_tuple(stride, ndim)
110
- padding = ensure_tuple(padding, ndim)
111
- output_padding = ensure_tuple(output_padding, ndim)
112
- dilation = ensure_tuple(dilation, ndim)
113
-
114
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
- if key in conv2d_gradfix_cache:
116
- return conv2d_gradfix_cache[key]
117
-
118
- common_kwargs = dict(
119
- stride=stride, padding=padding, dilation=dilation, groups=groups
120
- )
121
-
122
- def calc_output_padding(input_shape, output_shape):
123
- if transpose:
124
- return [0, 0]
125
-
126
- return [
127
- input_shape[i + 2]
128
- - (output_shape[i + 2] - 1) * stride[i]
129
- - (1 - 2 * padding[i])
130
- - dilation[i] * (weight_shape[i + 2] - 1)
131
- for i in range(ndim)
132
- ]
133
-
134
- class Conv2d(autograd.Function):
135
- @staticmethod
136
- def forward(ctx, input, weight, bias):
137
- if not transpose:
138
- out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
-
140
- else:
141
- out = F.conv_transpose2d(
142
- input=input,
143
- weight=weight,
144
- bias=bias,
145
- output_padding=output_padding,
146
- **common_kwargs,
147
- )
148
-
149
- ctx.save_for_backward(input, weight)
150
-
151
- return out
152
-
153
- @staticmethod
154
- def backward(ctx, grad_output):
155
- input, weight = ctx.saved_tensors
156
- grad_input, grad_weight, grad_bias = None, None, None
157
-
158
- if ctx.needs_input_grad[0]:
159
- p = calc_output_padding(
160
- input_shape=input.shape, output_shape=grad_output.shape
161
- )
162
- grad_input = conv2d_gradfix(
163
- transpose=(not transpose),
164
- weight_shape=weight_shape,
165
- output_padding=p,
166
- **common_kwargs,
167
- ).apply(grad_output, weight, None)
168
-
169
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
-
172
- if ctx.needs_input_grad[2]:
173
- grad_bias = grad_output.sum((0, 2, 3))
174
-
175
- return grad_input, grad_weight, grad_bias
176
-
177
- class Conv2dGradWeight(autograd.Function):
178
- @staticmethod
179
- def forward(ctx, grad_output, input):
180
- op = torch._C._jit_get_operation(
181
- "aten::cudnn_convolution_backward_weight"
182
- if not transpose
183
- else "aten::cudnn_convolution_transpose_backward_weight"
184
- )
185
- flags = [
186
- torch.backends.cudnn.benchmark,
187
- torch.backends.cudnn.deterministic,
188
- torch.backends.cudnn.allow_tf32,
189
- ]
190
- grad_weight = op(
191
- weight_shape,
192
- grad_output,
193
- input,
194
- padding,
195
- stride,
196
- dilation,
197
- groups,
198
- *flags,
199
- )
200
- ctx.save_for_backward(grad_output, input)
201
-
202
- return grad_weight
203
-
204
- @staticmethod
205
- def backward(ctx, grad_grad_weight):
206
- grad_output, input = ctx.saved_tensors
207
- grad_grad_output, grad_grad_input = None, None
208
-
209
- if ctx.needs_input_grad[0]:
210
- grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
-
212
- if ctx.needs_input_grad[1]:
213
- p = calc_output_padding(
214
- input_shape=input.shape, output_shape=grad_output.shape
215
- )
216
- grad_grad_input = conv2d_gradfix(
217
- transpose=(not transpose),
218
- weight_shape=weight_shape,
219
- output_padding=p,
220
- **common_kwargs,
221
- ).apply(grad_output, grad_grad_weight, None)
222
-
223
- return grad_grad_output, grad_grad_input
224
-
225
- conv2d_gradfix_cache[key] = Conv2d
226
-
227
- return Conv2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_act.py DELETED
@@ -1,127 +0,0 @@
1
- import os
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from torch.autograd import Function
7
- from torch.utils.cpp_extension import load
8
-
9
-
10
- module_path = os.path.dirname(__file__)
11
- fused = load(
12
- "fused",
13
- sources=[
14
- os.path.join(module_path, "fused_bias_act.cpp"),
15
- os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
- ],
17
- )
18
-
19
-
20
- class FusedLeakyReLUFunctionBackward(Function):
21
- @staticmethod
22
- def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
- ctx.save_for_backward(out)
24
- ctx.negative_slope = negative_slope
25
- ctx.scale = scale
26
-
27
- empty = grad_output.new_empty(0)
28
-
29
- grad_input = fused.fused_bias_act(
30
- grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
- )
32
-
33
- dim = [0]
34
-
35
- if grad_input.ndim > 2:
36
- dim += list(range(2, grad_input.ndim))
37
-
38
- if bias:
39
- grad_bias = grad_input.sum(dim).detach()
40
-
41
- else:
42
- grad_bias = empty
43
-
44
- return grad_input, grad_bias
45
-
46
- @staticmethod
47
- def backward(ctx, gradgrad_input, gradgrad_bias):
48
- out, = ctx.saved_tensors
49
- gradgrad_out = fused.fused_bias_act(
50
- gradgrad_input.contiguous(),
51
- gradgrad_bias,
52
- out,
53
- 3,
54
- 1,
55
- ctx.negative_slope,
56
- ctx.scale,
57
- )
58
-
59
- return gradgrad_out, None, None, None, None
60
-
61
-
62
- class FusedLeakyReLUFunction(Function):
63
- @staticmethod
64
- def forward(ctx, input, bias, negative_slope, scale):
65
- empty = input.new_empty(0)
66
-
67
- ctx.bias = bias is not None
68
-
69
- if bias is None:
70
- bias = empty
71
-
72
- out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73
- ctx.save_for_backward(out)
74
- ctx.negative_slope = negative_slope
75
- ctx.scale = scale
76
-
77
- return out
78
-
79
- @staticmethod
80
- def backward(ctx, grad_output):
81
- out, = ctx.saved_tensors
82
-
83
- grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84
- grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85
- )
86
-
87
- if not ctx.bias:
88
- grad_bias = None
89
-
90
- return grad_input, grad_bias, None, None
91
-
92
-
93
- class FusedLeakyReLU(nn.Module):
94
- def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95
- super().__init__()
96
-
97
- if bias:
98
- self.bias = nn.Parameter(torch.zeros(channel))
99
-
100
- else:
101
- self.bias = None
102
-
103
- self.negative_slope = negative_slope
104
- self.scale = scale
105
-
106
- def forward(self, input):
107
- return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108
-
109
-
110
- def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111
- if input.device.type == "cpu":
112
- if bias is not None:
113
- rest_dim = [1] * (input.ndim - bias.ndim - 1)
114
- return (
115
- F.leaky_relu(
116
- input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117
- )
118
- * scale
119
- )
120
-
121
- else:
122
- return F.leaky_relu(input, negative_slope=0.2) * scale
123
-
124
- else:
125
- return FusedLeakyReLUFunction.apply(
126
- input.contiguous(), bias, negative_slope, scale
127
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_bias_act.cpp DELETED
@@ -1,32 +0,0 @@
1
-
2
- #include <ATen/ATen.h>
3
- #include <torch/extension.h>
4
-
5
- torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
- const torch::Tensor &bias,
7
- const torch::Tensor &refer, int act, int grad,
8
- float alpha, float scale);
9
-
10
- #define CHECK_CUDA(x) \
11
- TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
- #define CHECK_CONTIGUOUS(x) \
13
- TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
- #define CHECK_INPUT(x) \
15
- CHECK_CUDA(x); \
16
- CHECK_CONTIGUOUS(x)
17
-
18
- torch::Tensor fused_bias_act(const torch::Tensor &input,
19
- const torch::Tensor &bias,
20
- const torch::Tensor &refer, int act, int grad,
21
- float alpha, float scale) {
22
- CHECK_INPUT(input);
23
- CHECK_INPUT(bias);
24
-
25
- at::DeviceGuard guard(input.device());
26
-
27
- return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
- }
29
-
30
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
- m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_bias_act_kernel.cu DELETED
@@ -1,105 +0,0 @@
1
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
- //
3
- // This work is made available under the Nvidia Source Code License-NC.
4
- // To view a copy of this license, visit
5
- // https://nvlabs.github.io/stylegan2/license.html
6
-
7
- #include <torch/types.h>
8
-
9
- #include <ATen/ATen.h>
10
- #include <ATen/AccumulateType.h>
11
- #include <ATen/cuda/CUDAApplyUtils.cuh>
12
- #include <ATen/cuda/CUDAContext.h>
13
-
14
-
15
- #include <cuda.h>
16
- #include <cuda_runtime.h>
17
-
18
- template <typename scalar_t>
19
- static __global__ void
20
- fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
- const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
- scalar_t scale, int loop_x, int size_x, int step_b,
23
- int size_b, int use_bias, int use_ref) {
24
- int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
-
26
- scalar_t zero = 0.0;
27
-
28
- for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
- loop_idx++, xi += blockDim.x) {
30
- scalar_t x = p_x[xi];
31
-
32
- if (use_bias) {
33
- x += p_b[(xi / step_b) % size_b];
34
- }
35
-
36
- scalar_t ref = use_ref ? p_ref[xi] : zero;
37
-
38
- scalar_t y;
39
-
40
- switch (act * 10 + grad) {
41
- default:
42
- case 10:
43
- y = x;
44
- break;
45
- case 11:
46
- y = x;
47
- break;
48
- case 12:
49
- y = 0.0;
50
- break;
51
-
52
- case 30:
53
- y = (x > 0.0) ? x : x * alpha;
54
- break;
55
- case 31:
56
- y = (ref > 0.0) ? x : x * alpha;
57
- break;
58
- case 32:
59
- y = 0.0;
60
- break;
61
- }
62
-
63
- out[xi] = y * scale;
64
- }
65
- }
66
-
67
- torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
- const torch::Tensor &bias,
69
- const torch::Tensor &refer, int act, int grad,
70
- float alpha, float scale) {
71
- int curDevice = -1;
72
- cudaGetDevice(&curDevice);
73
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
-
75
- auto x = input.contiguous();
76
- auto b = bias.contiguous();
77
- auto ref = refer.contiguous();
78
-
79
- int use_bias = b.numel() ? 1 : 0;
80
- int use_ref = ref.numel() ? 1 : 0;
81
-
82
- int size_x = x.numel();
83
- int size_b = b.numel();
84
- int step_b = 1;
85
-
86
- for (int i = 1 + 1; i < x.dim(); i++) {
87
- step_b *= x.size(i);
88
- }
89
-
90
- int loop_x = 4;
91
- int block_size = 4 * 32;
92
- int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
-
94
- auto y = torch::empty_like(x);
95
-
96
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
- x.scalar_type(), "fused_bias_act_kernel", [&] {
98
- fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
- y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
- b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
- scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
- });
103
-
104
- return y;
105
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d.cpp DELETED
@@ -1,31 +0,0 @@
1
- #include <ATen/ATen.h>
2
- #include <torch/extension.h>
3
-
4
- torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
- const torch::Tensor &kernel, int up_x, int up_y,
6
- int down_x, int down_y, int pad_x0, int pad_x1,
7
- int pad_y0, int pad_y1);
8
-
9
- #define CHECK_CUDA(x) \
10
- TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
- #define CHECK_CONTIGUOUS(x) \
12
- TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
- #define CHECK_INPUT(x) \
14
- CHECK_CUDA(x); \
15
- CHECK_CONTIGUOUS(x)
16
-
17
- torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
- int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
- int pad_x1, int pad_y0, int pad_y1) {
20
- CHECK_INPUT(input);
21
- CHECK_INPUT(kernel);
22
-
23
- at::DeviceGuard guard(input.device());
24
-
25
- return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
- pad_y0, pad_y1);
27
- }
28
-
29
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
- m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d.py DELETED
@@ -1,209 +0,0 @@
1
- from collections import abc
2
- import os
3
-
4
- import torch
5
- from torch.nn import functional as F
6
- from torch.autograd import Function
7
- from torch.utils.cpp_extension import load
8
-
9
-
10
- module_path = os.path.dirname(__file__)
11
- upfirdn2d_op = load(
12
- "upfirdn2d",
13
- sources=[
14
- os.path.join(module_path, "upfirdn2d.cpp"),
15
- os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
- ],
17
- )
18
-
19
-
20
- class UpFirDn2dBackward(Function):
21
- @staticmethod
22
- def forward(
23
- ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
- ):
25
-
26
- up_x, up_y = up
27
- down_x, down_y = down
28
- g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
-
30
- grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
-
32
- grad_input = upfirdn2d_op.upfirdn2d(
33
- grad_output,
34
- grad_kernel,
35
- down_x,
36
- down_y,
37
- up_x,
38
- up_y,
39
- g_pad_x0,
40
- g_pad_x1,
41
- g_pad_y0,
42
- g_pad_y1,
43
- )
44
- grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
-
46
- ctx.save_for_backward(kernel)
47
-
48
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
-
50
- ctx.up_x = up_x
51
- ctx.up_y = up_y
52
- ctx.down_x = down_x
53
- ctx.down_y = down_y
54
- ctx.pad_x0 = pad_x0
55
- ctx.pad_x1 = pad_x1
56
- ctx.pad_y0 = pad_y0
57
- ctx.pad_y1 = pad_y1
58
- ctx.in_size = in_size
59
- ctx.out_size = out_size
60
-
61
- return grad_input
62
-
63
- @staticmethod
64
- def backward(ctx, gradgrad_input):
65
- kernel, = ctx.saved_tensors
66
-
67
- gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
-
69
- gradgrad_out = upfirdn2d_op.upfirdn2d(
70
- gradgrad_input,
71
- kernel,
72
- ctx.up_x,
73
- ctx.up_y,
74
- ctx.down_x,
75
- ctx.down_y,
76
- ctx.pad_x0,
77
- ctx.pad_x1,
78
- ctx.pad_y0,
79
- ctx.pad_y1,
80
- )
81
- # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
- gradgrad_out = gradgrad_out.view(
83
- ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
- )
85
-
86
- return gradgrad_out, None, None, None, None, None, None, None, None
87
-
88
-
89
- class UpFirDn2d(Function):
90
- @staticmethod
91
- def forward(ctx, input, kernel, up, down, pad):
92
- up_x, up_y = up
93
- down_x, down_y = down
94
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
-
96
- kernel_h, kernel_w = kernel.shape
97
- batch, channel, in_h, in_w = input.shape
98
- ctx.in_size = input.shape
99
-
100
- input = input.reshape(-1, in_h, in_w, 1)
101
-
102
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
-
104
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
- ctx.out_size = (out_h, out_w)
107
-
108
- ctx.up = (up_x, up_y)
109
- ctx.down = (down_x, down_y)
110
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
-
112
- g_pad_x0 = kernel_w - pad_x0 - 1
113
- g_pad_y0 = kernel_h - pad_y0 - 1
114
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
-
117
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
-
119
- out = upfirdn2d_op.upfirdn2d(
120
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
- )
122
- # out = out.view(major, out_h, out_w, minor)
123
- out = out.view(-1, channel, out_h, out_w)
124
-
125
- return out
126
-
127
- @staticmethod
128
- def backward(ctx, grad_output):
129
- kernel, grad_kernel = ctx.saved_tensors
130
-
131
- grad_input = None
132
-
133
- if ctx.needs_input_grad[0]:
134
- grad_input = UpFirDn2dBackward.apply(
135
- grad_output,
136
- kernel,
137
- grad_kernel,
138
- ctx.up,
139
- ctx.down,
140
- ctx.pad,
141
- ctx.g_pad,
142
- ctx.in_size,
143
- ctx.out_size,
144
- )
145
-
146
- return grad_input, None, None, None, None
147
-
148
-
149
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
- if not isinstance(up, abc.Iterable):
151
- up = (up, up)
152
-
153
- if not isinstance(down, abc.Iterable):
154
- down = (down, down)
155
-
156
- if len(pad) == 2:
157
- pad = (pad[0], pad[1], pad[0], pad[1])
158
-
159
- if input.device.type == "cpu":
160
- out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
-
162
- else:
163
- out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
-
165
- return out
166
-
167
-
168
- def upfirdn2d_native(
169
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
- ):
171
- _, channel, in_h, in_w = input.shape
172
- input = input.reshape(-1, in_h, in_w, 1)
173
-
174
- _, in_h, in_w, minor = input.shape
175
- kernel_h, kernel_w = kernel.shape
176
-
177
- out = input.view(-1, in_h, 1, in_w, 1, minor)
178
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
-
181
- out = F.pad(
182
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
- )
184
- out = out[
185
- :,
186
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
- :,
189
- ]
190
-
191
- out = out.permute(0, 3, 1, 2)
192
- out = out.reshape(
193
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
- )
195
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
- out = F.conv2d(out, w)
197
- out = out.reshape(
198
- -1,
199
- minor,
200
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
- )
203
- out = out.permute(0, 2, 3, 1)
204
- out = out[:, ::down_y, ::down_x, :]
205
-
206
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
-
209
- return out.view(-1, channel, out_h, out_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d_kernel.cu DELETED
@@ -1,369 +0,0 @@
1
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
- //
3
- // This work is made available under the Nvidia Source Code License-NC.
4
- // To view a copy of this license, visit
5
- // https://nvlabs.github.io/stylegan2/license.html
6
-
7
- #include <torch/types.h>
8
-
9
- #include <ATen/ATen.h>
10
- #include <ATen/AccumulateType.h>
11
- #include <ATen/cuda/CUDAApplyUtils.cuh>
12
- #include <ATen/cuda/CUDAContext.h>
13
-
14
- #include <cuda.h>
15
- #include <cuda_runtime.h>
16
-
17
- static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
- int c = a / b;
19
-
20
- if (c * b > a) {
21
- c--;
22
- }
23
-
24
- return c;
25
- }
26
-
27
- struct UpFirDn2DKernelParams {
28
- int up_x;
29
- int up_y;
30
- int down_x;
31
- int down_y;
32
- int pad_x0;
33
- int pad_x1;
34
- int pad_y0;
35
- int pad_y1;
36
-
37
- int major_dim;
38
- int in_h;
39
- int in_w;
40
- int minor_dim;
41
- int kernel_h;
42
- int kernel_w;
43
- int out_h;
44
- int out_w;
45
- int loop_major;
46
- int loop_x;
47
- };
48
-
49
- template <typename scalar_t>
50
- __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
- const scalar_t *kernel,
52
- const UpFirDn2DKernelParams p) {
53
- int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
- int out_y = minor_idx / p.minor_dim;
55
- minor_idx -= out_y * p.minor_dim;
56
- int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
- int major_idx_base = blockIdx.z * p.loop_major;
58
-
59
- if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
- major_idx_base >= p.major_dim) {
61
- return;
62
- }
63
-
64
- int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
- int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
- int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
- int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
-
69
- for (int loop_major = 0, major_idx = major_idx_base;
70
- loop_major < p.loop_major && major_idx < p.major_dim;
71
- loop_major++, major_idx++) {
72
- for (int loop_x = 0, out_x = out_x_base;
73
- loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
- int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
- int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
- int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
- int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
-
79
- const scalar_t *x_p =
80
- &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
- minor_idx];
82
- const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
- int x_px = p.minor_dim;
84
- int k_px = -p.up_x;
85
- int x_py = p.in_w * p.minor_dim;
86
- int k_py = -p.up_y * p.kernel_w;
87
-
88
- scalar_t v = 0.0f;
89
-
90
- for (int y = 0; y < h; y++) {
91
- for (int x = 0; x < w; x++) {
92
- v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
- x_p += x_px;
94
- k_p += k_px;
95
- }
96
-
97
- x_p += x_py - w * x_px;
98
- k_p += k_py - w * k_px;
99
- }
100
-
101
- out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
- minor_idx] = v;
103
- }
104
- }
105
- }
106
-
107
- template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
- int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
- __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
- const scalar_t *kernel,
111
- const UpFirDn2DKernelParams p) {
112
- const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
- const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
-
115
- __shared__ volatile float sk[kernel_h][kernel_w];
116
- __shared__ volatile float sx[tile_in_h][tile_in_w];
117
-
118
- int minor_idx = blockIdx.x;
119
- int tile_out_y = minor_idx / p.minor_dim;
120
- minor_idx -= tile_out_y * p.minor_dim;
121
- tile_out_y *= tile_out_h;
122
- int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
- int major_idx_base = blockIdx.z * p.loop_major;
124
-
125
- if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
- major_idx_base >= p.major_dim) {
127
- return;
128
- }
129
-
130
- for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
- tap_idx += blockDim.x) {
132
- int ky = tap_idx / kernel_w;
133
- int kx = tap_idx - ky * kernel_w;
134
- scalar_t v = 0.0;
135
-
136
- if (kx < p.kernel_w & ky < p.kernel_h) {
137
- v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
- }
139
-
140
- sk[ky][kx] = v;
141
- }
142
-
143
- for (int loop_major = 0, major_idx = major_idx_base;
144
- loop_major < p.loop_major & major_idx < p.major_dim;
145
- loop_major++, major_idx++) {
146
- for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
- loop_x < p.loop_x & tile_out_x < p.out_w;
148
- loop_x++, tile_out_x += tile_out_w) {
149
- int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
- int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
- int tile_in_x = floor_div(tile_mid_x, up_x);
152
- int tile_in_y = floor_div(tile_mid_y, up_y);
153
-
154
- __syncthreads();
155
-
156
- for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
- in_idx += blockDim.x) {
158
- int rel_in_y = in_idx / tile_in_w;
159
- int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
- int in_x = rel_in_x + tile_in_x;
161
- int in_y = rel_in_y + tile_in_y;
162
-
163
- scalar_t v = 0.0;
164
-
165
- if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
- v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
- p.minor_dim +
168
- minor_idx];
169
- }
170
-
171
- sx[rel_in_y][rel_in_x] = v;
172
- }
173
-
174
- __syncthreads();
175
- for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
- out_idx += blockDim.x) {
177
- int rel_out_y = out_idx / tile_out_w;
178
- int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
- int out_x = rel_out_x + tile_out_x;
180
- int out_y = rel_out_y + tile_out_y;
181
-
182
- int mid_x = tile_mid_x + rel_out_x * down_x;
183
- int mid_y = tile_mid_y + rel_out_y * down_y;
184
- int in_x = floor_div(mid_x, up_x);
185
- int in_y = floor_div(mid_y, up_y);
186
- int rel_in_x = in_x - tile_in_x;
187
- int rel_in_y = in_y - tile_in_y;
188
- int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
- int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
-
191
- scalar_t v = 0.0;
192
-
193
- #pragma unroll
194
- for (int y = 0; y < kernel_h / up_y; y++)
195
- #pragma unroll
196
- for (int x = 0; x < kernel_w / up_x; x++)
197
- v += sx[rel_in_y + y][rel_in_x + x] *
198
- sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
-
200
- if (out_x < p.out_w & out_y < p.out_h) {
201
- out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
- minor_idx] = v;
203
- }
204
- }
205
- }
206
- }
207
- }
208
-
209
- torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
- const torch::Tensor &kernel, int up_x, int up_y,
211
- int down_x, int down_y, int pad_x0, int pad_x1,
212
- int pad_y0, int pad_y1) {
213
- int curDevice = -1;
214
- cudaGetDevice(&curDevice);
215
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
-
217
- UpFirDn2DKernelParams p;
218
-
219
- auto x = input.contiguous();
220
- auto k = kernel.contiguous();
221
-
222
- p.major_dim = x.size(0);
223
- p.in_h = x.size(1);
224
- p.in_w = x.size(2);
225
- p.minor_dim = x.size(3);
226
- p.kernel_h = k.size(0);
227
- p.kernel_w = k.size(1);
228
- p.up_x = up_x;
229
- p.up_y = up_y;
230
- p.down_x = down_x;
231
- p.down_y = down_y;
232
- p.pad_x0 = pad_x0;
233
- p.pad_x1 = pad_x1;
234
- p.pad_y0 = pad_y0;
235
- p.pad_y1 = pad_y1;
236
-
237
- p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
- p.down_y;
239
- p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
- p.down_x;
241
-
242
- auto out =
243
- at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
-
245
- int mode = -1;
246
-
247
- int tile_out_h = -1;
248
- int tile_out_w = -1;
249
-
250
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
- p.kernel_h <= 4 && p.kernel_w <= 4) {
252
- mode = 1;
253
- tile_out_h = 16;
254
- tile_out_w = 64;
255
- }
256
-
257
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
- p.kernel_h <= 3 && p.kernel_w <= 3) {
259
- mode = 2;
260
- tile_out_h = 16;
261
- tile_out_w = 64;
262
- }
263
-
264
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
- p.kernel_h <= 4 && p.kernel_w <= 4) {
266
- mode = 3;
267
- tile_out_h = 16;
268
- tile_out_w = 64;
269
- }
270
-
271
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
- p.kernel_h <= 2 && p.kernel_w <= 2) {
273
- mode = 4;
274
- tile_out_h = 16;
275
- tile_out_w = 64;
276
- }
277
-
278
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
- p.kernel_h <= 4 && p.kernel_w <= 4) {
280
- mode = 5;
281
- tile_out_h = 8;
282
- tile_out_w = 32;
283
- }
284
-
285
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
- p.kernel_h <= 2 && p.kernel_w <= 2) {
287
- mode = 6;
288
- tile_out_h = 8;
289
- tile_out_w = 32;
290
- }
291
-
292
- dim3 block_size;
293
- dim3 grid_size;
294
-
295
- if (tile_out_h > 0 && tile_out_w > 0) {
296
- p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
- p.loop_x = 1;
298
- block_size = dim3(32 * 8, 1, 1);
299
- grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
- (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
- (p.major_dim - 1) / p.loop_major + 1);
302
- } else {
303
- p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
- p.loop_x = 4;
305
- block_size = dim3(4, 32, 1);
306
- grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
- (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
- (p.major_dim - 1) / p.loop_major + 1);
309
- }
310
-
311
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
- switch (mode) {
313
- case 1:
314
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
- x.data_ptr<scalar_t>(),
317
- k.data_ptr<scalar_t>(), p);
318
-
319
- break;
320
-
321
- case 2:
322
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
- x.data_ptr<scalar_t>(),
325
- k.data_ptr<scalar_t>(), p);
326
-
327
- break;
328
-
329
- case 3:
330
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
- x.data_ptr<scalar_t>(),
333
- k.data_ptr<scalar_t>(), p);
334
-
335
- break;
336
-
337
- case 4:
338
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
- x.data_ptr<scalar_t>(),
341
- k.data_ptr<scalar_t>(), p);
342
-
343
- break;
344
-
345
- case 5:
346
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
- x.data_ptr<scalar_t>(),
349
- k.data_ptr<scalar_t>(), p);
350
-
351
- break;
352
-
353
- case 6:
354
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
- x.data_ptr<scalar_t>(),
357
- k.data_ptr<scalar_t>(), p);
358
-
359
- break;
360
-
361
- default:
362
- upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
- k.data_ptr<scalar_t>(), p);
365
- }
366
- });
367
-
368
- return out;
369
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ppl.py DELETED
@@ -1,130 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
- from torch.nn import functional as F
5
- import numpy as np
6
- from tqdm import tqdm
7
-
8
- import lpips
9
- from model import Generator
10
-
11
-
12
- def normalize(x):
13
- return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
14
-
15
-
16
- def slerp(a, b, t):
17
- a = normalize(a)
18
- b = normalize(b)
19
- d = (a * b).sum(-1, keepdim=True)
20
- p = t * torch.acos(d)
21
- c = normalize(b - d * a)
22
- d = a * torch.cos(p) + c * torch.sin(p)
23
-
24
- return normalize(d)
25
-
26
-
27
- def lerp(a, b, t):
28
- return a + (b - a) * t
29
-
30
-
31
- if __name__ == "__main__":
32
- device = "cuda"
33
-
34
- parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
35
-
36
- parser.add_argument(
37
- "--space", choices=["z", "w"], help="space that PPL calculated with"
38
- )
39
- parser.add_argument(
40
- "--batch", type=int, default=64, help="batch size for the models"
41
- )
42
- parser.add_argument(
43
- "--n_sample",
44
- type=int,
45
- default=5000,
46
- help="number of the samples for calculating PPL",
47
- )
48
- parser.add_argument(
49
- "--size", type=int, default=256, help="output image sizes of the generator"
50
- )
51
- parser.add_argument(
52
- "--eps", type=float, default=1e-4, help="epsilon for numerical stability"
53
- )
54
- parser.add_argument(
55
- "--crop", action="store_true", help="apply center crop to the images"
56
- )
57
- parser.add_argument(
58
- "--sampling",
59
- default="end",
60
- choices=["end", "full"],
61
- help="set endpoint sampling method",
62
- )
63
- parser.add_argument(
64
- "ckpt", metavar="CHECKPOINT", help="path to the model checkpoints"
65
- )
66
-
67
- args = parser.parse_args()
68
-
69
- latent_dim = 512
70
-
71
- ckpt = torch.load(args.ckpt)
72
-
73
- g = Generator(args.size, latent_dim, 8).to(device)
74
- g.load_state_dict(ckpt["g_ema"])
75
- g.eval()
76
-
77
- percept = lpips.PerceptualLoss(
78
- model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
79
- )
80
-
81
- distances = []
82
-
83
- n_batch = args.n_sample // args.batch
84
- resid = args.n_sample - (n_batch * args.batch)
85
- batch_sizes = [args.batch] * n_batch + [resid]
86
-
87
- with torch.no_grad():
88
- for batch in tqdm(batch_sizes):
89
- noise = g.make_noise()
90
-
91
- inputs = torch.randn([batch * 2, latent_dim], device=device)
92
- if args.sampling == "full":
93
- lerp_t = torch.rand(batch, device=device)
94
- else:
95
- lerp_t = torch.zeros(batch, device=device)
96
-
97
- if args.space == "w":
98
- latent = g.get_latent(inputs)
99
- latent_t0, latent_t1 = latent[::2], latent[1::2]
100
- latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
101
- latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
102
- latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
103
-
104
- image, _ = g([latent_e], input_is_latent=True, noise=noise)
105
-
106
- if args.crop:
107
- c = image.shape[2] // 8
108
- image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
109
-
110
- factor = image.shape[2] // 256
111
-
112
- if factor > 1:
113
- image = F.interpolate(
114
- image, size=(256, 256), mode="bilinear", align_corners=False
115
- )
116
-
117
- dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
118
- args.eps ** 2
119
- )
120
- distances.append(dist.to("cpu").numpy())
121
-
122
- distances = np.concatenate(distances, 0)
123
-
124
- lo = np.percentile(distances, 1, interpolation="lower")
125
- hi = np.percentile(distances, 99, interpolation="higher")
126
- filtered_dist = np.extract(
127
- np.logical_and(lo <= distances, distances <= hi), distances
128
- )
129
-
130
- print("ppl:", filtered_dist.mean())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prepare_data.py DELETED
@@ -1,101 +0,0 @@
1
- import argparse
2
- from io import BytesIO
3
- import multiprocessing
4
- from functools import partial
5
-
6
- from PIL import Image
7
- import lmdb
8
- from tqdm import tqdm
9
- from torchvision import datasets
10
- from torchvision.transforms import functional as trans_fn
11
-
12
-
13
- def resize_and_convert(img, size, resample, quality=100):
14
- img = trans_fn.resize(img, size, resample)
15
- img = trans_fn.center_crop(img, size)
16
- buffer = BytesIO()
17
- img.save(buffer, format="jpeg", quality=quality)
18
- val = buffer.getvalue()
19
-
20
- return val
21
-
22
-
23
- def resize_multiple(
24
- img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
25
- ):
26
- imgs = []
27
-
28
- for size in sizes:
29
- imgs.append(resize_and_convert(img, size, resample, quality))
30
-
31
- return imgs
32
-
33
-
34
- def resize_worker(img_file, sizes, resample):
35
- i, file = img_file
36
- img = Image.open(file)
37
- img = img.convert("RGB")
38
- out = resize_multiple(img, sizes=sizes, resample=resample)
39
-
40
- return i, out
41
-
42
-
43
- def prepare(
44
- env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
45
- ):
46
- resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
47
-
48
- files = sorted(dataset.imgs, key=lambda x: x[0])
49
- files = [(i, file) for i, (file, label) in enumerate(files)]
50
- total = 0
51
-
52
- with multiprocessing.Pool(n_worker) as pool:
53
- for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
54
- for size, img in zip(sizes, imgs):
55
- key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
56
-
57
- with env.begin(write=True) as txn:
58
- txn.put(key, img)
59
-
60
- total += 1
61
-
62
- with env.begin(write=True) as txn:
63
- txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
64
-
65
-
66
- if __name__ == "__main__":
67
- parser = argparse.ArgumentParser(description="Preprocess images for model training")
68
- parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
69
- parser.add_argument(
70
- "--size",
71
- type=str,
72
- default="128,256,512,1024",
73
- help="resolutions of images for the dataset",
74
- )
75
- parser.add_argument(
76
- "--n_worker",
77
- type=int,
78
- default=8,
79
- help="number of workers for preparing dataset",
80
- )
81
- parser.add_argument(
82
- "--resample",
83
- type=str,
84
- default="lanczos",
85
- help="resampling methods for resizing images",
86
- )
87
- parser.add_argument("path", type=str, help="path to the image dataset")
88
-
89
- args = parser.parse_args()
90
-
91
- resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
92
- resample = resample_map[args.resample]
93
-
94
- sizes = [int(s.strip()) for s in args.size.split(",")]
95
-
96
- print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
97
-
98
- imgset = datasets.ImageFolder(args.path)
99
-
100
- with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
101
- prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
projector.py DELETED
@@ -1,248 +0,0 @@
1
- import argparse
2
- import math
3
- import os
4
-
5
- import torch
6
- from torch import optim
7
- from torch.nn import functional as F
8
- from torchvision import transforms
9
- from PIL import Image
10
- from tqdm import tqdm
11
-
12
- import lpips
13
- from model import Generator
14
-
15
-
16
- def noise_regularize(noises):
17
- loss = 0
18
-
19
- for noise in noises:
20
- size = noise.shape[2]
21
-
22
- while True:
23
- loss = (
24
- loss
25
- + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
26
- + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
27
- )
28
-
29
- if size <= 8:
30
- break
31
-
32
- noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
33
- noise = noise.mean([3, 5])
34
- size //= 2
35
-
36
- return loss
37
-
38
-
39
- def noise_normalize_(noises):
40
- for noise in noises:
41
- mean = noise.mean()
42
- std = noise.std()
43
-
44
- noise.data.add_(-mean).div_(std)
45
-
46
-
47
- def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
48
- lr_ramp = min(1, (1 - t) / rampdown)
49
- lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
50
- lr_ramp = lr_ramp * min(1, t / rampup)
51
-
52
- return initial_lr * lr_ramp
53
-
54
-
55
- def latent_noise(latent, strength):
56
- noise = torch.randn_like(latent) * strength
57
-
58
- return latent + noise
59
-
60
-
61
- def make_image(tensor):
62
- return (
63
- tensor.detach()
64
- .clamp_(min=-1, max=1)
65
- .add(1)
66
- .div_(2)
67
- .mul(255)
68
- .type(torch.uint8)
69
- .permute(0, 2, 3, 1)
70
- .to("cpu")
71
- .numpy()
72
- )
73
-
74
-
75
- if __name__ == "__main__":
76
- device = "cuda"
77
-
78
- parser = argparse.ArgumentParser(
79
- description="Image projector to the generator latent spaces"
80
- )
81
- parser.add_argument(
82
- "--ckpt", type=str, required=True, help="path to the model checkpoint"
83
- )
84
- parser.add_argument(
85
- "--size", type=int, default=256, help="output image sizes of the generator"
86
- )
87
- parser.add_argument(
88
- "--lr_rampup",
89
- type=float,
90
- default=0.05,
91
- help="duration of the learning rate warmup",
92
- )
93
- parser.add_argument(
94
- "--lr_rampdown",
95
- type=float,
96
- default=0.25,
97
- help="duration of the learning rate decay",
98
- )
99
- parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
100
- parser.add_argument(
101
- "--noise", type=float, default=0.05, help="strength of the noise level"
102
- )
103
- parser.add_argument(
104
- "--noise_ramp",
105
- type=float,
106
- default=0.75,
107
- help="duration of the noise level decay",
108
- )
109
- parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
110
- parser.add_argument(
111
- "--noise_regularize",
112
- type=float,
113
- default=1e5,
114
- help="weight of the noise regularization",
115
- )
116
- parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
117
- parser.add_argument(
118
- "--w_plus",
119
- action="store_true",
120
- help="allow to use distinct latent codes to each layers",
121
- )
122
- parser.add_argument(
123
- "files", metavar="FILES", nargs="+", help="path to image files to be projected"
124
- )
125
-
126
- args = parser.parse_args()
127
-
128
- n_mean_latent = 10000
129
-
130
- resize = min(args.size, 256)
131
-
132
- transform = transforms.Compose(
133
- [
134
- transforms.Resize(resize),
135
- transforms.CenterCrop(resize),
136
- transforms.ToTensor(),
137
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
138
- ]
139
- )
140
-
141
- imgs = []
142
-
143
- for imgfile in args.files:
144
- img = transform(Image.open(imgfile).convert("RGB"))
145
- imgs.append(img)
146
-
147
- imgs = torch.stack(imgs, 0).to(device)
148
-
149
- g_ema = Generator(args.size, 512, 8)
150
- g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
151
- g_ema.eval()
152
- g_ema = g_ema.to(device)
153
-
154
- with torch.no_grad():
155
- noise_sample = torch.randn(n_mean_latent, 512, device=device)
156
- latent_out = g_ema.style(noise_sample)
157
-
158
- latent_mean = latent_out.mean(0)
159
- latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
160
-
161
- percept = lpips.PerceptualLoss(
162
- model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
163
- )
164
-
165
- noises_single = g_ema.make_noise()
166
- noises = []
167
- for noise in noises_single:
168
- noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
169
-
170
- latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
171
-
172
- if args.w_plus:
173
- latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
174
-
175
- latent_in.requires_grad = True
176
-
177
- for noise in noises:
178
- noise.requires_grad = True
179
-
180
- optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
181
-
182
- pbar = tqdm(range(args.step))
183
- latent_path = []
184
-
185
- for i in pbar:
186
- t = i / args.step
187
- lr = get_lr(t, args.lr)
188
- optimizer.param_groups[0]["lr"] = lr
189
- noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
190
- latent_n = latent_noise(latent_in, noise_strength.item())
191
-
192
- img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
193
-
194
- batch, channel, height, width = img_gen.shape
195
-
196
- if height > 256:
197
- factor = height // 256
198
-
199
- img_gen = img_gen.reshape(
200
- batch, channel, height // factor, factor, width // factor, factor
201
- )
202
- img_gen = img_gen.mean([3, 5])
203
-
204
- p_loss = percept(img_gen, imgs).sum()
205
- n_loss = noise_regularize(noises)
206
- mse_loss = F.mse_loss(img_gen, imgs)
207
-
208
- loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
209
-
210
- optimizer.zero_grad()
211
- loss.backward()
212
- optimizer.step()
213
-
214
- noise_normalize_(noises)
215
-
216
- if (i + 1) % 100 == 0:
217
- latent_path.append(latent_in.detach().clone())
218
-
219
- pbar.set_description(
220
- (
221
- f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
222
- f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
223
- )
224
- )
225
-
226
- img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
227
-
228
- filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
229
-
230
- img_ar = make_image(img_gen)
231
-
232
- result_file = {}
233
- for i, input_name in enumerate(args.files):
234
- noise_single = []
235
- for noise in noises:
236
- noise_single.append(noise[i : i + 1])
237
-
238
- result_file[input_name] = {
239
- "img": img_gen[i],
240
- "latent": latent_in[i],
241
- "noise": noise_single,
242
- }
243
-
244
- img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
245
- pil_img = Image.fromarray(img_ar[i])
246
- pil_img.save(img_name)
247
-
248
- torch.save(result_file, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample/.gitignore DELETED
@@ -1 +0,0 @@
1
- *.png
 
 
swagan.py DELETED
@@ -1,440 +0,0 @@
1
- import math
2
- import random
3
- import functools
4
- import operator
5
-
6
- import torch
7
- from torch import nn
8
- from torch.nn import functional as F
9
- from torch.autograd import Function
10
-
11
- from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
- from model import (
13
- ModulatedConv2d,
14
- StyledConv,
15
- ConstantInput,
16
- PixelNorm,
17
- Upsample,
18
- Downsample,
19
- Blur,
20
- EqualLinear,
21
- ConvLayer,
22
- )
23
-
24
-
25
- def get_haar_wavelet(in_channels):
26
- haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
27
- haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2)
28
- haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0]
29
-
30
- haar_wav_ll = haar_wav_l.T * haar_wav_l
31
- haar_wav_lh = haar_wav_h.T * haar_wav_l
32
- haar_wav_hl = haar_wav_l.T * haar_wav_h
33
- haar_wav_hh = haar_wav_h.T * haar_wav_h
34
-
35
- return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh
36
-
37
-
38
- def dwt_init(x):
39
- x01 = x[:, :, 0::2, :] / 2
40
- x02 = x[:, :, 1::2, :] / 2
41
- x1 = x01[:, :, :, 0::2]
42
- x2 = x02[:, :, :, 0::2]
43
- x3 = x01[:, :, :, 1::2]
44
- x4 = x02[:, :, :, 1::2]
45
- x_LL = x1 + x2 + x3 + x4
46
- x_HL = -x1 - x2 + x3 + x4
47
- x_LH = -x1 + x2 - x3 + x4
48
- x_HH = x1 - x2 - x3 + x4
49
-
50
- return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
51
-
52
-
53
- def iwt_init(x):
54
- r = 2
55
- in_batch, in_channel, in_height, in_width = x.size()
56
- # print([in_batch, in_channel, in_height, in_width])
57
- out_batch, out_channel, out_height, out_width = (
58
- in_batch,
59
- int(in_channel / (r ** 2)),
60
- r * in_height,
61
- r * in_width,
62
- )
63
- x1 = x[:, 0:out_channel, :, :] / 2
64
- x2 = x[:, out_channel : out_channel * 2, :, :] / 2
65
- x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
66
- x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
67
-
68
- h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
69
-
70
- h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
71
- h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
72
- h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
73
- h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
74
-
75
- return h
76
-
77
-
78
- class HaarTransform(nn.Module):
79
- def __init__(self, in_channels):
80
- super().__init__()
81
-
82
- ll, lh, hl, hh = get_haar_wavelet(in_channels)
83
-
84
- self.register_buffer("ll", ll)
85
- self.register_buffer("lh", lh)
86
- self.register_buffer("hl", hl)
87
- self.register_buffer("hh", hh)
88
-
89
- def forward(self, input):
90
- ll = upfirdn2d(input, self.ll, down=2)
91
- lh = upfirdn2d(input, self.lh, down=2)
92
- hl = upfirdn2d(input, self.hl, down=2)
93
- hh = upfirdn2d(input, self.hh, down=2)
94
-
95
- return torch.cat((ll, lh, hl, hh), 1)
96
-
97
-
98
- class InverseHaarTransform(nn.Module):
99
- def __init__(self, in_channels):
100
- super().__init__()
101
-
102
- ll, lh, hl, hh = get_haar_wavelet(in_channels)
103
-
104
- self.register_buffer("ll", ll)
105
- self.register_buffer("lh", -lh)
106
- self.register_buffer("hl", -hl)
107
- self.register_buffer("hh", hh)
108
-
109
- def forward(self, input):
110
- ll, lh, hl, hh = input.chunk(4, 1)
111
- ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
112
- lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
113
- hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
114
- hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))
115
-
116
- return ll + lh + hl + hh
117
-
118
-
119
- class ToRGB(nn.Module):
120
- def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
121
- super().__init__()
122
-
123
- if upsample:
124
- self.iwt = InverseHaarTransform(3)
125
- self.upsample = Upsample(blur_kernel)
126
- self.dwt = HaarTransform(3)
127
-
128
- self.conv = ModulatedConv2d(in_channel, 3 * 4, 1, style_dim, demodulate=False)
129
- self.bias = nn.Parameter(torch.zeros(1, 3 * 4, 1, 1))
130
-
131
- def forward(self, input, style, skip=None):
132
- out = self.conv(input, style)
133
- out = out + self.bias
134
-
135
- if skip is not None:
136
- skip = self.iwt(skip)
137
- skip = self.upsample(skip)
138
- skip = self.dwt(skip)
139
-
140
- out = out + skip
141
-
142
- return out
143
-
144
-
145
- class Generator(nn.Module):
146
- def __init__(
147
- self,
148
- size,
149
- style_dim,
150
- n_mlp,
151
- channel_multiplier=2,
152
- blur_kernel=[1, 3, 3, 1],
153
- lr_mlp=0.01,
154
- ):
155
- super().__init__()
156
-
157
- self.size = size
158
-
159
- self.style_dim = style_dim
160
-
161
- layers = [PixelNorm()]
162
-
163
- for i in range(n_mlp):
164
- layers.append(
165
- EqualLinear(
166
- style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
167
- )
168
- )
169
-
170
- self.style = nn.Sequential(*layers)
171
-
172
- self.channels = {
173
- 4: 512,
174
- 8: 512,
175
- 16: 512,
176
- 32: 512,
177
- 64: 256 * channel_multiplier,
178
- 128: 128 * channel_multiplier,
179
- 256: 64 * channel_multiplier,
180
- 512: 32 * channel_multiplier,
181
- 1024: 16 * channel_multiplier,
182
- }
183
-
184
- self.input = ConstantInput(self.channels[4])
185
- self.conv1 = StyledConv(
186
- self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
187
- )
188
- self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
189
-
190
- self.log_size = int(math.log(size, 2)) - 1
191
- self.num_layers = (self.log_size - 2) * 2 + 1
192
-
193
- self.convs = nn.ModuleList()
194
- self.upsamples = nn.ModuleList()
195
- self.to_rgbs = nn.ModuleList()
196
- self.noises = nn.Module()
197
-
198
- in_channel = self.channels[4]
199
-
200
- for layer_idx in range(self.num_layers):
201
- res = (layer_idx + 5) // 2
202
- shape = [1, 1, 2 ** res, 2 ** res]
203
- self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
204
-
205
- for i in range(3, self.log_size + 1):
206
- out_channel = self.channels[2 ** i]
207
-
208
- self.convs.append(
209
- StyledConv(
210
- in_channel,
211
- out_channel,
212
- 3,
213
- style_dim,
214
- upsample=True,
215
- blur_kernel=blur_kernel,
216
- )
217
- )
218
-
219
- self.convs.append(
220
- StyledConv(
221
- out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
222
- )
223
- )
224
-
225
- self.to_rgbs.append(ToRGB(out_channel, style_dim))
226
-
227
- in_channel = out_channel
228
-
229
- self.iwt = InverseHaarTransform(3)
230
-
231
- self.n_latent = self.log_size * 2 - 2
232
-
233
- def make_noise(self):
234
- device = self.input.input.device
235
-
236
- noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
237
-
238
- for i in range(3, self.log_size + 1):
239
- for _ in range(2):
240
- noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
241
-
242
- return noises
243
-
244
- def mean_latent(self, n_latent):
245
- latent_in = torch.randn(
246
- n_latent, self.style_dim, device=self.input.input.device
247
- )
248
- latent = self.style(latent_in).mean(0, keepdim=True)
249
-
250
- return latent
251
-
252
- def get_latent(self, input):
253
- return self.style(input)
254
-
255
- def forward(
256
- self,
257
- styles,
258
- return_latents=False,
259
- inject_index=None,
260
- truncation=1,
261
- truncation_latent=None,
262
- input_is_latent=False,
263
- noise=None,
264
- randomize_noise=True,
265
- ):
266
- if not input_is_latent:
267
- styles = [self.style(s) for s in styles]
268
-
269
- if noise is None:
270
- if randomize_noise:
271
- noise = [None] * self.num_layers
272
- else:
273
- noise = [
274
- getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
275
- ]
276
-
277
- if truncation < 1:
278
- style_t = []
279
-
280
- for style in styles:
281
- style_t.append(
282
- truncation_latent + truncation * (style - truncation_latent)
283
- )
284
-
285
- styles = style_t
286
-
287
- if len(styles) < 2:
288
- inject_index = self.n_latent
289
-
290
- if styles[0].ndim < 3:
291
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
292
-
293
- else:
294
- latent = styles[0]
295
-
296
- else:
297
- if inject_index is None:
298
- inject_index = random.randint(1, self.n_latent - 1)
299
-
300
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
301
- latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
302
-
303
- latent = torch.cat([latent, latent2], 1)
304
-
305
- out = self.input(latent)
306
- out = self.conv1(out, latent[:, 0], noise=noise[0])
307
-
308
- skip = self.to_rgb1(out, latent[:, 1])
309
-
310
- i = 1
311
- for conv1, conv2, noise1, noise2, to_rgb in zip(
312
- self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
313
- ):
314
- out = conv1(out, latent[:, i], noise=noise1)
315
- out = conv2(out, latent[:, i + 1], noise=noise2)
316
- skip = to_rgb(out, latent[:, i + 2], skip)
317
-
318
- i += 2
319
-
320
- image = self.iwt(skip)
321
-
322
- if return_latents:
323
- return image, latent
324
-
325
- else:
326
- return image, None
327
-
328
-
329
- class ConvBlock(nn.Module):
330
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
331
- super().__init__()
332
-
333
- self.conv1 = ConvLayer(in_channel, in_channel, 3)
334
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
335
-
336
- def forward(self, input):
337
- out = self.conv1(input)
338
- out = self.conv2(out)
339
-
340
- return out
341
-
342
-
343
- class FromRGB(nn.Module):
344
- def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]):
345
- super().__init__()
346
-
347
- self.downsample = downsample
348
-
349
- if downsample:
350
- self.iwt = InverseHaarTransform(3)
351
- self.downsample = Downsample(blur_kernel)
352
- self.dwt = HaarTransform(3)
353
-
354
- self.conv = ConvLayer(3 * 4, out_channel, 3)
355
-
356
- def forward(self, input, skip=None):
357
- if self.downsample:
358
- input = self.iwt(input)
359
- input = self.downsample(input)
360
- input = self.dwt(input)
361
-
362
- out = self.conv(input)
363
-
364
- if skip is not None:
365
- out = out + skip
366
-
367
- return input, out
368
-
369
-
370
- class Discriminator(nn.Module):
371
- def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
372
- super().__init__()
373
-
374
- channels = {
375
- 4: 512,
376
- 8: 512,
377
- 16: 512,
378
- 32: 512,
379
- 64: 256 * channel_multiplier,
380
- 128: 128 * channel_multiplier,
381
- 256: 64 * channel_multiplier,
382
- 512: 32 * channel_multiplier,
383
- 1024: 16 * channel_multiplier,
384
- }
385
-
386
- self.dwt = HaarTransform(3)
387
-
388
- self.from_rgbs = nn.ModuleList()
389
- self.convs = nn.ModuleList()
390
-
391
- log_size = int(math.log(size, 2)) - 1
392
-
393
- in_channel = channels[size]
394
-
395
- for i in range(log_size, 2, -1):
396
- out_channel = channels[2 ** (i - 1)]
397
-
398
- self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size))
399
- self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel))
400
-
401
- in_channel = out_channel
402
-
403
- self.from_rgbs.append(FromRGB(channels[4]))
404
-
405
- self.stddev_group = 4
406
- self.stddev_feat = 1
407
-
408
- self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
409
- self.final_linear = nn.Sequential(
410
- EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
411
- EqualLinear(channels[4], 1),
412
- )
413
-
414
- def forward(self, input):
415
- input = self.dwt(input)
416
- out = None
417
-
418
- for from_rgb, conv in zip(self.from_rgbs, self.convs):
419
- input, out = from_rgb(input, out)
420
- out = conv(out)
421
-
422
- _, out = self.from_rgbs[-1](input, out)
423
-
424
- batch, channel, height, width = out.shape
425
- group = min(batch, self.stddev_group)
426
- stddev = out.view(
427
- group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
428
- )
429
- stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
430
- stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
431
- stddev = stddev.repeat(group, 1, height, width)
432
- out = torch.cat([out, stddev], 1)
433
-
434
- out = self.final_conv(out)
435
-
436
- out = out.view(batch, -1)
437
- out = self.final_linear(out)
438
-
439
- return out
440
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,531 +0,0 @@
1
- import argparse
2
- import math
3
- import random
4
- import os
5
-
6
- import numpy as np
7
- import torch
8
- from torch import nn, autograd, optim
9
- from torch.nn import functional as F
10
- from torch.utils import data
11
- import torch.distributed as dist
12
- from torchvision import transforms, utils
13
- from tqdm import tqdm
14
-
15
- try:
16
- import wandb
17
-
18
- except ImportError:
19
- wandb = None
20
-
21
-
22
- from dataset import MultiResolutionDataset
23
- from distributed import (
24
- get_rank,
25
- synchronize,
26
- reduce_loss_dict,
27
- reduce_sum,
28
- get_world_size,
29
- )
30
- from op import conv2d_gradfix
31
- from non_leaking import augment, AdaptiveAugment
32
-
33
-
34
- def data_sampler(dataset, shuffle, distributed):
35
- if distributed:
36
- return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
37
-
38
- if shuffle:
39
- return data.RandomSampler(dataset)
40
-
41
- else:
42
- return data.SequentialSampler(dataset)
43
-
44
-
45
- def requires_grad(model, flag=True):
46
- for p in model.parameters():
47
- p.requires_grad = flag
48
-
49
-
50
- def accumulate(model1, model2, decay=0.999):
51
- par1 = dict(model1.named_parameters())
52
- par2 = dict(model2.named_parameters())
53
-
54
- for k in par1.keys():
55
- par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
56
-
57
-
58
- def sample_data(loader):
59
- while True:
60
- for batch in loader:
61
- yield batch
62
-
63
-
64
- def d_logistic_loss(real_pred, fake_pred):
65
- real_loss = F.softplus(-real_pred)
66
- fake_loss = F.softplus(fake_pred)
67
-
68
- return real_loss.mean() + fake_loss.mean()
69
-
70
-
71
- def d_r1_loss(real_pred, real_img):
72
- with conv2d_gradfix.no_weight_gradients():
73
- grad_real, = autograd.grad(
74
- outputs=real_pred.sum(), inputs=real_img, create_graph=True
75
- )
76
- grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
77
-
78
- return grad_penalty
79
-
80
-
81
- def g_nonsaturating_loss(fake_pred):
82
- loss = F.softplus(-fake_pred).mean()
83
-
84
- return loss
85
-
86
-
87
- def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
88
- noise = torch.randn_like(fake_img) / math.sqrt(
89
- fake_img.shape[2] * fake_img.shape[3]
90
- )
91
- grad, = autograd.grad(
92
- outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
93
- )
94
- path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
95
-
96
- path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
97
-
98
- path_penalty = (path_lengths - path_mean).pow(2).mean()
99
-
100
- return path_penalty, path_mean.detach(), path_lengths
101
-
102
-
103
- def make_noise(batch, latent_dim, n_noise, device):
104
- if n_noise == 1:
105
- return torch.randn(batch, latent_dim, device=device)
106
-
107
- noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
108
-
109
- return noises
110
-
111
-
112
- def mixing_noise(batch, latent_dim, prob, device):
113
- if prob > 0 and random.random() < prob:
114
- return make_noise(batch, latent_dim, 2, device)
115
-
116
- else:
117
- return [make_noise(batch, latent_dim, 1, device)]
118
-
119
-
120
- def set_grad_none(model, targets):
121
- for n, p in model.named_parameters():
122
- if n in targets:
123
- p.grad = None
124
-
125
-
126
- def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
127
- loader = sample_data(loader)
128
-
129
- pbar = range(args.iter)
130
-
131
- if get_rank() == 0:
132
- pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
133
-
134
- mean_path_length = 0
135
-
136
- d_loss_val = 0
137
- r1_loss = torch.tensor(0.0, device=device)
138
- g_loss_val = 0
139
- path_loss = torch.tensor(0.0, device=device)
140
- path_lengths = torch.tensor(0.0, device=device)
141
- mean_path_length_avg = 0
142
- loss_dict = {}
143
-
144
- if args.distributed:
145
- g_module = generator.module
146
- d_module = discriminator.module
147
-
148
- else:
149
- g_module = generator
150
- d_module = discriminator
151
-
152
- accum = 0.5 ** (32 / (10 * 1000))
153
- ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
154
- r_t_stat = 0
155
-
156
- if args.augment and args.augment_p == 0:
157
- ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)
158
-
159
- sample_z = torch.randn(args.n_sample, args.latent, device=device)
160
-
161
- for idx in pbar:
162
- i = idx + args.start_iter
163
-
164
- if i > args.iter:
165
- print("Done!")
166
-
167
- break
168
-
169
- real_img = next(loader)
170
- real_img = real_img.to(device)
171
-
172
- requires_grad(generator, False)
173
- requires_grad(discriminator, True)
174
-
175
- noise = mixing_noise(args.batch, args.latent, args.mixing, device)
176
- fake_img, _ = generator(noise)
177
-
178
- if args.augment:
179
- real_img_aug, _ = augment(real_img, ada_aug_p)
180
- fake_img, _ = augment(fake_img, ada_aug_p)
181
-
182
- else:
183
- real_img_aug = real_img
184
-
185
- fake_pred = discriminator(fake_img)
186
- real_pred = discriminator(real_img_aug)
187
- d_loss = d_logistic_loss(real_pred, fake_pred)
188
-
189
- loss_dict["d"] = d_loss
190
- loss_dict["real_score"] = real_pred.mean()
191
- loss_dict["fake_score"] = fake_pred.mean()
192
-
193
- discriminator.zero_grad()
194
- d_loss.backward()
195
- d_optim.step()
196
-
197
- if args.augment and args.augment_p == 0:
198
- ada_aug_p = ada_augment.tune(real_pred)
199
- r_t_stat = ada_augment.r_t_stat
200
-
201
- d_regularize = i % args.d_reg_every == 0
202
-
203
- if d_regularize:
204
- real_img.requires_grad = True
205
-
206
- if args.augment:
207
- real_img_aug, _ = augment(real_img, ada_aug_p)
208
-
209
- else:
210
- real_img_aug = real_img
211
-
212
- real_pred = discriminator(real_img_aug)
213
- r1_loss = d_r1_loss(real_pred, real_img)
214
-
215
- discriminator.zero_grad()
216
- (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
217
-
218
- d_optim.step()
219
-
220
- loss_dict["r1"] = r1_loss
221
-
222
- requires_grad(generator, True)
223
- requires_grad(discriminator, False)
224
-
225
- noise = mixing_noise(args.batch, args.latent, args.mixing, device)
226
- fake_img, _ = generator(noise)
227
-
228
- if args.augment:
229
- fake_img, _ = augment(fake_img, ada_aug_p)
230
-
231
- fake_pred = discriminator(fake_img)
232
- g_loss = g_nonsaturating_loss(fake_pred)
233
-
234
- loss_dict["g"] = g_loss
235
-
236
- generator.zero_grad()
237
- g_loss.backward()
238
- g_optim.step()
239
-
240
- g_regularize = i % args.g_reg_every == 0
241
-
242
- if g_regularize:
243
- path_batch_size = max(1, args.batch // args.path_batch_shrink)
244
- noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
245
- fake_img, latents = generator(noise, return_latents=True)
246
-
247
- path_loss, mean_path_length, path_lengths = g_path_regularize(
248
- fake_img, latents, mean_path_length
249
- )
250
-
251
- generator.zero_grad()
252
- weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
253
-
254
- if args.path_batch_shrink:
255
- weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
256
-
257
- weighted_path_loss.backward()
258
-
259
- g_optim.step()
260
-
261
- mean_path_length_avg = (
262
- reduce_sum(mean_path_length).item() / get_world_size()
263
- )
264
-
265
- loss_dict["path"] = path_loss
266
- loss_dict["path_length"] = path_lengths.mean()
267
-
268
- accumulate(g_ema, g_module, accum)
269
-
270
- loss_reduced = reduce_loss_dict(loss_dict)
271
-
272
- d_loss_val = loss_reduced["d"].mean().item()
273
- g_loss_val = loss_reduced["g"].mean().item()
274
- r1_val = loss_reduced["r1"].mean().item()
275
- path_loss_val = loss_reduced["path"].mean().item()
276
- real_score_val = loss_reduced["real_score"].mean().item()
277
- fake_score_val = loss_reduced["fake_score"].mean().item()
278
- path_length_val = loss_reduced["path_length"].mean().item()
279
-
280
- if get_rank() == 0:
281
- pbar.set_description(
282
- (
283
- f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
284
- f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
285
- f"augment: {ada_aug_p:.4f}"
286
- )
287
- )
288
-
289
- if wandb and args.wandb:
290
- wandb.log(
291
- {
292
- "Generator": g_loss_val,
293
- "Discriminator": d_loss_val,
294
- "Augment": ada_aug_p,
295
- "Rt": r_t_stat,
296
- "R1": r1_val,
297
- "Path Length Regularization": path_loss_val,
298
- "Mean Path Length": mean_path_length,
299
- "Real Score": real_score_val,
300
- "Fake Score": fake_score_val,
301
- "Path Length": path_length_val,
302
- }
303
- )
304
-
305
- if i % 100 == 0:
306
- with torch.no_grad():
307
- g_ema.eval()
308
- sample, _ = g_ema([sample_z])
309
- utils.save_image(
310
- sample,
311
- f"sample/{str(i).zfill(6)}.png",
312
- nrow=int(args.n_sample ** 0.5),
313
- normalize=True,
314
- range=(-1, 1),
315
- )
316
-
317
- if i % 10000 == 0:
318
- torch.save(
319
- {
320
- "g": g_module.state_dict(),
321
- "d": d_module.state_dict(),
322
- "g_ema": g_ema.state_dict(),
323
- "g_optim": g_optim.state_dict(),
324
- "d_optim": d_optim.state_dict(),
325
- "args": args,
326
- "ada_aug_p": ada_aug_p,
327
- },
328
- f"checkpoint/{str(i).zfill(6)}.pt",
329
- )
330
-
331
-
332
- if __name__ == "__main__":
333
- device = "cuda"
334
-
335
- parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
336
-
337
- parser.add_argument("path", type=str, help="path to the lmdb dataset")
338
- parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
339
- parser.add_argument(
340
- "--iter", type=int, default=800000, help="total training iterations"
341
- )
342
- parser.add_argument(
343
- "--batch", type=int, default=16, help="batch sizes for each gpus"
344
- )
345
- parser.add_argument(
346
- "--n_sample",
347
- type=int,
348
- default=64,
349
- help="number of the samples generated during training",
350
- )
351
- parser.add_argument(
352
- "--size", type=int, default=256, help="image sizes for the model"
353
- )
354
- parser.add_argument(
355
- "--r1", type=float, default=10, help="weight of the r1 regularization"
356
- )
357
- parser.add_argument(
358
- "--path_regularize",
359
- type=float,
360
- default=2,
361
- help="weight of the path length regularization",
362
- )
363
- parser.add_argument(
364
- "--path_batch_shrink",
365
- type=int,
366
- default=2,
367
- help="batch size reducing factor for the path length regularization (reduce memory consumption)",
368
- )
369
- parser.add_argument(
370
- "--d_reg_every",
371
- type=int,
372
- default=16,
373
- help="interval of the applying r1 regularization",
374
- )
375
- parser.add_argument(
376
- "--g_reg_every",
377
- type=int,
378
- default=4,
379
- help="interval of the applying path length regularization",
380
- )
381
- parser.add_argument(
382
- "--mixing", type=float, default=0.9, help="probability of latent code mixing"
383
- )
384
- parser.add_argument(
385
- "--ckpt",
386
- type=str,
387
- default=None,
388
- help="path to the checkpoints to resume training",
389
- )
390
- parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
391
- parser.add_argument(
392
- "--channel_multiplier",
393
- type=int,
394
- default=2,
395
- help="channel multiplier factor for the model. config-f = 2, else = 1",
396
- )
397
- parser.add_argument(
398
- "--wandb", action="store_true", help="use weights and biases logging"
399
- )
400
- parser.add_argument(
401
- "--local_rank", type=int, default=0, help="local rank for distributed training"
402
- )
403
- parser.add_argument(
404
- "--augment", action="store_true", help="apply non leaking augmentation"
405
- )
406
- parser.add_argument(
407
- "--augment_p",
408
- type=float,
409
- default=0,
410
- help="probability of applying augmentation. 0 = use adaptive augmentation",
411
- )
412
- parser.add_argument(
413
- "--ada_target",
414
- type=float,
415
- default=0.6,
416
- help="target augmentation probability for adaptive augmentation",
417
- )
418
- parser.add_argument(
419
- "--ada_length",
420
- type=int,
421
- default=500 * 1000,
422
- help="target duraing to reach augmentation probability for adaptive augmentation",
423
- )
424
- parser.add_argument(
425
- "--ada_every",
426
- type=int,
427
- default=256,
428
- help="probability update interval of the adaptive augmentation",
429
- )
430
-
431
- args = parser.parse_args()
432
-
433
- n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
434
- args.distributed = n_gpu > 1
435
-
436
- if args.distributed:
437
- torch.cuda.set_device(args.local_rank)
438
- torch.distributed.init_process_group(backend="nccl", init_method="env://")
439
- synchronize()
440
-
441
- args.latent = 512
442
- args.n_mlp = 8
443
-
444
- args.start_iter = 0
445
-
446
- if args.arch == 'stylegan2':
447
- from model import Generator, Discriminator
448
-
449
- elif args.arch == 'swagan':
450
- from swagan import Generator, Discriminator
451
-
452
- generator = Generator(
453
- args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
454
- ).to(device)
455
- discriminator = Discriminator(
456
- args.size, channel_multiplier=args.channel_multiplier
457
- ).to(device)
458
- g_ema = Generator(
459
- args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
460
- ).to(device)
461
- g_ema.eval()
462
- accumulate(g_ema, generator, 0)
463
-
464
- g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
465
- d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
466
-
467
- g_optim = optim.Adam(
468
- generator.parameters(),
469
- lr=args.lr * g_reg_ratio,
470
- betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
471
- )
472
- d_optim = optim.Adam(
473
- discriminator.parameters(),
474
- lr=args.lr * d_reg_ratio,
475
- betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
476
- )
477
-
478
- if args.ckpt is not None:
479
- print("load model:", args.ckpt)
480
-
481
- ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
482
-
483
- try:
484
- ckpt_name = os.path.basename(args.ckpt)
485
- args.start_iter = int(os.path.splitext(ckpt_name)[0])
486
-
487
- except ValueError:
488
- pass
489
-
490
- generator.load_state_dict(ckpt["g"])
491
- discriminator.load_state_dict(ckpt["d"])
492
- g_ema.load_state_dict(ckpt["g_ema"])
493
-
494
- g_optim.load_state_dict(ckpt["g_optim"])
495
- d_optim.load_state_dict(ckpt["d_optim"])
496
-
497
- if args.distributed:
498
- generator = nn.parallel.DistributedDataParallel(
499
- generator,
500
- device_ids=[args.local_rank],
501
- output_device=args.local_rank,
502
- broadcast_buffers=False,
503
- )
504
-
505
- discriminator = nn.parallel.DistributedDataParallel(
506
- discriminator,
507
- device_ids=[args.local_rank],
508
- output_device=args.local_rank,
509
- broadcast_buffers=False,
510
- )
511
-
512
- transform = transforms.Compose(
513
- [
514
- transforms.RandomHorizontalFlip(),
515
- transforms.ToTensor(),
516
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
517
- ]
518
- )
519
-
520
- dataset = MultiResolutionDataset(args.path, transform, args.size)
521
- loader = data.DataLoader(
522
- dataset,
523
- batch_size=args.batch,
524
- sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
525
- drop_last=True,
526
- )
527
-
528
- if get_rank() == 0 and wandb is not None and args.wandb:
529
- wandb.init(project="stylegan 2")
530
-
531
- train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)