|
import os |
|
from .base_logger import BaseLogger |
|
|
|
class WandbLogger(BaseLogger): |
|
def __init__(self, |
|
project=None, |
|
name=None, |
|
id=None, |
|
entity=None, |
|
save_dir=None, |
|
config=None, |
|
**kwargs): |
|
try: |
|
import wandb |
|
self.wandb = wandb |
|
except ModuleNotFoundError: |
|
raise ModuleNotFoundError( |
|
"Please install wandb using `pip install wandb`" |
|
) |
|
|
|
self.project = project |
|
self.name = name |
|
self.id = id |
|
self.save_dir = save_dir |
|
self.config = config |
|
self.kwargs = kwargs |
|
self.entity = entity |
|
self._run = None |
|
self._wandb_init = dict( |
|
project=self.project, |
|
name=self.name, |
|
id=self.id, |
|
entity=self.entity, |
|
dir=self.save_dir, |
|
resume="allow" |
|
) |
|
self._wandb_init.update(**kwargs) |
|
|
|
_ = self.run |
|
|
|
if self.config: |
|
self.run.config.update(self.config) |
|
|
|
@property |
|
def run(self): |
|
if self._run is None: |
|
if self.wandb.run is not None: |
|
logger.info( |
|
"There is a wandb run already in progress " |
|
"and newly created instances of `WandbLogger` will reuse" |
|
" this run. If this is not desired, call `wandb.finish()`" |
|
"before instantiating `WandbLogger`." |
|
) |
|
self._run = self.wandb.run |
|
else: |
|
self._run = self.wandb.init(**self._wandb_init) |
|
return self._run |
|
|
|
def log_metrics(self, metrics, prefix=None, step=None): |
|
if not prefix: |
|
prefix = "" |
|
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()} |
|
|
|
self.run.log(updated_metrics, step=step) |
|
|
|
def log_model(self, is_best, prefix, metadata=None): |
|
model_path = os.path.join(self.save_dir, prefix + '.pdparams') |
|
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata) |
|
artifact.add_file(model_path, name="model_ckpt.pdparams") |
|
|
|
aliases = [prefix] |
|
if is_best: |
|
aliases.append("best") |
|
|
|
self.run.log_artifact(artifact, aliases=aliases) |
|
|
|
def close(self): |
|
self.run.finish() |