File size: 12,524 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import importlib
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import SGD, Adam, RMSprop, lr_scheduler

from imaginaire.optimizers import Fromage, Madam
from imaginaire.utils.distributed import get_rank, get_world_size
from imaginaire.utils.distributed import master_only_print as print
from imaginaire.utils.init_weight import weights_init, weights_rescale
from imaginaire.utils.model_average import ModelAverage


def set_random_seed(seed, by_rank=False):
    r"""Set random seeds for everything.

    Args:
        seed (int): Random seed.
        by_rank (bool):
    """
    if by_rank:
        seed += get_rank()
    print(f"Using random seed {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_trainer(cfg, net_G, net_D=None,
                opt_G=None, opt_D=None,
                sch_G=None, sch_D=None,
                train_data_loader=None,
                val_data_loader=None):
    """Return the trainer object.

    Args:
        cfg (Config): Loaded config object.
        net_G (obj): Generator network object.
        net_D (obj): Discriminator network object.
        opt_G (obj): Generator optimizer object.
        opt_D (obj): Discriminator optimizer object.
        sch_G (obj): Generator optimizer scheduler object.
        sch_D (obj): Discriminator optimizer scheduler object.
        train_data_loader (obj): Train data loader.
        val_data_loader (obj): Validation data loader.

    Returns:
        (obj): Trainer object.
    """
    trainer_lib = importlib.import_module(cfg.trainer.type)
    trainer = trainer_lib.Trainer(cfg, net_G, net_D,
                                  opt_G, opt_D,
                                  sch_G, sch_D,
                                  train_data_loader, val_data_loader)
    return trainer


def get_model_optimizer_and_scheduler(cfg, seed=0):
    r"""Return the networks, the optimizers, and the schedulers. We will
    first set the random seed to a fixed value so that each GPU copy will be
    initialized to have the same network weights. We will then use different
    random seeds for different GPUs. After this we will wrap the generator
    with a moving average model if applicable. It is followed by getting the
    optimizers and data distributed data parallel wrapping.

    Args:
        cfg (obj): Global configuration.
        seed (int): Random seed.

    Returns:
        (dict):
          - net_G (obj): Generator network object.
          - net_D (obj): Discriminator network object.
          - opt_G (obj): Generator optimizer object.
          - opt_D (obj): Discriminator optimizer object.
          - sch_G (obj): Generator optimizer scheduler object.
          - sch_D (obj): Discriminator optimizer scheduler object.
    """
    # We first set the random seed to be the same so that we initialize each
    # copy of the network in exactly the same way so that they have the same
    # weights and other parameters. The true seed will be the seed.
    set_random_seed(seed, by_rank=False)
    # Construct networks
    lib_G = importlib.import_module(cfg.gen.type)
    lib_D = importlib.import_module(cfg.dis.type)
    net_G = lib_G.Generator(cfg.gen, cfg.data)
    net_D = lib_D.Discriminator(cfg.dis, cfg.data)
    print('Initialize net_G and net_D weights using '
          'type: {} gain: {}'.format(cfg.trainer.init.type,
                                     cfg.trainer.init.gain))
    init_bias = getattr(cfg.trainer.init, 'bias', None)
    net_G.apply(weights_init(
        cfg.trainer.init.type, cfg.trainer.init.gain, init_bias))
    net_D.apply(weights_init(
        cfg.trainer.init.type, cfg.trainer.init.gain, init_bias))
    net_G.apply(weights_rescale())
    net_D.apply(weights_rescale())
    # for name, p in net_G.named_parameters():
    #     if 'modulation' in name and 'bias' in name:
    #         nn.init.constant_(p.data, 1.)
    net_G = net_G.to('cuda')
    net_D = net_D.to('cuda')
    # Different GPU copies of the same model will receive noises
    # initialized with different random seeds (if applicable) thanks to the
    # set_random_seed command (GPU #K has random seed = args.seed + K).
    set_random_seed(seed, by_rank=True)
    print('net_G parameter count: {:,}'.format(_calculate_model_size(net_G)))
    print('net_D parameter count: {:,}'.format(_calculate_model_size(net_D)))

    # Optimizer
    opt_G = get_optimizer(cfg.gen_opt, net_G)
    opt_D = get_optimizer(cfg.dis_opt, net_D)

    net_G, net_D, opt_G, opt_D = \
        wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D)

    # Scheduler
    sch_G = get_scheduler(cfg.gen_opt, opt_G)
    sch_D = get_scheduler(cfg.dis_opt, opt_D)

    return net_G, net_D, opt_G, opt_D, sch_G, sch_D


def wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D):
    r"""Wrap the networks and the optimizers with AMP DDP and (optionally)
    model average.

    Args:
        cfg (obj): Global configuration.
        net_G (obj): Generator network object.
        net_D (obj): Discriminator network object.
        opt_G (obj): Generator optimizer object.
        opt_D (obj): Discriminator optimizer object.

    Returns:
        (dict):
          - net_G (obj): Generator network object.
          - net_D (obj): Discriminator network object.
          - opt_G (obj): Generator optimizer object.
          - opt_D (obj): Discriminator optimizer object.
    """
    # Apply model average wrapper.
    if cfg.trainer.model_average_config.enabled:
        if hasattr(cfg.trainer.model_average_config, 'g_smooth_img'):
            # Specifies half-life of the running average of generator weights.
            cfg.trainer.model_average_config.beta = \
                0.5 ** (cfg.data.train.batch_size *
                        get_world_size() / cfg.trainer.model_average_config.g_smooth_img)
            print(f"EMA Decay Factor: {cfg.trainer.model_average_config.beta}")
        net_G = ModelAverage(net_G, cfg.trainer.model_average_config.beta,
                             cfg.trainer.model_average_config.start_iteration,
                             cfg.trainer.model_average_config.remove_sn)
    if cfg.trainer.model_average_config.enabled:
        net_G_module = net_G.module
    else:
        net_G_module = net_G
    if hasattr(net_G_module, 'custom_init'):
        net_G_module.custom_init()

    net_G = _wrap_model(cfg, net_G)
    net_D = _wrap_model(cfg, net_D)
    return net_G, net_D, opt_G, opt_D


def _calculate_model_size(model):
    r"""Calculate number of parameters in a PyTorch network.

    Args:
        model (obj): PyTorch network.

    Returns:
        (int): Number of parameters.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class WrappedModel(nn.Module):
    r"""Dummy wrapping the module.
    """

    def __init__(self, module):
        super(WrappedModel, self).__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        r"""PyTorch module forward function overload."""
        return self.module(*args, **kwargs)


def _wrap_model(cfg, model):
    r"""Wrap a model for distributed data parallel training.

    Args:
        model (obj): PyTorch network model.

    Returns:
        (obj): Wrapped PyTorch network model.
    """
    if torch.distributed.is_available() and dist.is_initialized():
        # ddp = cfg.trainer.distributed_data_parallel
        find_unused_parameters = cfg.trainer.distributed_data_parallel_params.find_unused_parameters
        return torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[cfg.local_rank],
            output_device=cfg.local_rank,
            find_unused_parameters=find_unused_parameters,
            broadcast_buffers=False
        )
        # if ddp == 'pytorch':
        #     return torch.nn.parallel.DistributedDataParallel(
        #         model,
        #         device_ids=[cfg.local_rank],
        #         output_device=cfg.local_rank,
        #         find_unused_parameters=find_unused_parameters,
        #         broadcast_buffers=False)
        # else:
        #     delay_allreduce = cfg.trainer.delay_allreduce
        #     return apex.parallel.DistributedDataParallel(
        #         model, delay_allreduce=delay_allreduce)
    else:
        return WrappedModel(model)


def get_scheduler(cfg_opt, opt):
    """Return the scheduler object.

    Args:
        cfg_opt (obj): Config for the specific optimization module (gen/dis).
        opt (obj): PyTorch optimizer object.

    Returns:
        (obj): Scheduler
    """
    if cfg_opt.lr_policy.type == 'step':
        scheduler = lr_scheduler.StepLR(
            opt,
            step_size=cfg_opt.lr_policy.step_size,
            gamma=cfg_opt.lr_policy.gamma)
    elif cfg_opt.lr_policy.type == 'constant':
        scheduler = lr_scheduler.LambdaLR(opt, lambda x: 1)
    elif cfg_opt.lr_policy.type == 'linear':
        # Start linear decay from here.
        decay_start = cfg_opt.lr_policy.decay_start
        # End linear decay here.
        # Continue to train using the lowest learning rate till the end.
        decay_end = cfg_opt.lr_policy.decay_end
        # Lowest learning rate multiplier.
        decay_target = cfg_opt.lr_policy.decay_target

        def sch(x):
            return min(
                max(((x - decay_start) * decay_target + decay_end - x) / (
                    decay_end - decay_start
                ), decay_target), 1.
            )
        scheduler = lr_scheduler.LambdaLR(opt, lambda x: sch(x))
    else:
        return NotImplementedError('Learning rate policy {} not implemented.'.
                                   format(cfg_opt.lr_policy.type))
    return scheduler


def get_optimizer(cfg_opt, net):
    r"""Return the scheduler object.

    Args:
        cfg_opt (obj): Config for the specific optimization module (gen/dis).
        net (obj): PyTorch network object.

    Returns:
        (obj): Pytorch optimizer
    """
    if hasattr(net, 'get_param_groups'):
        # Allow the network to use different hyper-parameters (e.g., learning
        # rate) for different parameters.
        params = net.get_param_groups(cfg_opt)
    else:
        params = net.parameters()
    return get_optimizer_for_params(cfg_opt, params)


def get_optimizer_for_params(cfg_opt, params):
    r"""Return the scheduler object.

    Args:
        cfg_opt (obj): Config for the specific optimization module (gen/dis).
        params (obj): Parameters to be trained by the parameters.

    Returns:
        (obj): Optimizer
    """
    # We will use fuse optimizers by default.
    fused_opt = cfg_opt.fused_opt
    try:
        from apex.optimizers import FusedAdam
    except:  # noqa
        fused_opt = False

    if cfg_opt.type == 'adam':
        if fused_opt:
            opt = FusedAdam(params,
                            lr=cfg_opt.lr, eps=cfg_opt.eps,
                            betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2))
        else:
            opt = Adam(params,
                       lr=cfg_opt.lr, eps=cfg_opt.eps,
                       betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2))

    elif cfg_opt.type == 'madam':
        g_bound = getattr(cfg_opt, 'g_bound', None)
        opt = Madam(params, lr=cfg_opt.lr,
                    scale=cfg_opt.scale, g_bound=g_bound)
    elif cfg_opt.type == 'fromage':
        opt = Fromage(params, lr=cfg_opt.lr)
    elif cfg_opt.type == 'rmsprop':
        opt = RMSprop(params, lr=cfg_opt.lr,
                      eps=cfg_opt.eps, weight_decay=cfg_opt.weight_decay)
    elif cfg_opt.type == 'sgd':
        if fused_opt:
            from apex.optimizers import FusedSGD
            opt = FusedSGD(params,
                           lr=cfg_opt.lr,
                           momentum=cfg_opt.momentum,
                           weight_decay=cfg_opt.weight_decay)
        else:
            opt = SGD(params,
                      lr=cfg_opt.lr,
                      momentum=cfg_opt.momentum,
                      weight_decay=cfg_opt.weight_decay)
    else:
        raise NotImplementedError(
            'Optimizer {} is not yet implemented.'.format(cfg_opt.type))
    return opt