CharacterGAN / netdissect /statedict.py
mfrashad's picture
Init code
8f87579
'''
Utilities for dealing with simple state dicts as npz files instead of pth files.
'''
import torch
from collections.abc import MutableMapping, Mapping
def load_from_numpy_dict(model, numpy_dict, prefix='', examples=None):
'''
Loads a model from numpy_dict using load_state_dict.
Converts numpy types to torch types using the current state_dict
of the model to determine types and devices for the tensors.
Supports loading a subdict by prepending the given prefix to all keys.
'''
if prefix:
if not prefix.endswith('.'):
prefix = prefix + '.'
numpy_dict = PrefixSubDict(numpy_dict, prefix)
if examples is None:
exampels = model.state_dict()
torch_state_dict = TorchTypeMatchingDict(numpy_dict, examples)
model.load_state_dict(torch_state_dict)
def save_to_numpy_dict(model, numpy_dict, prefix=''):
'''
Saves a model by copying tensors to numpy_dict.
Converts torch types to numpy types using `t.detach().cpu().numpy()`.
Supports saving a subdict by prepending the given prefix to all keys.
'''
if prefix:
if not prefix.endswith('.'):
prefix = prefix + '.'
for k, v in model.numpy_dict().items():
if isinstance(v, torch.Tensor):
v = v.detach().cpu().numpy()
numpy_dict[prefix + k] = v
class TorchTypeMatchingDict(Mapping):
'''
Provides a view of a dict of numpy values as torch tensors, where the
types are converted to match the types and devices in the given
dict of examples.
'''
def __init__(self, data, examples):
self.data = data
self.examples = examples
self.cached_data = {}
def __getitem__(self, key):
if key in self.cached_data:
return self.cached_data[key]
val = self.data[key]
if key not in self.examples:
return val
example = self.examples.get(key, None)
example_type = type(example)
if example is not None and type(val) != example_type:
if isinstance(example, torch.Tensor):
val = torch.from_numpy(val)
else:
val = example_type(val)
if isinstance(example, torch.Tensor):
val = val.to(dtype=example.dtype, device=example.device)
self.cached_data[key] = val
return val
def __iter__(self):
return self.data.keys()
def __len__(self):
return len(self.data)
class PrefixSubDict(MutableMapping):
'''
Provides a view of the subset of a dict where string keys begin with
the given prefix. The prefix is stripped from all keys of the view.
'''
def __init__(self, data, prefix=''):
self.data = data
self.prefix = prefix
self._cached_keys = None
def __getitem__(self, key):
return self.data[self.prefix + key]
def __setitem__(self, key, value):
pkey = self.prefix + key
if self._cached_keys is not None and pkey not in self.data:
self._cached_keys = None
self.data[pkey] = value
def __delitem__(self, key):
pkey = self.prefix + key
if self._cached_keys is not None and pkey in self.data:
self._cached_keys = None
del self.data[pkey]
def __cached_keys(self):
if self._cached_keys is None:
plen = len(self.prefix)
self._cached_keys = list(k[plen:] for k in self.data
if k.startswith(self.prefix))
return self._cached_keys
def __iter__(self):
return iter(self.__cached_keys())
def __len__(self):
return len(self.__cached_keys())