Heekyung commited on
Commit
dad8cec
Β·
1 Parent(s): e975653

Upload 9 files

Browse files
LICENSE.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for StyleGAN3
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
arial.ttf ADDED
Binary file (276 kB). View file
 
environment.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: stylegan3
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python >= 3.8
7
+ - pip
8
+ - numpy>=1.20
9
+ - click>=8.0
10
+ - pillow=8.3.1
11
+ - scipy=1.7.1
12
+ - pytorch=1.9.1
13
+ - cudatoolkit=11.1
14
+ - requests=2.26.0
15
+ - tqdm=4.62.2
16
+ - ninja=1.10.2
17
+ - matplotlib=3.4.2
18
+ - imageio=2.9.0
19
+ - pip:
20
+ - imgui==1.3.0
21
+ - glfw==2.2.0
22
+ - pyopengl==3.1.5
23
+ - imageio-ffmpeg==0.4.3
24
+ - pyspng
gen_images.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def parse_range(s: Union[str, List]) -> List[int]:
26
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
27
+
28
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
29
+ '''
30
+ if isinstance(s, list): return s
31
+ ranges = []
32
+ range_re = re.compile(r'^(\d+)-(\d+)$')
33
+ for p in s.split(','):
34
+ m = range_re.match(p)
35
+ if m:
36
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
37
+ else:
38
+ ranges.append(int(p))
39
+ return ranges
40
+
41
+ #----------------------------------------------------------------------------
42
+
43
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
44
+ '''Parse a floating point 2-vector of syntax 'a,b'.
45
+
46
+ Example:
47
+ '0,1' returns (0,1)
48
+ '''
49
+ if isinstance(s, tuple): return s
50
+ parts = s.split(',')
51
+ if len(parts) == 2:
52
+ return (float(parts[0]), float(parts[1]))
53
+ raise ValueError(f'cannot parse 2-vector {s}')
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def make_transform(translate: Tuple[float,float], angle: float):
58
+ m = np.eye(3)
59
+ s = np.sin(angle/360.0*np.pi*2)
60
+ c = np.cos(angle/360.0*np.pi*2)
61
+ m[0][0] = c
62
+ m[0][1] = s
63
+ m[0][2] = translate[0]
64
+ m[1][0] = -s
65
+ m[1][1] = c
66
+ m[1][2] = translate[1]
67
+ return m
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ @click.command()
72
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
73
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
74
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
75
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
76
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
77
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
78
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
79
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
80
+ def generate_images(
81
+ network_pkl: str,
82
+ seeds: List[int],
83
+ truncation_psi: float,
84
+ noise_mode: str,
85
+ outdir: str,
86
+ translate: Tuple[float,float],
87
+ rotate: float,
88
+ class_idx: Optional[int]
89
+ ):
90
+ """Generate images using pretrained network pickle.
91
+
92
+ Examples:
93
+
94
+ \b
95
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
96
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
97
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
98
+
99
+ \b
100
+ # Generate uncurated images with truncation using the MetFaces-U dataset
101
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
102
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
103
+ """
104
+
105
+ print('Loading networks from "%s"...' % network_pkl)
106
+ device = torch.device('cuda')
107
+ with dnnlib.util.open_url(network_pkl) as f:
108
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
109
+ # import pickle
110
+ # G = legacy.load_network_pkl(f)
111
+ # output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
112
+ # pickle.dump(G, output)
113
+
114
+ os.makedirs(outdir, exist_ok=True)
115
+
116
+ # Labels.
117
+ label = torch.zeros([1, G.c_dim], device=device)
118
+ if G.c_dim != 0:
119
+ if class_idx is None:
120
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
121
+ label[:, class_idx] = 1
122
+ else:
123
+ if class_idx is not None:
124
+ print ('warn: --class=lbl ignored when running on an unconditional network')
125
+
126
+ # Generate images.
127
+ for seed_idx, seed in enumerate(seeds):
128
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
129
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
130
+
131
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
132
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
133
+ # operations in the network.
134
+ if hasattr(G.synthesis, 'input'):
135
+ m = make_transform(translate, rotate)
136
+ m = np.linalg.inv(m)
137
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
138
+
139
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
140
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
141
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
142
+
143
+
144
+ #----------------------------------------------------------------------------
145
+
146
+ if __name__ == "__main__":
147
+ generate_images() # pylint: disable=no-value-for-parameter
148
+
149
+ #----------------------------------------------------------------------------
legacy.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Converting legacy network pickle into the new format."""
10
+
11
+ import click
12
+ import pickle
13
+ import re
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from torch_utils import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def load_network_pkl(f, force_fp16=False):
23
+ data = _LegacyUnpickler(f).load()
24
+
25
+ # Legacy TensorFlow pickle => convert.
26
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
27
+ tf_G, tf_D, tf_Gs = data
28
+ G = convert_tf_generator(tf_G)
29
+ D = convert_tf_discriminator(tf_D)
30
+ G_ema = convert_tf_generator(tf_Gs)
31
+ data = dict(G=G, D=D, G_ema=G_ema)
32
+
33
+ # Add missing fields.
34
+ if 'training_set_kwargs' not in data:
35
+ data['training_set_kwargs'] = None
36
+ if 'augment_pipe' not in data:
37
+ data['augment_pipe'] = None
38
+
39
+ # Validate contents.
40
+ assert isinstance(data['G'], torch.nn.Module)
41
+ assert isinstance(data['D'], torch.nn.Module)
42
+ assert isinstance(data['G_ema'], torch.nn.Module)
43
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
44
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
45
+
46
+ # Force FP16.
47
+ if force_fp16:
48
+ for key in ['G', 'D', 'G_ema']:
49
+ old = data[key]
50
+ kwargs = copy.deepcopy(old.init_kwargs)
51
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
52
+ fp16_kwargs.num_fp16_res = 4
53
+ fp16_kwargs.conv_clamp = 256
54
+ if kwargs != old.init_kwargs:
55
+ new = type(old)(**kwargs).eval().requires_grad_(False)
56
+ misc.copy_params_and_buffers(old, new, require_all=True)
57
+ data[key] = new
58
+ return data
59
+
60
+ #----------------------------------------------------------------------------
61
+
62
+ class _TFNetworkStub(dnnlib.EasyDict):
63
+ pass
64
+
65
+ class _LegacyUnpickler(pickle.Unpickler):
66
+ def find_class(self, module, name):
67
+ if module == 'dnnlib.tflib.network' and name == 'Network':
68
+ return _TFNetworkStub
69
+ return super().find_class(module, name)
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ def _collect_tf_params(tf_net):
74
+ # pylint: disable=protected-access
75
+ tf_params = dict()
76
+ def recurse(prefix, tf_net):
77
+ for name, value in tf_net.variables:
78
+ tf_params[prefix + name] = value
79
+ for name, comp in tf_net.components.items():
80
+ recurse(prefix + name + '/', comp)
81
+ recurse('', tf_net)
82
+ return tf_params
83
+
84
+ #----------------------------------------------------------------------------
85
+
86
+ def _populate_module_params(module, *patterns):
87
+ for name, tensor in misc.named_params_and_buffers(module):
88
+ found = False
89
+ value = None
90
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
91
+ match = re.fullmatch(pattern, name)
92
+ if match:
93
+ found = True
94
+ if value_fn is not None:
95
+ value = value_fn(*match.groups())
96
+ break
97
+ try:
98
+ assert found
99
+ if value is not None:
100
+ tensor.copy_(torch.from_numpy(np.array(value)))
101
+ except:
102
+ print(name, list(tensor.shape))
103
+ raise
104
+
105
+ #----------------------------------------------------------------------------
106
+
107
+ def convert_tf_generator(tf_G):
108
+ if tf_G.version < 4:
109
+ raise ValueError('TensorFlow pickle version too low')
110
+
111
+ # Collect kwargs.
112
+ tf_kwargs = tf_G.static_kwargs
113
+ known_kwargs = set()
114
+ def kwarg(tf_name, default=None, none=None):
115
+ known_kwargs.add(tf_name)
116
+ val = tf_kwargs.get(tf_name, default)
117
+ return val if val is not None else none
118
+
119
+ # Convert kwargs.
120
+ from training import networks_stylegan2
121
+ network_class = networks_stylegan2.Generator
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ channel_base = kwarg('fmap_base', 16384) * 2,
129
+ channel_max = kwarg('fmap_max', 512),
130
+ num_fp16_res = kwarg('num_fp16_res', 0),
131
+ conv_clamp = kwarg('conv_clamp', None),
132
+ architecture = kwarg('architecture', 'skip'),
133
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
134
+ use_noise = kwarg('use_noise', True),
135
+ activation = kwarg('nonlinearity', 'lrelu'),
136
+ mapping_kwargs = dnnlib.EasyDict(
137
+ num_layers = kwarg('mapping_layers', 8),
138
+ embed_features = kwarg('label_fmaps', None),
139
+ layer_features = kwarg('mapping_fmaps', None),
140
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
141
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
142
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
143
+ ),
144
+ )
145
+
146
+ # Check for unknown kwargs.
147
+ kwarg('truncation_psi')
148
+ kwarg('truncation_cutoff')
149
+ kwarg('style_mixing_prob')
150
+ kwarg('structure')
151
+ kwarg('conditioning')
152
+ kwarg('fused_modconv')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ G = network_class(**kwargs).eval().requires_grad_(False)
169
+ # pylint: disable=unnecessary-lambda
170
+ # pylint: disable=f-string-without-interpolation
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ r'.*\.act_filter', None,
203
+ )
204
+ return G
205
+
206
+ #----------------------------------------------------------------------------
207
+
208
+ def convert_tf_discriminator(tf_D):
209
+ if tf_D.version < 4:
210
+ raise ValueError('TensorFlow pickle version too low')
211
+
212
+ # Collect kwargs.
213
+ tf_kwargs = tf_D.static_kwargs
214
+ known_kwargs = set()
215
+ def kwarg(tf_name, default=None):
216
+ known_kwargs.add(tf_name)
217
+ return tf_kwargs.get(tf_name, default)
218
+
219
+ # Convert kwargs.
220
+ kwargs = dnnlib.EasyDict(
221
+ c_dim = kwarg('label_size', 0),
222
+ img_resolution = kwarg('resolution', 1024),
223
+ img_channels = kwarg('num_channels', 3),
224
+ architecture = kwarg('architecture', 'resnet'),
225
+ channel_base = kwarg('fmap_base', 16384) * 2,
226
+ channel_max = kwarg('fmap_max', 512),
227
+ num_fp16_res = kwarg('num_fp16_res', 0),
228
+ conv_clamp = kwarg('conv_clamp', None),
229
+ cmap_dim = kwarg('mapping_fmaps', None),
230
+ block_kwargs = dnnlib.EasyDict(
231
+ activation = kwarg('nonlinearity', 'lrelu'),
232
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
233
+ freeze_layers = kwarg('freeze_layers', 0),
234
+ ),
235
+ mapping_kwargs = dnnlib.EasyDict(
236
+ num_layers = kwarg('mapping_layers', 0),
237
+ embed_features = kwarg('mapping_fmaps', None),
238
+ layer_features = kwarg('mapping_fmaps', None),
239
+ activation = kwarg('nonlinearity', 'lrelu'),
240
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
241
+ ),
242
+ epilogue_kwargs = dnnlib.EasyDict(
243
+ mbstd_group_size = kwarg('mbstd_group_size', None),
244
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
245
+ activation = kwarg('nonlinearity', 'lrelu'),
246
+ ),
247
+ )
248
+
249
+ # Check for unknown kwargs.
250
+ kwarg('structure')
251
+ kwarg('conditioning')
252
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
253
+ if len(unknown_kwargs) > 0:
254
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
255
+
256
+ # Collect params.
257
+ tf_params = _collect_tf_params(tf_D)
258
+ for name, value in list(tf_params.items()):
259
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
260
+ if match:
261
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
262
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
263
+ kwargs.architecture = 'orig'
264
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
265
+
266
+ # Convert params.
267
+ from training import networks_stylegan2
268
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
269
+ # pylint: disable=unnecessary-lambda
270
+ # pylint: disable=f-string-without-interpolation
271
+ _populate_module_params(D,
272
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
273
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
274
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
275
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
276
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
277
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
278
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
279
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
280
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
281
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
282
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
283
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
284
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
285
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
286
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
287
+ r'.*\.resample_filter', None,
288
+ )
289
+ return D
290
+
291
+ #----------------------------------------------------------------------------
292
+
293
+ @click.command()
294
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
295
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
296
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
297
+ def convert_network_pickle(source, dest, force_fp16):
298
+ """Convert legacy network pickle into the native PyTorch format.
299
+
300
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
301
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
302
+
303
+ Example:
304
+
305
+ \b
306
+ python legacy.py \\
307
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
308
+ --dest=stylegan2-cat-config-f.pkl
309
+ """
310
+ print(f'Loading "{source}"...')
311
+ with dnnlib.util.open_url(source) as f:
312
+ data = load_network_pkl(f, force_fp16=force_fp16)
313
+ print(f'Saving "{dest}"...')
314
+ with open(dest, 'wb') as f:
315
+ pickle.dump(data, f)
316
+ print('Done.')
317
+
318
+ #----------------------------------------------------------------------------
319
+
320
+ if __name__ == "__main__":
321
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
322
+
323
+ #----------------------------------------------------------------------------
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Ninja
4
+ gradio
5
+ huggingface_hub
6
+ hf_transfer
7
+ Pillow==9.5.0
8
+ psutil
visualizer_drag.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import os
11
+
12
+ import multiprocessing
13
+ import numpy as np
14
+ import torch
15
+ import imgui
16
+ import dnnlib
17
+ from gui_utils import imgui_window
18
+ from gui_utils import imgui_utils
19
+ from gui_utils import gl_utils
20
+ from gui_utils import text_utils
21
+ from viz import renderer
22
+ from viz import pickle_widget
23
+ from viz import latent_widget
24
+ from viz import drag_widget
25
+ from viz import capture_widget
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ class Visualizer(imgui_window.ImguiWindow):
30
+ def __init__(self, capture_dir=None):
31
+ super().__init__(title='DragGAN', window_width=3840, window_height=2160)
32
+
33
+ # Internals.
34
+ self._last_error_print = None
35
+ self._async_renderer = AsyncRenderer()
36
+ self._defer_rendering = 0
37
+ self._tex_img = None
38
+ self._tex_obj = None
39
+ self._mask_obj = None
40
+ self._image_area = None
41
+ self._status = dnnlib.EasyDict()
42
+
43
+ # Widget interface.
44
+ self.args = dnnlib.EasyDict()
45
+ self.result = dnnlib.EasyDict()
46
+ self.pane_w = 0
47
+ self.label_w = 0
48
+ self.button_w = 0
49
+ self.image_w = 0
50
+ self.image_h = 0
51
+
52
+ # Widgets.
53
+ self.pickle_widget = pickle_widget.PickleWidget(self)
54
+ self.latent_widget = latent_widget.LatentWidget(self)
55
+ self.drag_widget = drag_widget.DragWidget(self)
56
+ self.capture_widget = capture_widget.CaptureWidget(self)
57
+
58
+ if capture_dir is not None:
59
+ self.capture_widget.path = capture_dir
60
+
61
+ # Initialize window.
62
+ self.set_position(0, 0)
63
+ self._adjust_font_size()
64
+ self.skip_frame() # Layout may change after first frame.
65
+
66
+ def close(self):
67
+ super().close()
68
+ if self._async_renderer is not None:
69
+ self._async_renderer.close()
70
+ self._async_renderer = None
71
+
72
+ def add_recent_pickle(self, pkl, ignore_errors=False):
73
+ self.pickle_widget.add_recent(pkl, ignore_errors=ignore_errors)
74
+
75
+ def load_pickle(self, pkl, ignore_errors=False):
76
+ self.pickle_widget.load(pkl, ignore_errors=ignore_errors)
77
+
78
+ def print_error(self, error):
79
+ error = str(error)
80
+ if error != self._last_error_print:
81
+ print('\n' + error + '\n')
82
+ self._last_error_print = error
83
+
84
+ def defer_rendering(self, num_frames=1):
85
+ self._defer_rendering = max(self._defer_rendering, num_frames)
86
+
87
+ def clear_result(self):
88
+ self._async_renderer.clear_result()
89
+
90
+ def set_async(self, is_async):
91
+ if is_async != self._async_renderer.is_async:
92
+ self._async_renderer.set_async(is_async)
93
+ self.clear_result()
94
+ if 'image' in self.result:
95
+ self.result.message = 'Switching rendering process...'
96
+ self.defer_rendering()
97
+
98
+ def _adjust_font_size(self):
99
+ old = self.font_size
100
+ self.set_font_size(min(self.content_width / 120, self.content_height / 60))
101
+ if self.font_size != old:
102
+ self.skip_frame() # Layout changed.
103
+
104
+ def check_update_mask(self, **args):
105
+ update_mask = False
106
+ if 'pkl' in self._status:
107
+ if self._status.pkl != args['pkl']:
108
+ update_mask = True
109
+ self._status.pkl = args['pkl']
110
+ if 'w0_seed' in self._status:
111
+ if self._status.w0_seed != args['w0_seed']:
112
+ update_mask = True
113
+ self._status.w0_seed = args['w0_seed']
114
+ return update_mask
115
+
116
+ def capture_image_frame(self):
117
+ self.capture_next_frame()
118
+ captured_frame = self.pop_captured_frame()
119
+ captured_image = None
120
+ if captured_frame is not None:
121
+ x1, y1, w, h = self._image_area
122
+ captured_image = captured_frame[y1:y1+h, x1:x1+w, :]
123
+ return captured_image
124
+
125
+ def get_drag_info(self):
126
+ seed = self.latent_widget.seed
127
+ points = self.drag_widget.points
128
+ targets = self.drag_widget.targets
129
+ mask = self.drag_widget.mask
130
+ w = self._async_renderer._renderer_obj.w
131
+ return seed, points, targets, mask, w
132
+
133
+ def draw_frame(self):
134
+ self.begin_frame()
135
+ self.args = dnnlib.EasyDict()
136
+ self.pane_w = self.font_size * 18
137
+ self.button_w = self.font_size * 5
138
+ self.label_w = round(self.font_size * 4.5)
139
+
140
+ # Detect mouse dragging in the result area.
141
+ if self._image_area is not None:
142
+ if not hasattr(self.drag_widget, 'width'):
143
+ self.drag_widget.init_mask(self.image_w, self.image_h)
144
+ clicked, down, img_x, img_y = imgui_utils.click_hidden_window(
145
+ '##image_area', self._image_area[0], self._image_area[1], self._image_area[2], self._image_area[3], self.image_w, self.image_h)
146
+ self.drag_widget.action(clicked, down, img_x, img_y)
147
+
148
+ # Begin control pane.
149
+ imgui.set_next_window_position(0, 0)
150
+ imgui.set_next_window_size(self.pane_w, self.content_height)
151
+ imgui.begin('##control_pane', closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
152
+
153
+ # Widgets.
154
+ expanded, _visible = imgui_utils.collapsing_header('Network & latent', default=True)
155
+ self.pickle_widget(expanded)
156
+ self.latent_widget(expanded)
157
+ expanded, _visible = imgui_utils.collapsing_header('Drag', default=True)
158
+ self.drag_widget(expanded)
159
+ expanded, _visible = imgui_utils.collapsing_header('Capture', default=True)
160
+ self.capture_widget(expanded)
161
+
162
+ # Render.
163
+ if self.is_skipping_frames():
164
+ pass
165
+ elif self._defer_rendering > 0:
166
+ self._defer_rendering -= 1
167
+ elif self.args.pkl is not None:
168
+ self._async_renderer.set_args(**self.args)
169
+ result = self._async_renderer.get_result()
170
+ if result is not None:
171
+ self.result = result
172
+ if 'stop' in self.result and self.result.stop:
173
+ self.drag_widget.stop_drag()
174
+ if 'points' in self.result:
175
+ self.drag_widget.set_points(self.result.points)
176
+ if 'init_net' in self.result:
177
+ if self.result.init_net:
178
+ self.drag_widget.reset_point()
179
+
180
+ if self.check_update_mask(**self.args):
181
+ h, w, _ = self.result.image.shape
182
+ self.drag_widget.init_mask(w, h)
183
+
184
+ # Display.
185
+ max_w = self.content_width - self.pane_w
186
+ max_h = self.content_height
187
+ pos = np.array([self.pane_w + max_w / 2, max_h / 2])
188
+ if 'image' in self.result:
189
+ if self._tex_img is not self.result.image:
190
+ self._tex_img = self.result.image
191
+ if self._tex_obj is None or not self._tex_obj.is_compatible(image=self._tex_img):
192
+ self._tex_obj = gl_utils.Texture(image=self._tex_img, bilinear=False, mipmap=False)
193
+ else:
194
+ self._tex_obj.update(self._tex_img)
195
+ self.image_h, self.image_w = self._tex_obj.height, self._tex_obj.width
196
+ zoom = min(max_w / self._tex_obj.width, max_h / self._tex_obj.height)
197
+ zoom = np.floor(zoom) if zoom >= 1 else zoom
198
+ self._tex_obj.draw(pos=pos, zoom=zoom, align=0.5, rint=True)
199
+ if self.drag_widget.show_mask and hasattr(self.drag_widget, 'mask'):
200
+ mask = ((1-self.drag_widget.mask.unsqueeze(-1)) * 255).to(torch.uint8)
201
+ if self._mask_obj is None or not self._mask_obj.is_compatible(image=self._tex_img):
202
+ self._mask_obj = gl_utils.Texture(image=mask, bilinear=False, mipmap=False)
203
+ else:
204
+ self._mask_obj.update(mask)
205
+ self._mask_obj.draw(pos=pos, zoom=zoom, align=0.5, rint=True, alpha=0.15)
206
+
207
+ if self.drag_widget.mode in ['flexible', 'fixed']:
208
+ posx, posy = imgui.get_mouse_pos()
209
+ if posx >= self.pane_w:
210
+ pos_c = np.array([posx, posy])
211
+ gl_utils.draw_circle(center=pos_c, radius=self.drag_widget.r_mask * zoom, alpha=0.5)
212
+
213
+ rescale = self._tex_obj.width / 512 * zoom
214
+
215
+ for point in self.drag_widget.targets:
216
+ pos_x = self.pane_w + max_w / 2 + (point[1] - self.image_w//2) * zoom
217
+ pos_y = max_h / 2 + (point[0] - self.image_h//2) * zoom
218
+ gl_utils.draw_circle(center=np.array([pos_x, pos_y]), color=[0,0,1], radius=9 * rescale)
219
+
220
+ for point in self.drag_widget.points:
221
+ pos_x = self.pane_w + max_w / 2 + (point[1] - self.image_w//2) * zoom
222
+ pos_y = max_h / 2 + (point[0] - self.image_h//2) * zoom
223
+ gl_utils.draw_circle(center=np.array([pos_x, pos_y]), color=[1,0,0], radius=9 * rescale)
224
+
225
+ for point, target in zip(self.drag_widget.points, self.drag_widget.targets):
226
+ t_x = self.pane_w + max_w / 2 + (target[1] - self.image_w//2) * zoom
227
+ t_y = max_h / 2 + (target[0] - self.image_h//2) * zoom
228
+
229
+ p_x = self.pane_w + max_w / 2 + (point[1] - self.image_w//2) * zoom
230
+ p_y = max_h / 2 + (point[0] - self.image_h//2) * zoom
231
+
232
+ gl_utils.draw_arrow(p_x, p_y, t_x, t_y, l=8 * rescale, width = 3 * rescale)
233
+
234
+ imshow_w = int(self._tex_obj.width * zoom)
235
+ imshow_h = int(self._tex_obj.height * zoom)
236
+ self._image_area = [int(self.pane_w + max_w / 2 - imshow_w / 2), int(max_h / 2 - imshow_h / 2), imshow_w, imshow_h]
237
+ if 'error' in self.result:
238
+ self.print_error(self.result.error)
239
+ if 'message' not in self.result:
240
+ self.result.message = str(self.result.error)
241
+ if 'message' in self.result:
242
+ tex = text_utils.get_texture(self.result.message, size=self.font_size, max_width=max_w, max_height=max_h, outline=2)
243
+ tex.draw(pos=pos, align=0.5, rint=True, color=1)
244
+
245
+ # End frame.
246
+ self._adjust_font_size()
247
+ imgui.end()
248
+ self.end_frame()
249
+
250
+ #----------------------------------------------------------------------------
251
+
252
+ class AsyncRenderer:
253
+ def __init__(self):
254
+ self._closed = False
255
+ self._is_async = False
256
+ self._cur_args = None
257
+ self._cur_result = None
258
+ self._cur_stamp = 0
259
+ self._renderer_obj = None
260
+ self._args_queue = None
261
+ self._result_queue = None
262
+ self._process = None
263
+
264
+ def close(self):
265
+ self._closed = True
266
+ self._renderer_obj = None
267
+ if self._process is not None:
268
+ self._process.terminate()
269
+ self._process = None
270
+ self._args_queue = None
271
+ self._result_queue = None
272
+
273
+ @property
274
+ def is_async(self):
275
+ return self._is_async
276
+
277
+ def set_async(self, is_async):
278
+ self._is_async = is_async
279
+
280
+ def set_args(self, **args):
281
+ assert not self._closed
282
+ args2 = args.copy()
283
+ args_mask = args2.pop('mask')
284
+ if self._cur_args:
285
+ _cur_args = self._cur_args.copy()
286
+ cur_args_mask = _cur_args.pop('mask')
287
+ else:
288
+ _cur_args = self._cur_args
289
+ # if args != self._cur_args:
290
+ if args2 != _cur_args:
291
+ if self._is_async:
292
+ self._set_args_async(**args)
293
+ else:
294
+ self._set_args_sync(**args)
295
+ self._cur_args = args
296
+
297
+ def _set_args_async(self, **args):
298
+ if self._process is None:
299
+ self._args_queue = multiprocessing.Queue()
300
+ self._result_queue = multiprocessing.Queue()
301
+ try:
302
+ multiprocessing.set_start_method('spawn')
303
+ except RuntimeError:
304
+ pass
305
+ self._process = multiprocessing.Process(target=self._process_fn, args=(self._args_queue, self._result_queue), daemon=True)
306
+ self._process.start()
307
+ self._args_queue.put([args, self._cur_stamp])
308
+
309
+ def _set_args_sync(self, **args):
310
+ if self._renderer_obj is None:
311
+ self._renderer_obj = renderer.Renderer()
312
+ self._cur_result = self._renderer_obj.render(**args)
313
+
314
+ def get_result(self):
315
+ assert not self._closed
316
+ if self._result_queue is not None:
317
+ while self._result_queue.qsize() > 0:
318
+ result, stamp = self._result_queue.get()
319
+ if stamp == self._cur_stamp:
320
+ self._cur_result = result
321
+ return self._cur_result
322
+
323
+ def clear_result(self):
324
+ assert not self._closed
325
+ self._cur_args = None
326
+ self._cur_result = None
327
+ self._cur_stamp += 1
328
+
329
+ @staticmethod
330
+ def _process_fn(args_queue, result_queue):
331
+ renderer_obj = renderer.Renderer()
332
+ cur_args = None
333
+ cur_stamp = None
334
+ while True:
335
+ args, stamp = args_queue.get()
336
+ while args_queue.qsize() > 0:
337
+ args, stamp = args_queue.get()
338
+ if args != cur_args or stamp != cur_stamp:
339
+ result = renderer_obj.render(**args)
340
+ if 'error' in result:
341
+ result.error = renderer.CapturedException(result.error)
342
+ result_queue.put([result, stamp])
343
+ cur_args = args
344
+ cur_stamp = stamp
345
+
346
+ #----------------------------------------------------------------------------
347
+
348
+ @click.command()
349
+ @click.argument('pkls', metavar='PATH', nargs=-1)
350
+ @click.option('--capture-dir', help='Where to save screenshot captures', metavar='PATH', default=None)
351
+ @click.option('--browse-dir', help='Specify model path for the \'Browse...\' button', metavar='PATH')
352
+ def main(
353
+ pkls,
354
+ capture_dir,
355
+ browse_dir
356
+ ):
357
+ """Interactive model visualizer.
358
+
359
+ Optional PATH argument can be used specify which .pkl file to load.
360
+ """
361
+ viz = Visualizer(capture_dir=capture_dir)
362
+
363
+ if browse_dir is not None:
364
+ viz.pickle_widget.search_dirs = [browse_dir]
365
+
366
+ # List pickles.
367
+ if len(pkls) > 0:
368
+ for pkl in pkls:
369
+ viz.add_recent_pickle(pkl)
370
+ viz.load_pickle(pkls[0])
371
+ else:
372
+ pretrained = [
373
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl',
374
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqdog-512x512.pkl',
375
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqv2-512x512.pkl',
376
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqwild-512x512.pkl',
377
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-brecahad-512x512.pkl',
378
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-celebahq-256x256.pkl',
379
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-cifar10-32x32.pkl',
380
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl',
381
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl',
382
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl',
383
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-1024x1024.pkl',
384
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-256x256.pkl',
385
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-lsundog-256x256.pkl',
386
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfaces-1024x1024.pkl',
387
+ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfacesu-1024x1024.pkl'
388
+ ]
389
+
390
+ # Populate recent pickles list with pretrained model URLs.
391
+ for url in pretrained:
392
+ viz.add_recent_pickle(url)
393
+
394
+ # Run.
395
+ while not viz.should_close():
396
+ viz.draw_frame()
397
+ viz.close()
398
+
399
+ #----------------------------------------------------------------------------
400
+
401
+ if __name__ == "__main__":
402
+ main()
403
+
404
+ #----------------------------------------------------------------------------
visualizer_drag_gradio.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/DragGan/DragGan-Models
2
+ # https://arxiv.org/abs/2305.10973
3
+ import os
4
+ import os.path as osp
5
+ from argparse import ArgumentParser
6
+ from functools import partial
7
+ from pathlib import Path
8
+ import time
9
+
10
+ import psutil
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+
17
+ import dnnlib
18
+ from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
19
+ get_latest_points_pair, get_valid_mask,
20
+ on_change_single_global_state)
21
+ from viz.renderer import Renderer, add_watermark_np
22
+
23
+
24
+ # download models from Hugging Face hub
25
+ from huggingface_hub import snapshot_download
26
+
27
+ model_dir = Path('./checkpoints')
28
+ snapshot_download('DragGan/DragGan-Models',
29
+ repo_type='model', local_dir=model_dir)
30
+
31
+ parser = ArgumentParser()
32
+ parser.add_argument('--share', action='store_true')
33
+ parser.add_argument('--cache-dir', type=str, default='./checkpoints')
34
+ args = parser.parse_args()
35
+
36
+ cache_dir = args.cache_dir
37
+
38
+ device = 'cuda'
39
+ IS_SPACE = "DragGan/DragGan" in os.environ.get('SPACE_ID', '')
40
+ TIMEOUT = 80
41
+
42
+
43
+ def reverse_point_pairs(points):
44
+ new_points = []
45
+ for p in points:
46
+ new_points.append([p[1], p[0]])
47
+ return new_points
48
+
49
+
50
+ def clear_state(global_state, target=None):
51
+ """Clear target history state from global_state
52
+ If target is not defined, points and mask will be both removed.
53
+ 1. set global_state['points'] as empty dict
54
+ 2. set global_state['mask'] as full-one mask.
55
+ """
56
+ if target is None:
57
+ target = ['point', 'mask']
58
+ if not isinstance(target, list):
59
+ target = [target]
60
+ if 'point' in target:
61
+ global_state['points'] = dict()
62
+ print('Clear Points State!')
63
+ if 'mask' in target:
64
+ image_raw = global_state["images"]["image_raw"]
65
+ global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
66
+ dtype=np.uint8)
67
+ print('Clear mask State!')
68
+
69
+ return global_state
70
+
71
+
72
+ def init_images(global_state):
73
+ """This function is called only ones with Gradio App is started.
74
+ 0. pre-process global_state, unpack value from global_state of need
75
+ 1. Re-init renderer
76
+ 2. run `renderer._render_drag_impl` with `is_drag=False` to generate
77
+ new image
78
+ 3. Assign images to global state and re-generate mask
79
+ """
80
+
81
+ if isinstance(global_state, gr.State):
82
+ state = global_state.value
83
+ else:
84
+ state = global_state
85
+
86
+ state['renderer'].init_network(
87
+ state['generator_params'], # res
88
+ valid_checkpoints_dict[state['pretrained_weight']], # pkl
89
+ state['params']['seed'], # w0_seed,
90
+ None, # w_load
91
+ state['params']['latent_space'] == 'w+', # w_plus
92
+ 'const',
93
+ state['params']['trunc_psi'], # trunc_psi,
94
+ state['params']['trunc_cutoff'], # trunc_cutoff,
95
+ None, # input_transform
96
+ state['params']['lr'] # lr,
97
+ )
98
+
99
+ state['renderer']._render_drag_impl(state['generator_params'],
100
+ is_drag=False,
101
+ to_pil=True)
102
+
103
+ init_image = state['generator_params'].image
104
+ state['images']['image_orig'] = init_image
105
+ state['images']['image_raw'] = init_image
106
+ state['images']['image_show'] = Image.fromarray(
107
+ add_watermark_np(np.array(init_image)))
108
+ state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
109
+ dtype=np.uint8)
110
+ return global_state
111
+
112
+
113
+ def update_image_draw(image, points, mask, show_mask, global_state=None):
114
+
115
+ image_draw = draw_points_on_image(image, points)
116
+ if show_mask and mask is not None and not (mask == 0).all() and not (
117
+ mask == 1).all():
118
+ image_draw = draw_mask_on_image(image_draw, mask)
119
+
120
+ image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
121
+ if global_state is not None:
122
+ global_state['images']['image_show'] = image_draw
123
+ return image_draw
124
+
125
+
126
+ def preprocess_mask_info(global_state, image):
127
+ """Function to handle mask information.
128
+ 1. last_mask is None: Do not need to change mask, return mask
129
+ 2. last_mask is not None:
130
+ 2.1 global_state is remove_mask:
131
+ 2.2 global_state is add_mask:
132
+ """
133
+ if isinstance(image, dict):
134
+ last_mask = get_valid_mask(image['mask'])
135
+ else:
136
+ last_mask = None
137
+ mask = global_state['mask']
138
+
139
+ # mask in global state is a placeholder with all 1.
140
+ if (mask == 1).all():
141
+ mask = last_mask
142
+
143
+ # last_mask = global_state['last_mask']
144
+ editing_mode = global_state['editing_state']
145
+
146
+ if last_mask is None:
147
+ return global_state
148
+
149
+ if editing_mode == 'remove_mask':
150
+ updated_mask = np.clip(mask - last_mask, 0, 1)
151
+ print(f'Last editing_state is {editing_mode}, do remove.')
152
+ elif editing_mode == 'add_mask':
153
+ updated_mask = np.clip(mask + last_mask, 0, 1)
154
+ print(f'Last editing_state is {editing_mode}, do add.')
155
+ else:
156
+ updated_mask = mask
157
+ print(f'Last editing_state is {editing_mode}, '
158
+ 'do nothing to mask.')
159
+
160
+ global_state['mask'] = updated_mask
161
+ # global_state['last_mask'] = None # clear buffer
162
+ return global_state
163
+
164
+
165
+ def print_memory_usage():
166
+ # Print system memory usage
167
+ print(f"System memory usage: {psutil.virtual_memory().percent}%")
168
+
169
+ # Print GPU memory usage
170
+ if torch.cuda.is_available():
171
+ device = torch.device("cuda")
172
+ print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB")
173
+ print(
174
+ f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB")
175
+ device_properties = torch.cuda.get_device_properties(device)
176
+ available_memory = device_properties.total_memory - \
177
+ torch.cuda.max_memory_allocated()
178
+ print(f"Available GPU memory: {available_memory / 1e9} GB")
179
+ else:
180
+ print("No GPU available")
181
+
182
+
183
+ # filter large models running on SPACES
184
+ allowed_checkpoints = [] # all checkpoints
185
+ if IS_SPACE:
186
+ allowed_checkpoints = ["stylegan_human_v2_512.pkl",
187
+ "stylegan2_dogs_1024_pytorch.pkl"]
188
+
189
+ valid_checkpoints_dict = {
190
+ f.name.split('.')[0]: str(f)
191
+ for f in Path(cache_dir).glob('*.pkl')
192
+ if f.name in allowed_checkpoints or not IS_SPACE
193
+ }
194
+ print('Valid checkpoint file:')
195
+ print(valid_checkpoints_dict)
196
+
197
+ init_pkl = 'stylegan_human_v2_512'
198
+
199
+ with gr.Blocks() as app:
200
+ gr.Markdown("""
201
+ # DragGAN - Drag Your GAN
202
+ ## Interactive Point-based Manipulation on the Generative Image Manifold
203
+ ### Unofficial Gradio Demo
204
+
205
+ **Due to high demand, only one model can be run at a time, or you can duplicate the space and run your own copy.**
206
+
207
+ <a href="https://huggingface.co/spaces/radames/DragGan?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
208
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
209
+
210
+ * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
211
+ * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) Β© [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
212
+ """)
213
+
214
+ # renderer = Renderer()
215
+ global_state = gr.State({
216
+ "images": {
217
+ # image_orig: the original image, change with seed/model is changed
218
+ # image_raw: image with mask and points, change durning optimization
219
+ # image_show: image showed on screen
220
+ },
221
+ "temporal_params": {
222
+ # stop
223
+ },
224
+ 'mask':
225
+ None, # mask for visualization, 1 for editing and 0 for unchange
226
+ 'last_mask': None, # last edited mask
227
+ 'show_mask': True, # add button
228
+ "generator_params": dnnlib.EasyDict(),
229
+ "params": {
230
+ "seed": int(np.random.randint(0, 2**32 - 1)),
231
+ "motion_lambda": 20,
232
+ "r1_in_pixels": 3,
233
+ "r2_in_pixels": 12,
234
+ "magnitude_direction_in_pixels": 1.0,
235
+ "latent_space": "w+",
236
+ "trunc_psi": 0.7,
237
+ "trunc_cutoff": None,
238
+ "lr": 0.001,
239
+ },
240
+ "device": device,
241
+ "draw_interval": 1,
242
+ "renderer": Renderer(disable_timing=True),
243
+ "points": {},
244
+ "curr_point": None,
245
+ "curr_type_point": "start",
246
+ 'editing_state': 'add_points',
247
+ 'pretrained_weight': init_pkl
248
+ })
249
+
250
+ # init image
251
+ global_state = init_images(global_state)
252
+ with gr.Row():
253
+
254
+ with gr.Row():
255
+
256
+ # Left --> tools
257
+ with gr.Column(scale=3):
258
+
259
+ # Pickle
260
+ with gr.Row():
261
+
262
+ with gr.Column(scale=1, min_width=10):
263
+ gr.Markdown(value='Pickle', show_label=False)
264
+
265
+ with gr.Column(scale=4, min_width=10):
266
+ form_pretrained_dropdown = gr.Dropdown(
267
+ choices=list(valid_checkpoints_dict.keys()),
268
+ label="Pretrained Model",
269
+ value=init_pkl,
270
+ )
271
+
272
+ # Latent
273
+ with gr.Row():
274
+ with gr.Column(scale=1, min_width=10):
275
+ gr.Markdown(value='Latent', show_label=False)
276
+
277
+ with gr.Column(scale=4, min_width=10):
278
+ form_seed_number = gr.Slider(
279
+ mininium=0,
280
+ maximum=2**32-1,
281
+ step=1,
282
+ value=global_state.value['params']['seed'],
283
+ interactive=True,
284
+ # randomize=True,
285
+ label="Seed",
286
+ )
287
+ form_lr_number = gr.Number(
288
+ value=global_state.value["params"]["lr"],
289
+ interactive=True,
290
+ label="Step Size")
291
+
292
+ with gr.Row():
293
+ with gr.Column(scale=2, min_width=10):
294
+ form_reset_image = gr.Button("Reset Image")
295
+ with gr.Column(scale=3, min_width=10):
296
+ form_latent_space = gr.Radio(
297
+ ['w', 'w+'],
298
+ value=global_state.value['params']
299
+ ['latent_space'],
300
+ interactive=True,
301
+ label='Latent space to optimize',
302
+ show_label=False,
303
+ )
304
+
305
+ # Drag
306
+ with gr.Row():
307
+ with gr.Column(scale=1, min_width=10):
308
+ gr.Markdown(value='Drag', show_label=False)
309
+ with gr.Column(scale=4, min_width=10):
310
+ with gr.Row():
311
+ with gr.Column(scale=1, min_width=10):
312
+ enable_add_points = gr.Button('Add Points')
313
+ with gr.Column(scale=1, min_width=10):
314
+ undo_points = gr.Button('Reset Points')
315
+ with gr.Row():
316
+ with gr.Column(scale=1, min_width=10):
317
+ form_start_btn = gr.Button("Start")
318
+ with gr.Column(scale=1, min_width=10):
319
+ form_stop_btn = gr.Button("Stop")
320
+
321
+ form_steps_number = gr.Number(value=0,
322
+ label="Steps",
323
+ interactive=False)
324
+
325
+ # Mask
326
+ with gr.Row():
327
+ with gr.Column(scale=1, min_width=10):
328
+ gr.Markdown(value='Mask', show_label=False)
329
+ with gr.Column(scale=4, min_width=10):
330
+ enable_add_mask = gr.Button('Edit Flexible Area')
331
+ with gr.Row():
332
+ with gr.Column(scale=1, min_width=10):
333
+ form_reset_mask_btn = gr.Button("Reset mask")
334
+ with gr.Column(scale=1, min_width=10):
335
+ show_mask = gr.Checkbox(
336
+ label='Show Mask',
337
+ value=global_state.value['show_mask'],
338
+ show_label=False)
339
+
340
+ with gr.Row():
341
+ form_lambda_number = gr.Number(
342
+ value=global_state.value["params"]
343
+ ["motion_lambda"],
344
+ interactive=True,
345
+ label="Lambda",
346
+ )
347
+
348
+ form_draw_interval_number = gr.Number(
349
+ value=global_state.value["draw_interval"],
350
+ label="Draw Interval (steps)",
351
+ interactive=True,
352
+ visible=False)
353
+
354
+ # Right --> Image
355
+ with gr.Column(scale=8):
356
+ form_image = ImageMask(
357
+ value=global_state.value['images']['image_show'],
358
+ brush_radius=20).style(
359
+ width=768,
360
+ height=768) # NOTE: hard image size code here.
361
+ gr.Markdown("""
362
+ ## Quick Start
363
+
364
+ 1. Select desired `Pretrained Model` and adjust `Seed` to generate an
365
+ initial image.
366
+ 2. Click on image to add control points.
367
+ 3. Click `Start` and enjoy it!
368
+
369
+ ## Advance Usage
370
+
371
+ 1. Change `Step Size` to adjust learning rate in drag optimization.
372
+ 2. Select `w` or `w+` to change latent space to optimize:
373
+ * Optimize on `w` space may cause greater influence to the image.
374
+ * Optimize on `w+` space may work slower than `w`, but usually achieve
375
+ better results.
376
+ * Note that changing the latent space will reset the image, points and
377
+ mask (this has the same effect as `Reset Image` button).
378
+ 3. Click `Edit Flexible Area` to create a mask and constrain the
379
+ unmasked region to remain unchanged.
380
+
381
+
382
+ """)
383
+ gr.HTML("""
384
+ <style>
385
+ .container {
386
+ position: absolute;
387
+ height: 50px;
388
+ text-align: center;
389
+ line-height: 50px;
390
+ width: 100%;
391
+ }
392
+ </style>
393
+ <div class="container">
394
+ Gradio demo supported by
395
+ <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
396
+ <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
397
+ </div>
398
+ """)
399
+ # Network & latents tab listeners
400
+
401
+ def on_change_pretrained_dropdown(pretrained_value, global_state):
402
+ """Function to handle model change.
403
+ 1. Set pretrained value to global_state
404
+ 2. Re-init images and clear all states
405
+ """
406
+
407
+ global_state['pretrained_weight'] = pretrained_value
408
+ init_images(global_state)
409
+ clear_state(global_state)
410
+
411
+ return global_state, global_state["images"]['image_show']
412
+
413
+ form_pretrained_dropdown.change(
414
+ on_change_pretrained_dropdown,
415
+ inputs=[form_pretrained_dropdown, global_state],
416
+ outputs=[global_state, form_image],
417
+ queue=True,
418
+ )
419
+
420
+ def on_click_reset_image(global_state):
421
+ """Reset image to the original one and clear all states
422
+ 1. Re-init images
423
+ 2. Clear all states
424
+ """
425
+
426
+ init_images(global_state)
427
+ clear_state(global_state)
428
+
429
+ return global_state, global_state['images']['image_show']
430
+
431
+ form_reset_image.click(
432
+ on_click_reset_image,
433
+ inputs=[global_state],
434
+ outputs=[global_state, form_image],
435
+ queue=False,
436
+ )
437
+
438
+ # Update parameters
439
+ def on_change_update_image_seed(seed, global_state):
440
+ """Function to handle generation seed change.
441
+ 1. Set seed to global_state
442
+ 2. Re-init images and clear all states
443
+ """
444
+
445
+ global_state["params"]["seed"] = int(seed)
446
+ init_images(global_state)
447
+ clear_state(global_state)
448
+
449
+ return global_state, global_state['images']['image_show']
450
+
451
+ form_seed_number.change(
452
+ on_change_update_image_seed,
453
+ inputs=[form_seed_number, global_state],
454
+ outputs=[global_state, form_image],
455
+ )
456
+
457
+ def on_click_latent_space(latent_space, global_state):
458
+ """Function to reset latent space to optimize.
459
+ NOTE: this function we reset the image and all controls
460
+ 1. Set latent-space to global_state
461
+ 2. Re-init images and clear all state
462
+ """
463
+
464
+ global_state['params']['latent_space'] = latent_space
465
+ init_images(global_state)
466
+ clear_state(global_state)
467
+
468
+ return global_state, global_state['images']['image_show']
469
+
470
+ form_latent_space.change(on_click_latent_space,
471
+ inputs=[form_latent_space, global_state],
472
+ outputs=[global_state, form_image])
473
+
474
+ # ==== Params
475
+ form_lambda_number.change(
476
+ partial(on_change_single_global_state, ["params", "motion_lambda"]),
477
+ inputs=[form_lambda_number, global_state],
478
+ outputs=[global_state],
479
+ )
480
+
481
+ def on_change_lr(lr, global_state):
482
+ if lr == 0:
483
+ print('lr is 0, do nothing.')
484
+ return global_state
485
+ else:
486
+ global_state["params"]["lr"] = lr
487
+ renderer = global_state['renderer']
488
+ renderer.update_lr(lr)
489
+ print('New optimizer: ')
490
+ print(renderer.w_optim)
491
+ return global_state
492
+
493
+ form_lr_number.change(
494
+ on_change_lr,
495
+ inputs=[form_lr_number, global_state],
496
+ outputs=[global_state],
497
+ queue=False,
498
+ )
499
+
500
+ def on_click_start(global_state, image):
501
+ p_in_pixels = []
502
+ t_in_pixels = []
503
+ valid_points = []
504
+
505
+ # handle of start drag in mask editing mode
506
+ global_state = preprocess_mask_info(global_state, image)
507
+
508
+ # Prepare the points for the inference
509
+ if len(global_state["points"]) == 0:
510
+ # yield on_click_start_wo_points(global_state, image)
511
+ image_raw = global_state['images']['image_raw']
512
+ update_image_draw(
513
+ image_raw,
514
+ global_state['points'],
515
+ global_state['mask'],
516
+ global_state['show_mask'],
517
+ global_state,
518
+ )
519
+
520
+ yield (
521
+ global_state,
522
+ 0,
523
+ global_state['images']['image_show'],
524
+ # gr.File.update(visible=False),
525
+ gr.Button.update(interactive=True),
526
+ gr.Button.update(interactive=True),
527
+ gr.Button.update(interactive=True),
528
+ gr.Button.update(interactive=True),
529
+ gr.Button.update(interactive=True),
530
+ # latent space
531
+ gr.Radio.update(interactive=True),
532
+ gr.Button.update(interactive=True),
533
+ # NOTE: disable stop button
534
+ gr.Button.update(interactive=False),
535
+
536
+ # update other comps
537
+ gr.Dropdown.update(interactive=True),
538
+ gr.Number.update(interactive=True),
539
+ gr.Number.update(interactive=True),
540
+ gr.Button.update(interactive=True),
541
+ gr.Button.update(interactive=True),
542
+ gr.Checkbox.update(interactive=True),
543
+ # gr.Number.update(interactive=True),
544
+ gr.Number.update(interactive=True),
545
+ )
546
+ else:
547
+
548
+ # Transform the points into torch tensors
549
+ for key_point, point in global_state["points"].items():
550
+ try:
551
+ p_start = point.get("start_temp", point["start"])
552
+ p_end = point["target"]
553
+
554
+ if p_start is None or p_end is None:
555
+ continue
556
+
557
+ except KeyError:
558
+ continue
559
+
560
+ p_in_pixels.append(p_start)
561
+ t_in_pixels.append(p_end)
562
+ valid_points.append(key_point)
563
+
564
+ mask = torch.tensor(global_state['mask']).float()
565
+ drag_mask = 1 - mask
566
+
567
+ renderer: Renderer = global_state["renderer"]
568
+ global_state['temporal_params']['stop'] = False
569
+ global_state['editing_state'] = 'running'
570
+
571
+ # reverse points order
572
+ p_to_opt = reverse_point_pairs(p_in_pixels)
573
+ t_to_opt = reverse_point_pairs(t_in_pixels)
574
+ print('Running with:')
575
+ print(f' Source: {p_in_pixels}')
576
+ print(f' Target: {t_in_pixels}')
577
+ step_idx = 0
578
+ last_time = time.time()
579
+ while True:
580
+ print_memory_usage()
581
+ # add a TIMEOUT break
582
+ print(f'Running time: {time.time() - last_time}')
583
+ if IS_SPACE and time.time() - last_time > TIMEOUT:
584
+ print('Timeout break!')
585
+ break
586
+ if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]:
587
+ break
588
+
589
+ # do drage here!
590
+ renderer._render_drag_impl(
591
+ global_state['generator_params'],
592
+ p_to_opt, # point
593
+ t_to_opt, # target
594
+ drag_mask, # mask,
595
+ global_state['params']['motion_lambda'], # lambda_mask
596
+ reg=0,
597
+ feature_idx=5, # NOTE: do not support change for now
598
+ r1=global_state['params']['r1_in_pixels'], # r1
599
+ r2=global_state['params']['r2_in_pixels'], # r2
600
+ # random_seed = 0,
601
+ # noise_mode = 'const',
602
+ trunc_psi=global_state['params']['trunc_psi'],
603
+ # force_fp32 = False,
604
+ # layer_name = None,
605
+ # sel_channels = 3,
606
+ # base_channel = 0,
607
+ # img_scale_db = 0,
608
+ # img_normalize = False,
609
+ # untransform = False,
610
+ is_drag=True,
611
+ to_pil=True)
612
+
613
+ if step_idx % global_state['draw_interval'] == 0:
614
+ print('Current Source:')
615
+ for key_point, p_i, t_i in zip(valid_points, p_to_opt,
616
+ t_to_opt):
617
+ global_state["points"][key_point]["start_temp"] = [
618
+ p_i[1],
619
+ p_i[0],
620
+ ]
621
+ global_state["points"][key_point]["target"] = [
622
+ t_i[1],
623
+ t_i[0],
624
+ ]
625
+ start_temp = global_state["points"][key_point][
626
+ "start_temp"]
627
+ print(f' {start_temp}')
628
+
629
+ image_result = global_state['generator_params']['image']
630
+ image_draw = update_image_draw(
631
+ image_result,
632
+ global_state['points'],
633
+ global_state['mask'],
634
+ global_state['show_mask'],
635
+ global_state,
636
+ )
637
+ global_state['images']['image_raw'] = image_result
638
+
639
+ yield (
640
+ global_state,
641
+ step_idx,
642
+ global_state['images']['image_show'],
643
+ # gr.File.update(visible=False),
644
+ gr.Button.update(interactive=False),
645
+ gr.Button.update(interactive=False),
646
+ gr.Button.update(interactive=False),
647
+ gr.Button.update(interactive=False),
648
+ gr.Button.update(interactive=False),
649
+ # latent space
650
+ gr.Radio.update(interactive=False),
651
+ gr.Button.update(interactive=False),
652
+ # enable stop button in loop
653
+ gr.Button.update(interactive=True),
654
+
655
+ # update other comps
656
+ gr.Dropdown.update(interactive=False),
657
+ gr.Number.update(interactive=False),
658
+ gr.Number.update(interactive=False),
659
+ gr.Button.update(interactive=False),
660
+ gr.Button.update(interactive=False),
661
+ gr.Checkbox.update(interactive=False),
662
+ # gr.Number.update(interactive=False),
663
+ gr.Number.update(interactive=False),
664
+ )
665
+
666
+ # increate step
667
+ step_idx += 1
668
+
669
+ image_result = global_state['generator_params']['image']
670
+ global_state['images']['image_raw'] = image_result
671
+ image_draw = update_image_draw(image_result,
672
+ global_state['points'],
673
+ global_state['mask'],
674
+ global_state['show_mask'],
675
+ global_state)
676
+
677
+ # fp = NamedTemporaryFile(suffix=".png", delete=False)
678
+ # image_result.save(fp, "PNG")
679
+
680
+ global_state['editing_state'] = 'add_points'
681
+
682
+ yield (
683
+ global_state,
684
+ 0, # reset step to 0 after stop.
685
+ global_state['images']['image_show'],
686
+ # gr.File.update(visible=True, value=fp.name),
687
+ gr.Button.update(interactive=True),
688
+ gr.Button.update(interactive=True),
689
+ gr.Button.update(interactive=True),
690
+ gr.Button.update(interactive=True),
691
+ gr.Button.update(interactive=True),
692
+ # latent space
693
+ gr.Radio.update(interactive=True),
694
+ gr.Button.update(interactive=True),
695
+ # NOTE: disable stop button with loop finish
696
+ gr.Button.update(interactive=False),
697
+
698
+ # update other comps
699
+ gr.Dropdown.update(interactive=True),
700
+ gr.Number.update(interactive=True),
701
+ gr.Number.update(interactive=True),
702
+ gr.Checkbox.update(interactive=True),
703
+ gr.Number.update(interactive=True),
704
+ )
705
+
706
+ form_start_btn.click(
707
+ on_click_start,
708
+ inputs=[global_state, form_image],
709
+ outputs=[
710
+ global_state,
711
+ form_steps_number,
712
+ form_image,
713
+ # form_download_result_file,
714
+ # >>> buttons
715
+ form_reset_image,
716
+ enable_add_points,
717
+ enable_add_mask,
718
+ undo_points,
719
+ form_reset_mask_btn,
720
+ form_latent_space,
721
+ form_start_btn,
722
+ form_stop_btn,
723
+ # <<< buttonm
724
+ # >>> inputs comps
725
+ form_pretrained_dropdown,
726
+ form_seed_number,
727
+ form_lr_number,
728
+ show_mask,
729
+ form_lambda_number,
730
+ ],
731
+ )
732
+
733
+ def on_click_stop(global_state):
734
+ """Function to handle stop button is clicked.
735
+ 1. send a stop signal by set global_state["temporal_params"]["stop"] as True
736
+ 2. Disable Stop button
737
+ """
738
+ global_state["temporal_params"]["stop"] = True
739
+
740
+ return global_state, gr.Button.update(interactive=False)
741
+
742
+ form_stop_btn.click(on_click_stop,
743
+ inputs=[global_state],
744
+ outputs=[global_state, form_stop_btn],
745
+ queue=False)
746
+
747
+ form_draw_interval_number.change(
748
+ partial(
749
+ on_change_single_global_state,
750
+ "draw_interval",
751
+ map_transform=lambda x: int(x),
752
+ ),
753
+ inputs=[form_draw_interval_number, global_state],
754
+ outputs=[global_state],
755
+ queue=False,
756
+ )
757
+
758
+ def on_click_remove_point(global_state):
759
+ choice = global_state["curr_point"]
760
+ del global_state["points"][choice]
761
+
762
+ choices = list(global_state["points"].keys())
763
+
764
+ if len(choices) > 0:
765
+ global_state["curr_point"] = choices[0]
766
+
767
+ return (
768
+ gr.Dropdown.update(choices=choices, value=choices[0]),
769
+ global_state,
770
+ )
771
+
772
+ # Mask
773
+ def on_click_reset_mask(global_state):
774
+ global_state['mask'] = np.ones(
775
+ (
776
+ global_state["images"]["image_raw"].size[1],
777
+ global_state["images"]["image_raw"].size[0],
778
+ ),
779
+ dtype=np.uint8,
780
+ )
781
+ image_draw = update_image_draw(global_state['images']['image_raw'],
782
+ global_state['points'],
783
+ global_state['mask'],
784
+ global_state['show_mask'], global_state)
785
+ return global_state, image_draw
786
+
787
+ form_reset_mask_btn.click(
788
+ on_click_reset_mask,
789
+ inputs=[global_state],
790
+ outputs=[global_state, form_image],
791
+ )
792
+
793
+ # Image
794
+ def on_click_enable_draw(global_state, image):
795
+ """Function to start add mask mode.
796
+ 1. Preprocess mask info from last state
797
+ 2. Change editing state to add_mask
798
+ 3. Set curr image with points and mask
799
+ """
800
+ global_state = preprocess_mask_info(global_state, image)
801
+ global_state['editing_state'] = 'add_mask'
802
+ image_raw = global_state['images']['image_raw']
803
+ image_draw = update_image_draw(image_raw, global_state['points'],
804
+ global_state['mask'], True,
805
+ global_state)
806
+ return (global_state,
807
+ gr.Image.update(value=image_draw, interactive=True))
808
+
809
+ def on_click_remove_draw(global_state, image):
810
+ """Function to start remove mask mode.
811
+ 1. Preprocess mask info from last state
812
+ 2. Change editing state to remove_mask
813
+ 3. Set curr image with points and mask
814
+ """
815
+ global_state = preprocess_mask_info(global_state, image)
816
+ global_state['edinting_state'] = 'remove_mask'
817
+ image_raw = global_state['images']['image_raw']
818
+ image_draw = update_image_draw(image_raw, global_state['points'],
819
+ global_state['mask'], True,
820
+ global_state)
821
+ return (global_state,
822
+ gr.Image.update(value=image_draw, interactive=True))
823
+
824
+ enable_add_mask.click(on_click_enable_draw,
825
+ inputs=[global_state, form_image],
826
+ outputs=[
827
+ global_state,
828
+ form_image,
829
+ ],
830
+ queue=False)
831
+
832
+ def on_click_add_point(global_state, image: dict):
833
+ """Function switch from add mask mode to add points mode.
834
+ 1. Updaste mask buffer if need
835
+ 2. Change global_state['editing_state'] to 'add_points'
836
+ 3. Set current image with mask
837
+ """
838
+
839
+ global_state = preprocess_mask_info(global_state, image)
840
+ global_state['editing_state'] = 'add_points'
841
+ mask = global_state['mask']
842
+ image_raw = global_state['images']['image_raw']
843
+ image_draw = update_image_draw(image_raw, global_state['points'], mask,
844
+ global_state['show_mask'], global_state)
845
+
846
+ return (global_state,
847
+ gr.Image.update(value=image_draw, interactive=False))
848
+
849
+ enable_add_points.click(on_click_add_point,
850
+ inputs=[global_state, form_image],
851
+ outputs=[global_state, form_image],
852
+ queue=False)
853
+
854
+ def on_click_image(global_state, evt: gr.SelectData):
855
+ """This function only support click for point selection
856
+ """
857
+ xy = evt.index
858
+ if global_state['editing_state'] != 'add_points':
859
+ print(f'In {global_state["editing_state"]} state. '
860
+ 'Do not add points.')
861
+
862
+ return global_state, global_state['images']['image_show']
863
+
864
+ points = global_state["points"]
865
+
866
+ point_idx = get_latest_points_pair(points)
867
+ if point_idx is None:
868
+ points[0] = {'start': xy, 'target': None}
869
+ print(f'Click Image - Start - {xy}')
870
+ elif points[point_idx].get('target', None) is None:
871
+ points[point_idx]['target'] = xy
872
+ print(f'Click Image - Target - {xy}')
873
+ else:
874
+ points[point_idx + 1] = {'start': xy, 'target': None}
875
+ print(f'Click Image - Start - {xy}')
876
+
877
+ image_raw = global_state['images']['image_raw']
878
+ image_draw = update_image_draw(
879
+ image_raw,
880
+ global_state['points'],
881
+ global_state['mask'],
882
+ global_state['show_mask'],
883
+ global_state,
884
+ )
885
+
886
+ return global_state, image_draw
887
+
888
+ form_image.select(
889
+ on_click_image,
890
+ inputs=[global_state],
891
+ outputs=[global_state, form_image],
892
+ queue=False,
893
+ )
894
+
895
+ def on_click_clear_points(global_state):
896
+ """Function to handle clear all control points
897
+ 1. clear global_state['points'] (clear_state)
898
+ 2. re-init network
899
+ 2. re-draw image
900
+ """
901
+ clear_state(global_state, target='point')
902
+
903
+ renderer: Renderer = global_state["renderer"]
904
+ renderer.feat_refs = None
905
+
906
+ image_raw = global_state['images']['image_raw']
907
+ image_draw = update_image_draw(image_raw, {}, global_state['mask'],
908
+ global_state['show_mask'], global_state)
909
+ return global_state, image_draw
910
+
911
+ undo_points.click(on_click_clear_points,
912
+ inputs=[global_state],
913
+ outputs=[global_state, form_image],
914
+ queue=False)
915
+
916
+ def on_click_show_mask(global_state, show_mask):
917
+ """Function to control whether show mask on image."""
918
+ global_state['show_mask'] = show_mask
919
+
920
+ image_raw = global_state['images']['image_raw']
921
+ image_draw = update_image_draw(
922
+ image_raw,
923
+ global_state['points'],
924
+ global_state['mask'],
925
+ global_state['show_mask'],
926
+ global_state,
927
+ )
928
+ return global_state, image_draw
929
+
930
+ show_mask.change(
931
+ on_click_show_mask,
932
+ inputs=[global_state, show_mask],
933
+ outputs=[global_state, form_image],
934
+ queue=False,
935
+ )
936
+
937
+ print("SHAReD: Start app", parser.parse_args())
938
+ gr.close_all()
939
+ app.queue(concurrency_count=1, max_size=200, api_open=False)
940
+ app.launch(share=args.share, show_api=False)