import copy models = {} def register(name): def decorator(cls): models[name] = cls return cls return decorator def make(model_spec, args=None, load_sd=False): if args is not None: model_args = copy.deepcopy(model_spec['args']) model_args.update(args) else: model_args = model_spec['args'] model = models[model_spec['name']](**model_args) if load_sd: model.load_state_dict(model_spec['sd']) return model