|
import os |
|
import wandb |
|
|
|
class TrainPlatform: |
|
def __init__(self, save_dir): |
|
pass |
|
|
|
def report_scalar(self, name, value, iteration, group_name=None): |
|
pass |
|
|
|
def report_args(self, args, name): |
|
pass |
|
|
|
def close(self): |
|
pass |
|
|
|
|
|
class WandbPlatform(TrainPlatform): |
|
def __init__(self, save_dir): |
|
path, name = os.path.split(save_dir) |
|
wandb.init(project='flowmdm', |
|
name=name, |
|
entity="TO_BE_FILLED", |
|
) |
|
self.last_committed_iter = -1 |
|
|
|
def report_scalar(self, name, value, iteration, group_name): |
|
wandb.log(data={name: value}, step=iteration, commit=True) |
|
|
|
def report_data(self, data, iteration, group_name): |
|
|
|
wandb.log(data=data, step=iteration, commit=True) |
|
self.last_committed_iter = iteration |
|
|
|
def report_args(self, args, name): |
|
wandb.config.update(args) |
|
|
|
def close(self): |
|
wandb.finish() |
|
|
|
|
|
class ClearmlPlatform(TrainPlatform): |
|
def __init__(self, save_dir): |
|
from clearml import Task |
|
path, name = os.path.split(save_dir) |
|
self.task = Task.init(project_name='motion_diffusion', |
|
task_name=name, |
|
output_uri=path) |
|
self.logger = self.task.get_logger() |
|
|
|
def report_scalar(self, name, value, iteration, group_name): |
|
self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) |
|
|
|
def report_data(self, data, iteration, group_name): |
|
|
|
for name, value in data.items(): |
|
self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) |
|
|
|
def report_args(self, args, name): |
|
self.task.connect(args, name=name) |
|
|
|
def close(self): |
|
self.task.close() |
|
|
|
|
|
class TensorboardPlatform(TrainPlatform): |
|
def __init__(self, save_dir): |
|
from torch.utils.tensorboard import SummaryWriter |
|
self.writer = SummaryWriter(log_dir=save_dir) |
|
|
|
def report_scalar(self, name, value, iteration, group_name): |
|
self.writer.add_scalar(f'{group_name}/{name}', value, iteration) |
|
|
|
def report_data(self, data, iteration, group_name=None): |
|
|
|
for name, value in data.items(): |
|
self.writer.add_scalar(f'{group_name}/{name}', value, iteration) |
|
|
|
def close(self): |
|
self.writer.close() |
|
|
|
|
|
class NoPlatform(TrainPlatform): |
|
def __init__(self, save_dir): |
|
pass |
|
|
|
def report_scalar(self, name, value, iteration, group_name): |
|
pass |
|
|
|
def report_data(self, data, iteration, group_name=None): |
|
pass |
|
|
|
def close(self): |
|
pass |
|
|
|
|