File size: 6,632 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import json
import logging
import os
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, Optional, Union
from torch.utils.tensorboard import SummaryWriter
from finetune.args import MLFlowArgs, TrainArgs, WandbArgs
from finetune.utils import TrainState
logger = logging.getLogger("metrics_logger")
GB = 1024**3
def get_train_logs(
state: TrainState,
loss: float,
lr: float,
peak_allocated_mem: float,
allocated_mem: float,
train_args: TrainArgs,
) -> Dict[str, Union[float, int]]:
metrics = {
"lr": lr,
"step": state.step,
"loss": loss,
"percent_done": 100 * state.step / train_args.max_steps,
"peak_allocated_mem": peak_allocated_mem / GB,
"allocated_mem": allocated_mem / GB,
"wps": state.wps,
"avg_wps": state.avg_wps,
"eta_in_seconds": state.eta,
}
return metrics
def get_eval_logs(
step: int,
train_loss: float,
perplexity: Optional[float],
eval_loss: Optional[float],
) -> Dict[str, Union[float, int]]:
eval_dict = {"step": step, "train_loss": train_loss}
if perplexity is not None:
eval_dict["perplexity"] = perplexity
if eval_loss is not None:
eval_dict["eval_loss"] = eval_loss
return eval_dict
def train_log_msg(
state: TrainState, logs: Dict[str, Union[float, int]], loss: float
) -> str:
metrics: Dict[str, Union[float, int, datetime]] = dict(logs) # shallow copy
metrics.pop("eta_in_seconds")
metrics["eta"] = datetime.now() + timedelta(seconds=state.eta)
metrics["step"] = state.step
metrics["loss"] = loss
parts = []
for key, fmt, new_name in [
("step", "06", None),
("percent_done", "03.1f", "done (%)"),
("loss", ".3f", None),
("lr", ".1e", None),
("peak_allocated_mem", ".1f", "peak_alloc_mem (GB)"),
("allocated_mem", ".1f", "alloc_mem (GB)"),
("wps", ".1f", "words_per_second"),
("avg_wps", ".1f", "avg_words_per_second"),
("eta", "%Y-%m-%d %H:%M:%S", "ETA"),
]:
name = key if new_name is None else new_name
try:
parts.append(f"{name}: {metrics[key]:>{fmt}}")
except KeyError:
logger.error(f"{key} not found in {sorted(metrics.keys())}")
raise
return " - ".join(parts)
def eval_log_msg(logs: Dict[str, Union[float, int]]) -> str:
parts = []
for key, fmt, new_name in [
("step", "06", None),
("perplexity", ".3f", "eval_perplexity"),
("eval_loss", ".3f", None),
("train_loss", ".3f", None),
]:
name = key if new_name is None else new_name
if key in logs:
parts.append(f"{name}: {logs[key]:>{fmt}}")
return " - ".join(parts)
class MetricsLogger:
def __init__(
self,
dst_dir: Path,
tag: str,
is_master: bool,
wandb_args: WandbArgs,
mlflow_args: MLFlowArgs,
config: Optional[Dict[str, Any]] = None,
):
self.dst_dir = dst_dir
self.tag = tag
self.is_master = is_master
self.jsonl_path = dst_dir / f"metrics.{tag}.jsonl"
self.tb_dir = dst_dir / "tb"
self.summary_writer: Optional[SummaryWriter] = None
if not self.is_master:
return
filename_suffix = f".{tag}"
self.tb_dir.mkdir(exist_ok=True)
self.summary_writer = SummaryWriter(
log_dir=str(self.tb_dir),
max_queue=1000,
filename_suffix=filename_suffix,
)
self.is_wandb = wandb_args.project is not None
self.is_mlflow = mlflow_args.tracking_uri is not None
if self.is_wandb:
import wandb
if wandb_args.key is not None:
wandb.login(key=wandb_args.key) # LLM
if wandb_args.offline:
os.environ["WANDB_MODE"] = "offline"
if wandb.run is None:
logger.info("initializing wandb")
wandb.init(
config=config,
dir=dst_dir,
project=wandb_args.project,
job_type="training",
name=wandb_args.run_name or dst_dir.name,
resume=False,
)
self.wandb_log = wandb.log
if self.is_mlflow:
import mlflow
mlflow.set_tracking_uri(mlflow_args.tracking_uri)
mlflow.set_experiment(mlflow_args.experiment_name or dst_dir.name)
if tag == "train":
mlflow.start_run()
self.mlflow_log = mlflow.log_metric
def log(self, metrics: Dict[str, Union[float, int]], step: int):
if not self.is_master:
return
metrics_to_ignore = {"step"}
assert self.summary_writer is not None
for key, value in metrics.items():
if key in metrics_to_ignore:
continue
assert isinstance(value, (int, float)), (key, value)
self.summary_writer.add_scalar(
tag=f"{self.tag}.{key}", scalar_value=value, global_step=step
)
if self.is_mlflow:
self.mlflow_log(f"{self.tag}.{key}", value, step=step)
if self.is_wandb:
# grouping in wandb is done with /
self.wandb_log(
{
f"{self.tag}/{key}": value
for key, value in metrics.items()
if key not in metrics_to_ignore
},
step=step,
)
metrics_: Dict[str, Any] = dict(metrics) # shallow copy
if "step" in metrics_:
assert step == metrics_["step"]
else:
metrics_["step"] = step
metrics_["at"] = datetime.utcnow().isoformat()
with self.jsonl_path.open("a") as fp:
fp.write(f"{json.dumps(metrics_)}\n")
def close(self):
if not self.is_master:
return
if self.summary_writer is not None:
self.summary_writer.close()
self.summary_writer = None
if self.is_wandb:
import wandb
# to be sure we are not hanging while finishing
wandb.finish()
if self.is_mlflow:
import mlflow
mlflow.end_run()
def __del__(self):
if self.summary_writer is not None:
raise RuntimeError(
"MetricsLogger not closed properly! You should "
"make sure the close() method is called!"
)
|