ljy266987
add lfs
12bfd03
raw
history blame
6.54 kB
import glob
import json
import os
import random
import sys
import time
import warnings
import matplotlib
import numpy as np
import torch
import yaml
from torch import distributed as dist
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
import re
import pathlib
def seed_everything(seed, cudnn_deterministic=False):
"""
Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random
Args:
seed: the integer value seed for global random state
"""
if seed is not None:
# print(f"Global seed set to {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# if cudnn_deterministic:
# torch.backends.cudnn.deterministic = True
# warnings.warn('You have chosen to seed training. '
# 'This will turn on the CUDNN deterministic setting, '
# 'which can slow down your training considerably! '
# 'You may see unexpected behavior when restarting '
# 'from checkpoints.')
def is_primary():
return get_rank() == 0
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def load_yaml_config(path):
with open(path) as f:
config = yaml.full_load(f)
return config
def save_config_to_yaml(config, path):
assert path.endswith('.yaml')
with open(path, 'w') as f:
f.write(yaml.dump(config))
f.close()
def save_dict_to_json(d, path, indent=None):
json.dump(d, open(path, 'w'), indent=indent)
def load_dict_from_json(path):
return json.load(open(path, 'r'))
def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args)
if not name.startswith('_'))
with open(path, 'a') as args_file:
args_file.write('==> torch version: {}\n'.format(torch.__version__))
args_file.write(
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
args_file.write('==> Cmd:\n')
args_file.write(str(sys.argv))
args_file.write('\n==> args:\n')
for k, v in sorted(args_dict.items()):
args_file.write(' %s: %s\n' % (str(k), str(v)))
args_file.close()
class Logger(object):
def __init__(self, args):
self.args = args
self.save_dir = args.save_dir
self.is_primary = is_primary()
if self.is_primary:
os.makedirs(self.save_dir, exist_ok=True)
# save the args and config
self.config_dir = os.path.join(self.save_dir, 'configs')
os.makedirs(self.config_dir, exist_ok=True)
file_name = os.path.join(self.config_dir, 'args.txt')
write_args(args, file_name)
log_dir = os.path.join(self.save_dir, 'logs')
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
self.text_writer = open(os.path.join(log_dir, 'log.txt'),
'a') # 'w')
if args.tensorboard:
self.log_info('using tensorboard')
self.tb_writer = torch.utils.tensorboard.SummaryWriter(
log_dir=log_dir
) # tensorboard.SummaryWriter(log_dir=log_dir)
else:
self.tb_writer = None
def save_config(self, config):
if self.is_primary:
save_config_to_yaml(config,
os.path.join(self.config_dir, 'config.yaml'))
def log_info(self, info, check_primary=True):
if self.is_primary or (not check_primary):
print(info)
if self.is_primary:
info = str(info)
time_str = time.strftime('%Y-%m-%d-%H-%M')
info = '{}: {}'.format(time_str, info)
if not info.endswith('\n'):
info += '\n'
self.text_writer.write(info)
self.text_writer.flush()
def add_scalar(self, **kargs):
"""Log a scalar variable."""
if self.is_primary:
if self.tb_writer is not None:
self.tb_writer.add_scalar(**kargs)
def add_scalars(self, **kargs):
"""Log a scalar variable."""
if self.is_primary:
if self.tb_writer is not None:
self.tb_writer.add_scalars(**kargs)
def add_image(self, **kargs):
"""Log a scalar variable."""
if self.is_primary:
if self.tb_writer is not None:
self.tb_writer.add_image(**kargs)
def add_images(self, **kargs):
"""Log a scalar variable."""
if self.is_primary:
if self.tb_writer is not None:
self.tb_writer.add_images(**kargs)
def close(self):
if self.is_primary:
self.text_writer.close()
self.tb_writer.close()
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(
spectrogram, aspect="auto", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def save_checkpoint(filepath, obj, num_ckpt_keep=5):
name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1)
ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*'))
if len(ckpts) > num_ckpt_keep:
[os.remove(c) for c in ckpts[:-num_ckpt_keep]]
print("Saving checkpoint to {}".format(filepath))
torch.save(obj, filepath)
print("Complete.")
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '????????')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None
return sorted(cp_list)[-1]