Spaces:
Sleeping
Sleeping
import os, shutil | |
import torch | |
from tensorboardX import SummaryWriter | |
from config.options import * | |
import torch.distributed as dist | |
import time | |
""" ==================== Save ======================== """ | |
def make_path(): | |
return "{}_{}_bs{}_lr{}".format(opts.expri,opts.savepath,opts.batch_size,opts.learn_rate) | |
def save_model(model,name): | |
save_path = make_path() | |
if not os.path.isdir(os.path.join(config['checkpoint_base'], save_path)): | |
os.makedirs(os.path.join(config['checkpoint_base'], save_path), exist_ok=True) | |
model_name = os.path.join(config['checkpoint_base'], save_path, name) | |
torch.save(model.state_dict(), model_name) | |
""" ==================== Tools ======================== """ | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
def makedir(path): | |
if not os.path.exists(path): | |
os.makedirs(path, 0o777) | |
def visualizer(): | |
if get_rank() == 0: | |
# filewriter_path = config['visual_base']+opts.savepath+'/' | |
save_path = make_path() | |
filewriter_path = os.path.join(config['visual_base'], save_path) | |
if opts.clear_visualizer and os.path.exists(filewriter_path): # 删掉以前的summary,以免重合 | |
shutil.rmtree(filewriter_path) | |
makedir(filewriter_path) | |
writer = SummaryWriter(filewriter_path, comment='visualizer') | |
return writer | |