import sys import wandb from time import sleep import os def init_wandb(project_name, model_name, config, **wandb_kwargs): os.environ['WANDB__SERVICE_WAIT'] = '300' while True: try: wandb_run = wandb.init( project=project_name, name=model_name, save_code=True, config=config, **wandb_kwargs, ) break except Exception as e: print('wandb connection error', file=sys.stderr) print(f'error: {e}', file=sys.stderr) sleep(1) print('retrying..', file=sys.stderr) return wandb_run def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise ValueError class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__)