File size: 2,892 Bytes
b72a776 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import importlib
from datetime import datetime
class TensorboardWriter():
def __init__(self, log_dir, logger, enabled):
self.writer = None
self.selected_module = ""
if enabled:
log_dir = str(log_dir)
# Retrieve vizualization writer.
succeeded = False
for module in ["torch.utils.tensorboard", "tensorboardX"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
break
except ImportError:
succeeded = False
self.selected_module = module
if not succeeded:
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
logger.warning(message)
self.step = 0
self.mode = ''
self.tb_writer_ftns = {
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
}
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
self.timer = datetime.now()
def set_step(self, step, mode='train'):
self.mode = mode
self.step = step
if step == 0:
self.timer = datetime.now()
else:
duration = datetime.now() - self.timer
self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
self.timer = datetime.now()
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.tb_writer_ftns:
add_data = getattr(self.writer, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, self.step, *args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr
|