pcsr_carn / models /models.py
3587jjh's picture
Upload 10 files
61522a1 verified
raw
history blame contribute delete
508 Bytes
import copy
models = {}
def register(name):
def decorator(cls):
models[name] = cls
return cls
return decorator
def make(model_spec, args=None, load_sd=False):
if args is not None:
model_args = copy.deepcopy(model_spec['args'])
model_args.update(args)
else:
model_args = model_spec['args']
model = models[model_spec['name']](**model_args)
if load_sd:
model.load_state_dict(model_spec['sd'])
return model