master_thesis_models / model /logger /visualization.py
Hannes Kuchelmeister
add model to repository
b72a776
raw
history blame
2.89 kB
import importlib
from datetime import datetime
class TensorboardWriter():
def __init__(self, log_dir, logger, enabled):
self.writer = None
self.selected_module = ""
if enabled:
log_dir = str(log_dir)
# Retrieve vizualization writer.
succeeded = False
for module in ["torch.utils.tensorboard", "tensorboardX"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
break
except ImportError:
succeeded = False
self.selected_module = module
if not succeeded:
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
logger.warning(message)
self.step = 0
self.mode = ''
self.tb_writer_ftns = {
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
}
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
self.timer = datetime.now()
def set_step(self, step, mode='train'):
self.mode = mode
self.step = step
if step == 0:
self.timer = datetime.now()
else:
duration = datetime.now() - self.timer
self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
self.timer = datetime.now()
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.tb_writer_ftns:
add_data = getattr(self.writer, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, self.step, *args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr