|
import logging |
|
import torch |
|
from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule |
|
|
|
|
|
def get_training_model_class(kind): |
|
if kind == 'default': |
|
return DefaultInpaintingTrainingModule |
|
|
|
raise ValueError(f'Unknown trainer module {kind}') |
|
|
|
|
|
def make_training_model(config): |
|
kind = config.training_model.kind |
|
kwargs = dict(config.training_model) |
|
kwargs.pop('kind') |
|
kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' |
|
|
|
logging.info(f'Make training model {kind}') |
|
|
|
cls = get_training_model_class(kind) |
|
return cls(config, **kwargs) |
|
|
|
|
|
def load_checkpoint(train_config, path, map_location='cuda', strict=True): |
|
model: torch.nn.Module = make_training_model(train_config) |
|
state = torch.load(path, map_location=map_location) |
|
model.load_state_dict(state['state_dict'], strict=strict) |
|
model.on_load_checkpoint(state) |
|
return model |
|
|