Spaces:
Sleeping
Sleeping
File size: 1,577 Bytes
172a1e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import os, shutil
import torch
from tensorboardX import SummaryWriter
from config.options import *
import torch.distributed as dist
import time
""" ==================== Save ======================== """
def make_path():
return "{}_{}_bs{}_lr{}".format(opts.expri,opts.savepath,opts.batch_size,opts.learn_rate)
def save_model(model,name):
save_path = make_path()
if not os.path.isdir(os.path.join(config['checkpoint_base'], save_path)):
os.makedirs(os.path.join(config['checkpoint_base'], save_path), exist_ok=True)
model_name = os.path.join(config['checkpoint_base'], save_path, name)
torch.save(model.state_dict(), model_name)
""" ==================== Tools ======================== """
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def makedir(path):
if not os.path.exists(path):
os.makedirs(path, 0o777)
def visualizer():
if get_rank() == 0:
# filewriter_path = config['visual_base']+opts.savepath+'/'
save_path = make_path()
filewriter_path = os.path.join(config['visual_base'], save_path)
if opts.clear_visualizer and os.path.exists(filewriter_path): # 删掉以前的summary,以免重合
shutil.rmtree(filewriter_path)
makedir(filewriter_path)
writer = SummaryWriter(filewriter_path, comment='visualizer')
return writer
|