|
|
|
from mmcv.runner import build_optimizer |
|
from mmcv.utils import Registry |
|
|
|
OPTIMIZERS = Registry('optimizers') |
|
|
|
|
|
def build_optimizers(model, cfgs): |
|
"""Build multiple optimizers from configs. If `cfgs` contains several dicts |
|
for optimizers, then a dict for each constructed optimizers will be |
|
returned. If `cfgs` only contains one optimizer config, the constructed |
|
optimizer itself will be returned. For example, |
|
|
|
1) Multiple optimizer configs: |
|
|
|
.. code-block:: python |
|
|
|
optimizer_cfg = dict( |
|
model1=dict(type='SGD', lr=lr), |
|
model2=dict(type='SGD', lr=lr)) |
|
|
|
The return dict is |
|
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` |
|
|
|
2) Single optimizer config: |
|
|
|
.. code-block:: python |
|
|
|
optimizer_cfg = dict(type='SGD', lr=lr) |
|
|
|
The return is ``torch.optim.Optimizer``. |
|
|
|
Args: |
|
model (:obj:`nn.Module`): The model with parameters to be optimized. |
|
cfgs (dict): The config dict of the optimizer. |
|
|
|
Returns: |
|
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: |
|
The initialized optimizers. |
|
""" |
|
optimizers = {} |
|
if hasattr(model, 'module'): |
|
model = model.module |
|
|
|
if all(isinstance(v, dict) for v in cfgs.values()): |
|
for key, cfg in cfgs.items(): |
|
cfg_ = cfg.copy() |
|
module = getattr(model, key) |
|
optimizers[key] = build_optimizer(module, cfg_) |
|
return optimizers |
|
|
|
return build_optimizer(model, cfgs) |
|
|