Spaces:
Runtime error
Runtime error
File size: 3,746 Bytes
51a61da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import yaml
import json
import pickle
import torch
def traverse_dir(
root_dir,
extensions,
amount=None,
str_include=None,
str_exclude=None,
is_pure=False,
is_sort=False,
is_ext=True):
file_list = []
cnt = 0
for root, _, files in os.walk(root_dir):
for file in files:
if any([file.endswith(f".{ext}") for ext in extensions]):
# path
mix_path = os.path.join(root, file)
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
# amount
if (amount is not None) and (cnt == amount):
if is_sort:
file_list.sort()
return file_list
# check string
if (str_include is not None) and (str_include not in pure_path):
continue
if (str_exclude is not None) and (str_exclude in pure_path):
continue
if not is_ext:
ext = pure_path.split('.')[-1]
pure_path = pure_path[:-(len(ext)+1)]
file_list.append(pure_path)
cnt += 1
if is_sort:
file_list.sort()
return file_list
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def get_network_paras_amount(model_dict):
info = dict()
for model_name, model in model_dict.items():
# all_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
info[model_name] = trainable_params
return info
def load_config(path_config):
with open(path_config, "r") as config:
args = yaml.safe_load(config)
args = DotDict(args)
# print(args)
return args
def save_config(path_config,config):
config = dict(config)
with open(path_config, "w") as f:
yaml.dump(config, f)
def to_json(path_params, path_json):
params = torch.load(path_params, map_location=torch.device('cpu'))
raw_state_dict = {}
for k, v in params.items():
val = v.flatten().numpy().tolist()
raw_state_dict[k] = val
with open(path_json, 'w') as outfile:
json.dump(raw_state_dict, outfile,indent= "\t")
def convert_tensor_to_numpy(tensor, is_squeeze=True):
if is_squeeze:
tensor = tensor.squeeze()
if tensor.requires_grad:
tensor = tensor.detach()
if tensor.is_cuda:
tensor = tensor.cpu()
return tensor.numpy()
def load_model(
expdir,
model,
optimizer,
name='model',
postfix='',
device='cpu'):
if postfix == '':
postfix = '_' + postfix
path = os.path.join(expdir, name+postfix)
path_pt = traverse_dir(expdir, ['pt'], is_ext=False)
global_step = 0
if len(path_pt) > 0:
steps = [s[len(path):] for s in path_pt]
maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
if maxstep >= 0:
path_pt = path+str(maxstep)+'.pt'
else:
path_pt = path+'best.pt'
print(' [*] restoring model from', path_pt)
ckpt = torch.load(path_pt, map_location=torch.device(device))
global_step = ckpt['global_step']
model.load_state_dict(ckpt['model'], strict=False)
if ckpt.get('optimizer') != None:
optimizer.load_state_dict(ckpt['optimizer'])
return global_step, model, optimizer
|