Clemspace's picture
Initial model upload
cb9e677
import json
import logging
import os
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, Optional, Union
from torch.utils.tensorboard import SummaryWriter
from finetune.args import MLFlowArgs, TrainArgs, WandbArgs
from finetune.utils import TrainState
logger = logging.getLogger("metrics_logger")
GB = 1024**3
def get_train_logs(
state: TrainState,
loss: float,
lr: float,
peak_allocated_mem: float,
allocated_mem: float,
train_args: TrainArgs,
) -> Dict[str, Union[float, int]]:
metrics = {
"lr": lr,
"step": state.step,
"loss": loss,
"percent_done": 100 * state.step / train_args.max_steps,
"peak_allocated_mem": peak_allocated_mem / GB,
"allocated_mem": allocated_mem / GB,
"wps": state.wps,
"avg_wps": state.avg_wps,
"eta_in_seconds": state.eta,
}
return metrics
def get_eval_logs(
step: int,
train_loss: float,
perplexity: Optional[float],
eval_loss: Optional[float],
) -> Dict[str, Union[float, int]]:
eval_dict = {"step": step, "train_loss": train_loss}
if perplexity is not None:
eval_dict["perplexity"] = perplexity
if eval_loss is not None:
eval_dict["eval_loss"] = eval_loss
return eval_dict
def train_log_msg(
state: TrainState, logs: Dict[str, Union[float, int]], loss: float
) -> str:
metrics: Dict[str, Union[float, int, datetime]] = dict(logs) # shallow copy
metrics.pop("eta_in_seconds")
metrics["eta"] = datetime.now() + timedelta(seconds=state.eta)
metrics["step"] = state.step
metrics["loss"] = loss
parts = []
for key, fmt, new_name in [
("step", "06", None),
("percent_done", "03.1f", "done (%)"),
("loss", ".3f", None),
("lr", ".1e", None),
("peak_allocated_mem", ".1f", "peak_alloc_mem (GB)"),
("allocated_mem", ".1f", "alloc_mem (GB)"),
("wps", ".1f", "words_per_second"),
("avg_wps", ".1f", "avg_words_per_second"),
("eta", "%Y-%m-%d %H:%M:%S", "ETA"),
]:
name = key if new_name is None else new_name
try:
parts.append(f"{name}: {metrics[key]:>{fmt}}")
except KeyError:
logger.error(f"{key} not found in {sorted(metrics.keys())}")
raise
return " - ".join(parts)
def eval_log_msg(logs: Dict[str, Union[float, int]]) -> str:
parts = []
for key, fmt, new_name in [
("step", "06", None),
("perplexity", ".3f", "eval_perplexity"),
("eval_loss", ".3f", None),
("train_loss", ".3f", None),
]:
name = key if new_name is None else new_name
if key in logs:
parts.append(f"{name}: {logs[key]:>{fmt}}")
return " - ".join(parts)
class MetricsLogger:
def __init__(
self,
dst_dir: Path,
tag: str,
is_master: bool,
wandb_args: WandbArgs,
mlflow_args: MLFlowArgs,
config: Optional[Dict[str, Any]] = None,
):
self.dst_dir = dst_dir
self.tag = tag
self.is_master = is_master
self.jsonl_path = dst_dir / f"metrics.{tag}.jsonl"
self.tb_dir = dst_dir / "tb"
self.summary_writer: Optional[SummaryWriter] = None
if not self.is_master:
return
filename_suffix = f".{tag}"
self.tb_dir.mkdir(exist_ok=True)
self.summary_writer = SummaryWriter(
log_dir=str(self.tb_dir),
max_queue=1000,
filename_suffix=filename_suffix,
)
self.is_wandb = wandb_args.project is not None
self.is_mlflow = mlflow_args.tracking_uri is not None
if self.is_wandb:
import wandb
if wandb_args.key is not None:
wandb.login(key=wandb_args.key) # LLM
if wandb_args.offline:
os.environ["WANDB_MODE"] = "offline"
if wandb.run is None:
logger.info("initializing wandb")
wandb.init(
config=config,
dir=dst_dir,
project=wandb_args.project,
job_type="training",
name=wandb_args.run_name or dst_dir.name,
resume=False,
)
self.wandb_log = wandb.log
if self.is_mlflow:
import mlflow
mlflow.set_tracking_uri(mlflow_args.tracking_uri)
mlflow.set_experiment(mlflow_args.experiment_name or dst_dir.name)
if tag == "train":
mlflow.start_run()
self.mlflow_log = mlflow.log_metric
def log(self, metrics: Dict[str, Union[float, int]], step: int):
if not self.is_master:
return
metrics_to_ignore = {"step"}
assert self.summary_writer is not None
for key, value in metrics.items():
if key in metrics_to_ignore:
continue
assert isinstance(value, (int, float)), (key, value)
self.summary_writer.add_scalar(
tag=f"{self.tag}.{key}", scalar_value=value, global_step=step
)
if self.is_mlflow:
self.mlflow_log(f"{self.tag}.{key}", value, step=step)
if self.is_wandb:
# grouping in wandb is done with /
self.wandb_log(
{
f"{self.tag}/{key}": value
for key, value in metrics.items()
if key not in metrics_to_ignore
},
step=step,
)
metrics_: Dict[str, Any] = dict(metrics) # shallow copy
if "step" in metrics_:
assert step == metrics_["step"]
else:
metrics_["step"] = step
metrics_["at"] = datetime.utcnow().isoformat()
with self.jsonl_path.open("a") as fp:
fp.write(f"{json.dumps(metrics_)}\n")
def close(self):
if not self.is_master:
return
if self.summary_writer is not None:
self.summary_writer.close()
self.summary_writer = None
if self.is_wandb:
import wandb
# to be sure we are not hanging while finishing
wandb.finish()
if self.is_mlflow:
import mlflow
mlflow.end_run()
def __del__(self):
if self.summary_writer is not None:
raise RuntimeError(
"MetricsLogger not closed properly! You should "
"make sure the close() method is called!"
)