Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from torch import nn | |
class ModuleProxyWrapper(nn.Module): | |
""" | |
Wrap a DistributedDataParallel module and forward requests for missing | |
attributes to the module wrapped by DDP (the twice-wrapped module). | |
Also forward calls to :func:`state_dict` and :func:`load_state_dict`. | |
Usage:: | |
module.xyz = "hello world" | |
wrapped_module = DistributedDataParallel(module, **ddp_args) | |
wrapped_module = ModuleProxyWrapper(wrapped_module) | |
assert wrapped_module.xyz == "hello world" | |
assert wrapped_module.state_dict().keys() == module.state_dict().keys() | |
Args: | |
module (nn.Module): module to wrap | |
""" | |
def __init__(self, module: nn.Module): | |
super().__init__() | |
assert hasattr(module, "module"), \ | |
"ModuleProxyWrapper expects input to wrap another module" | |
self.module = module | |
def __getattr__(self, name): | |
"""Forward missing attributes to twice-wrapped module.""" | |
try: | |
# defer to nn.Module's logic | |
return super().__getattr__(name) | |
except AttributeError: | |
try: | |
# forward to the once-wrapped module | |
return getattr(self.module, name) | |
except AttributeError: | |
# forward to the twice-wrapped module | |
return getattr(self.module.module, name) | |
def state_dict(self, *args, **kwargs): | |
"""Forward to the twice-wrapped module.""" | |
return self.module.module.state_dict(*args, **kwargs) | |
def load_state_dict(self, *args, **kwargs): | |
"""Forward to the twice-wrapped module.""" | |
return self.module.module.load_state_dict(*args, **kwargs) | |
def forward(self, *args, **kwargs): | |
return self.module(*args, **kwargs) | |