|
""" Model creation / weight loading / state_dict helpers |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import logging |
|
import os |
|
from collections import OrderedDict |
|
from typing import Any, Callable, Dict, Optional, Union |
|
|
|
import torch |
|
try: |
|
import safetensors.torch |
|
_has_safetensors = True |
|
except ImportError: |
|
_has_safetensors = False |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint'] |
|
|
|
|
|
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
cleaned_state_dict = {} |
|
for k, v in state_dict.items(): |
|
name = k[7:] if k.startswith('module.') else k |
|
cleaned_state_dict[name] = v |
|
return cleaned_state_dict |
|
|
|
|
|
def load_state_dict( |
|
checkpoint_path: str, |
|
use_ema: bool = True, |
|
device: Union[str, torch.device] = 'cpu', |
|
) -> Dict[str, Any]: |
|
if checkpoint_path and os.path.isfile(checkpoint_path): |
|
|
|
if str(checkpoint_path).endswith(".safetensors"): |
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors" |
|
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) |
|
else: |
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
state_dict_key = '' |
|
if isinstance(checkpoint, dict): |
|
if use_ema and checkpoint.get('state_dict_ema', None) is not None: |
|
state_dict_key = 'state_dict_ema' |
|
elif use_ema and checkpoint.get('model_ema', None) is not None: |
|
state_dict_key = 'model_ema' |
|
elif 'state_dict' in checkpoint: |
|
state_dict_key = 'state_dict' |
|
elif 'model' in checkpoint: |
|
state_dict_key = 'model' |
|
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) |
|
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) |
|
return state_dict |
|
else: |
|
_logger.error("No checkpoint found at '{}'".format(checkpoint_path)) |
|
raise FileNotFoundError() |
|
|
|
|
|
def load_checkpoint( |
|
model: torch.nn.Module, |
|
checkpoint_path: str, |
|
use_ema: bool = True, |
|
device: Union[str, torch.device] = 'cpu', |
|
strict: bool = True, |
|
remap: bool = False, |
|
filter_fn: Optional[Callable] = None, |
|
): |
|
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): |
|
|
|
if hasattr(model, 'load_pretrained'): |
|
model.load_pretrained(checkpoint_path) |
|
else: |
|
raise NotImplementedError('Model cannot load numpy checkpoint') |
|
return |
|
|
|
state_dict = load_state_dict(checkpoint_path, use_ema, device=device) |
|
if remap: |
|
state_dict = remap_state_dict(state_dict, model) |
|
elif filter_fn: |
|
state_dict = filter_fn(state_dict, model) |
|
incompatible_keys = model.load_state_dict(state_dict, strict=strict) |
|
return incompatible_keys |
|
|
|
|
|
def remap_state_dict( |
|
state_dict: Dict[str, Any], |
|
model: torch.nn.Module, |
|
allow_reshape: bool = True |
|
): |
|
""" remap checkpoint by iterating over state dicts in order (ignoring original keys). |
|
This assumes models (and originating state dict) were created with params registered in same order. |
|
""" |
|
out_dict = {} |
|
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): |
|
assert va.numel() == vb.numel(), f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' |
|
if va.shape != vb.shape: |
|
if allow_reshape: |
|
vb = vb.reshape(va.shape) |
|
else: |
|
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' |
|
out_dict[ka] = vb |
|
return out_dict |
|
|
|
|
|
def resume_checkpoint( |
|
model: torch.nn.Module, |
|
checkpoint_path: str, |
|
optimizer: torch.optim.Optimizer = None, |
|
loss_scaler: Any = None, |
|
log_info: bool = True, |
|
): |
|
resume_epoch = None |
|
if os.path.isfile(checkpoint_path): |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
|
if log_info: |
|
_logger.info('Restoring model state from checkpoint...') |
|
state_dict = clean_state_dict(checkpoint['state_dict']) |
|
model.load_state_dict(state_dict) |
|
|
|
if optimizer is not None and 'optimizer' in checkpoint: |
|
if log_info: |
|
_logger.info('Restoring optimizer state from checkpoint...') |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
|
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: |
|
if log_info: |
|
_logger.info('Restoring AMP loss scaler state from checkpoint...') |
|
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) |
|
|
|
if 'epoch' in checkpoint: |
|
resume_epoch = checkpoint['epoch'] |
|
if 'version' in checkpoint and checkpoint['version'] > 1: |
|
resume_epoch += 1 |
|
|
|
if log_info: |
|
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) |
|
else: |
|
model.load_state_dict(checkpoint) |
|
if log_info: |
|
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) |
|
return resume_epoch |
|
else: |
|
_logger.error("No checkpoint found at '{}'".format(checkpoint_path)) |
|
raise FileNotFoundError() |
|
|
|
|
|
|