Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmcv.runner import build_optimizer | |
from mmcv.utils import Registry | |
OPTIMIZERS = Registry('optimizers') | |
def build_optimizers(model, cfgs): | |
"""Build multiple optimizers from configs. | |
If `cfgs` contains several dicts for optimizers, then a dict for each | |
constructed optimizers will be returned. | |
If `cfgs` only contains one optimizer config, the constructed optimizer | |
itself will be returned. | |
For example, | |
1) Multiple optimizer configs: | |
.. code-block:: python | |
optimizer_cfg = dict( | |
model1=dict(type='SGD', lr=lr), | |
model2=dict(type='SGD', lr=lr)) | |
The return dict is | |
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` | |
2) Single optimizer config: | |
.. code-block:: python | |
optimizer_cfg = dict(type='SGD', lr=lr) | |
The return is ``torch.optim.Optimizer``. | |
Args: | |
model (:obj:`nn.Module`): The model with parameters to be optimized. | |
cfgs (dict): The config dict of the optimizer. | |
Returns: | |
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: | |
The initialized optimizers. | |
""" | |
optimizers = {} | |
if hasattr(model, 'module'): | |
model = model.module | |
# determine whether 'cfgs' has several dicts for optimizers | |
if all(isinstance(v, dict) for v in cfgs.values()): | |
for key, cfg in cfgs.items(): | |
cfg_ = cfg.copy() | |
module = getattr(model, key) | |
optimizers[key] = build_optimizer(module, cfg_) | |
return optimizers | |
return build_optimizer(model, cfgs) | |