|
""" Summary utilities |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import csv |
|
import os |
|
from collections import OrderedDict |
|
try: |
|
import wandb |
|
except ImportError: |
|
pass |
|
|
|
|
|
def get_outdir(path, *paths, inc=False): |
|
outdir = os.path.join(path, *paths) |
|
if not os.path.exists(outdir): |
|
os.makedirs(outdir) |
|
elif inc: |
|
count = 1 |
|
outdir_inc = outdir + '-' + str(count) |
|
while os.path.exists(outdir_inc): |
|
count = count + 1 |
|
outdir_inc = outdir + '-' + str(count) |
|
assert count < 100 |
|
outdir = outdir_inc |
|
os.makedirs(outdir) |
|
return outdir |
|
|
|
|
|
def update_summary( |
|
epoch, |
|
train_metrics, |
|
eval_metrics, |
|
filename, |
|
lr=None, |
|
write_header=False, |
|
log_wandb=False, |
|
): |
|
rowd = OrderedDict(epoch=epoch) |
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) |
|
if eval_metrics: |
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) |
|
if lr is not None: |
|
rowd['lr'] = lr |
|
if log_wandb: |
|
wandb.log(rowd) |
|
with open(filename, mode='a') as cf: |
|
dw = csv.DictWriter(cf, fieldnames=rowd.keys()) |
|
if write_header: |
|
dw.writeheader() |
|
dw.writerow(rowd) |
|
|