|
|
|
import os.path as osp |
|
import platform |
|
import shutil |
|
|
|
import torch |
|
from torch.optim import Optimizer |
|
|
|
import mmcv |
|
from mmcv.runner import RUNNERS, EpochBasedRunner |
|
from .checkpoint import save_checkpoint |
|
|
|
try: |
|
import apex |
|
except: |
|
print('apex is not installed') |
|
|
|
|
|
@RUNNERS.register_module() |
|
class EpochBasedRunnerAmp(EpochBasedRunner): |
|
"""Epoch-based Runner with AMP support. |
|
|
|
This runner train models epoch by epoch. |
|
""" |
|
|
|
def save_checkpoint(self, |
|
out_dir, |
|
filename_tmpl='epoch_{}.pth', |
|
save_optimizer=True, |
|
meta=None, |
|
create_symlink=True): |
|
"""Save the checkpoint. |
|
|
|
Args: |
|
out_dir (str): The directory that checkpoints are saved. |
|
filename_tmpl (str, optional): The checkpoint filename template, |
|
which contains a placeholder for the epoch number. |
|
Defaults to 'epoch_{}.pth'. |
|
save_optimizer (bool, optional): Whether to save the optimizer to |
|
the checkpoint. Defaults to True. |
|
meta (dict, optional): The meta information to be saved in the |
|
checkpoint. Defaults to None. |
|
create_symlink (bool, optional): Whether to create a symlink |
|
"latest.pth" to point to the latest checkpoint. |
|
Defaults to True. |
|
""" |
|
if meta is None: |
|
meta = dict(epoch=self.epoch + 1, iter=self.iter) |
|
elif isinstance(meta, dict): |
|
meta.update(epoch=self.epoch + 1, iter=self.iter) |
|
else: |
|
raise TypeError( |
|
f'meta should be a dict or None, but got {type(meta)}') |
|
if self.meta is not None: |
|
meta.update(self.meta) |
|
|
|
filename = filename_tmpl.format(self.epoch + 1) |
|
filepath = osp.join(out_dir, filename) |
|
optimizer = self.optimizer if save_optimizer else None |
|
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) |
|
|
|
|
|
if create_symlink: |
|
dst_file = osp.join(out_dir, 'latest.pth') |
|
if platform.system() != 'Windows': |
|
mmcv.symlink(filename, dst_file) |
|
else: |
|
shutil.copy(filepath, dst_file) |
|
|
|
def resume(self, |
|
checkpoint, |
|
resume_optimizer=True, |
|
map_location='default'): |
|
if map_location == 'default': |
|
if torch.cuda.is_available(): |
|
device_id = torch.cuda.current_device() |
|
checkpoint = self.load_checkpoint( |
|
checkpoint, |
|
map_location=lambda storage, loc: storage.cuda(device_id)) |
|
else: |
|
checkpoint = self.load_checkpoint(checkpoint) |
|
else: |
|
checkpoint = self.load_checkpoint( |
|
checkpoint, map_location=map_location) |
|
|
|
self._epoch = checkpoint['meta']['epoch'] |
|
self._iter = checkpoint['meta']['iter'] |
|
if 'optimizer' in checkpoint and resume_optimizer: |
|
if isinstance(self.optimizer, Optimizer): |
|
self.optimizer.load_state_dict(checkpoint['optimizer']) |
|
elif isinstance(self.optimizer, dict): |
|
for k in self.optimizer.keys(): |
|
self.optimizer[k].load_state_dict( |
|
checkpoint['optimizer'][k]) |
|
else: |
|
raise TypeError( |
|
'Optimizer should be dict or torch.optim.Optimizer ' |
|
f'but got {type(self.optimizer)}') |
|
|
|
if 'amp' in checkpoint: |
|
apex.amp.load_state_dict(checkpoint['amp']) |
|
self.logger.info('load amp state dict') |
|
|
|
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) |
|
|