import logging | |
import torch | |
from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule | |
def get_training_model_class(kind): | |
if kind == 'default': | |
return DefaultInpaintingTrainingModule | |
raise ValueError(f'Unknown trainer module {kind}') | |
def make_training_model(config): | |
kind = config.training_model.kind | |
kwargs = dict(config.training_model) | |
kwargs.pop('kind') | |
kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' | |
logging.info(f'Make training model {kind}') | |
cls = get_training_model_class(kind) | |
return cls(config, **kwargs) | |
def load_checkpoint(train_config, path, map_location='cuda', strict=True): | |
model: torch.nn.Module = make_training_model(train_config) | |
state = torch.load(path, map_location=map_location) | |
model.load_state_dict(state['state_dict'], strict=strict) | |
model.on_load_checkpoint(state) | |
return model | |