|
import datetime |
|
import glob |
|
import os |
|
import shutil |
|
import subprocess |
|
import sys |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
|
|
def get_git_branch(): |
|
try: |
|
out = subprocess.check_output(["git", "branch"]).decode("utf8") |
|
current = next(line for line in out.split("\n") |
|
if line.startswith("*")) |
|
current.replace("* ", "") |
|
except subprocess.CalledProcessError: |
|
current = "inside_docker" |
|
return current |
|
|
|
|
|
def get_commit_hash(): |
|
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
commit = subprocess.check_output( |
|
['git', 'rev-parse', '--short', 'HEAD']).decode().strip() |
|
|
|
except subprocess.CalledProcessError: |
|
commit = "0000000" |
|
print(' > Git Hash: {}'.format(commit)) |
|
return commit |
|
|
|
|
|
def create_experiment_folder(root_path, model_name, debug): |
|
""" Create a folder with the current date and time """ |
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") |
|
if debug: |
|
commit_hash = 'debug' |
|
else: |
|
commit_hash = get_commit_hash() |
|
output_folder = os.path.join( |
|
root_path, model_name + '-' + date_str + '-' + commit_hash) |
|
os.makedirs(output_folder, exist_ok=True) |
|
print(" > Experiment folder: {}".format(output_folder)) |
|
return output_folder |
|
|
|
|
|
def remove_experiment_folder(experiment_path): |
|
"""Check folder if there is a checkpoint, otherwise remove the folder""" |
|
|
|
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") |
|
if not checkpoint_files: |
|
if os.path.exists(experiment_path): |
|
shutil.rmtree(experiment_path, ignore_errors=True) |
|
print(" ! Run is removed from {}".format(experiment_path)) |
|
else: |
|
print(" ! Run is kept in {}".format(experiment_path)) |
|
|
|
|
|
def count_parameters(model): |
|
r"""Count number of trainable parameters in a network""" |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
def get_user_data_dir(appname): |
|
if sys.platform == "win32": |
|
import winreg |
|
key = winreg.OpenKey( |
|
winreg.HKEY_CURRENT_USER, |
|
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" |
|
) |
|
dir_, _ = winreg.QueryValueEx(key, "Local AppData") |
|
ans = Path(dir_).resolve(strict=False) |
|
elif sys.platform == 'darwin': |
|
ans = Path('~/Library/Application Support/').expanduser() |
|
else: |
|
ans = Path.home().joinpath('.local/share') |
|
return ans.joinpath(appname) |
|
|
|
|
|
def set_init_dict(model_dict, checkpoint_state, c): |
|
|
|
for k, v in checkpoint_state.items(): |
|
if k not in model_dict: |
|
print(" | > Layer missing in the model definition: {}".format(k)) |
|
|
|
pretrained_dict = { |
|
k: v |
|
for k, v in checkpoint_state.items() if k in model_dict |
|
} |
|
|
|
pretrained_dict = { |
|
k: v |
|
for k, v in pretrained_dict.items() |
|
if v.numel() == model_dict[k].numel() |
|
} |
|
|
|
if c.reinit_layers is not None: |
|
for reinit_layer_name in c.reinit_layers: |
|
pretrained_dict = { |
|
k: v |
|
for k, v in pretrained_dict.items() |
|
if reinit_layer_name not in k |
|
} |
|
|
|
model_dict.update(pretrained_dict) |
|
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), |
|
len(model_dict))) |
|
return model_dict |
|
|
|
|
|
class KeepAverage(): |
|
def __init__(self): |
|
self.avg_values = {} |
|
self.iters = {} |
|
|
|
def __getitem__(self, key): |
|
return self.avg_values[key] |
|
|
|
def items(self): |
|
return self.avg_values.items() |
|
|
|
def add_value(self, name, init_val=0, init_iter=0): |
|
self.avg_values[name] = init_val |
|
self.iters[name] = init_iter |
|
|
|
def update_value(self, name, value, weighted_avg=False): |
|
if name not in self.avg_values: |
|
|
|
self.add_value(name, init_val=value) |
|
else: |
|
|
|
if weighted_avg: |
|
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value |
|
self.iters[name] += 1 |
|
else: |
|
self.avg_values[name] = self.avg_values[name] * \ |
|
self.iters[name] + value |
|
self.iters[name] += 1 |
|
self.avg_values[name] /= self.iters[name] |
|
|
|
def add_values(self, name_dict): |
|
for key, value in name_dict.items(): |
|
self.add_value(key, init_val=value) |
|
|
|
def update_values(self, value_dict): |
|
for key, value in value_dict.items(): |
|
self.update_value(key, value) |
|
|
|
|
|
def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): |
|
if alternative in c.keys() and c[alternative] is not None: |
|
return |
|
if restricted: |
|
assert name in c.keys(), f' [!] {name} not defined in config.json' |
|
if name in c.keys(): |
|
if max_val: |
|
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' |
|
if min_val: |
|
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' |
|
if enum_list: |
|
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' |
|
if isinstance(val_type, list): |
|
is_valid = False |
|
for typ in val_type: |
|
if isinstance(c[name], typ): |
|
is_valid = True |
|
assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' |
|
elif val_type: |
|
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' |
|
|