Spaces:
Build error
Build error
# Copyright (c) Open-MMLab. All rights reserved. | |
import os.path as osp | |
import platform | |
import shutil | |
import torch | |
from torch.optim import Optimizer | |
import mmcv | |
from mmcv.runner import RUNNERS, IterBasedRunner | |
from .checkpoint import save_checkpoint | |
try: | |
import apex | |
except: | |
print('apex is not installed') | |
class IterBasedRunnerAmp(IterBasedRunner): | |
"""Iteration-based Runner with AMP support. | |
This runner train models iteration by iteration. | |
""" | |
def save_checkpoint(self, | |
out_dir, | |
filename_tmpl='iter_{}.pth', | |
meta=None, | |
save_optimizer=True, | |
create_symlink=False): | |
"""Save checkpoint to file. | |
Args: | |
out_dir (str): Directory to save checkpoint files. | |
filename_tmpl (str, optional): Checkpoint file template. | |
Defaults to 'iter_{}.pth'. | |
meta (dict, optional): Metadata to be saved in checkpoint. | |
Defaults to None. | |
save_optimizer (bool, optional): Whether save optimizer. | |
Defaults to True. | |
create_symlink (bool, optional): Whether create symlink to the | |
latest checkpoint file. Defaults to True. | |
""" | |
if meta is None: | |
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1) | |
elif isinstance(meta, dict): | |
meta.update(iter=self.iter + 1, epoch=self.epoch + 1) | |
else: | |
raise TypeError( | |
f'meta should be a dict or None, but got {type(meta)}') | |
if self.meta is not None: | |
meta.update(self.meta) | |
filename = filename_tmpl.format(self.iter + 1) | |
filepath = osp.join(out_dir, filename) | |
optimizer = self.optimizer if save_optimizer else None | |
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) | |
# in some environments, `os.symlink` is not supported, you may need to | |
# set `create_symlink` to False | |
# if create_symlink: | |
# dst_file = osp.join(out_dir, 'latest.pth') | |
# if platform.system() != 'Windows': | |
# mmcv.symlink(filename, dst_file) | |
# else: | |
# shutil.copy(filepath, dst_file) | |
def resume(self, | |
checkpoint, | |
resume_optimizer=True, | |
map_location='default'): | |
if map_location == 'default': | |
if torch.cuda.is_available(): | |
device_id = torch.cuda.current_device() | |
checkpoint = self.load_checkpoint( | |
checkpoint, | |
map_location=lambda storage, loc: storage.cuda(device_id)) | |
else: | |
checkpoint = self.load_checkpoint(checkpoint) | |
else: | |
checkpoint = self.load_checkpoint( | |
checkpoint, map_location=map_location) | |
self._epoch = checkpoint['meta']['epoch'] | |
self._iter = checkpoint['meta']['iter'] | |
self._inner_iter = checkpoint['meta']['iter'] | |
if 'optimizer' in checkpoint and resume_optimizer: | |
if isinstance(self.optimizer, Optimizer): | |
self.optimizer.load_state_dict(checkpoint['optimizer']) | |
elif isinstance(self.optimizer, dict): | |
for k in self.optimizer.keys(): | |
self.optimizer[k].load_state_dict( | |
checkpoint['optimizer'][k]) | |
else: | |
raise TypeError( | |
'Optimizer should be dict or torch.optim.Optimizer ' | |
f'but got {type(self.optimizer)}') | |
if 'amp' in checkpoint: | |
apex.amp.load_state_dict(checkpoint['amp']) | |
self.logger.info('load amp state dict') | |
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') | |