File size: 1,634 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import build_optimizer
from mmcv.utils import Registry
OPTIMIZERS = Registry('optimizers')
def build_optimizers(model, cfgs):
"""Build multiple optimizers from configs. If `cfgs` contains several dicts
for optimizers, then a dict for each constructed optimizers will be
returned. If `cfgs` only contains one optimizer config, the constructed
optimizer itself will be returned. For example,
1) Multiple optimizer configs:
.. code-block:: python
optimizer_cfg = dict(
model1=dict(type='SGD', lr=lr),
model2=dict(type='SGD', lr=lr))
The return dict is
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
2) Single optimizer config:
.. code-block:: python
optimizer_cfg = dict(type='SGD', lr=lr)
The return is ``torch.optim.Optimizer``.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
cfgs (dict): The config dict of the optimizer.
Returns:
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
The initialized optimizers.
"""
optimizers = {}
if hasattr(model, 'module'):
model = model.module
# determine whether 'cfgs' has several dicts for optimizers
if all(isinstance(v, dict) for v in cfgs.values()):
for key, cfg in cfgs.items():
cfg_ = cfg.copy()
module = getattr(model, key)
optimizers[key] = build_optimizer(module, cfg_)
return optimizers
return build_optimizer(model, cfgs)
|