File size: 6,632 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import json
import logging
import os
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, Optional, Union

from torch.utils.tensorboard import SummaryWriter

from finetune.args import MLFlowArgs, TrainArgs, WandbArgs
from finetune.utils import TrainState

logger = logging.getLogger("metrics_logger")

GB = 1024**3


def get_train_logs(
    state: TrainState,
    loss: float,
    lr: float,
    peak_allocated_mem: float,
    allocated_mem: float,
    train_args: TrainArgs,
) -> Dict[str, Union[float, int]]:
    metrics = {
        "lr": lr,
        "step": state.step,
        "loss": loss,
        "percent_done": 100 * state.step / train_args.max_steps,
        "peak_allocated_mem": peak_allocated_mem / GB,
        "allocated_mem": allocated_mem / GB,
        "wps": state.wps,
        "avg_wps": state.avg_wps,
        "eta_in_seconds": state.eta,
    }

    return metrics


def get_eval_logs(
    step: int,
    train_loss: float,
    perplexity: Optional[float],
    eval_loss: Optional[float],
) -> Dict[str, Union[float, int]]:
    eval_dict = {"step": step, "train_loss": train_loss}

    if perplexity is not None:
        eval_dict["perplexity"] = perplexity

    if eval_loss is not None:
        eval_dict["eval_loss"] = eval_loss
    return eval_dict


def train_log_msg(
    state: TrainState, logs: Dict[str, Union[float, int]], loss: float
) -> str:
    metrics: Dict[str, Union[float, int, datetime]] = dict(logs)  # shallow copy
    metrics.pop("eta_in_seconds")

    metrics["eta"] = datetime.now() + timedelta(seconds=state.eta)
    metrics["step"] = state.step
    metrics["loss"] = loss

    parts = []
    for key, fmt, new_name in [
        ("step", "06", None),
        ("percent_done", "03.1f", "done (%)"),
        ("loss", ".3f", None),
        ("lr", ".1e", None),
        ("peak_allocated_mem", ".1f", "peak_alloc_mem (GB)"),
        ("allocated_mem", ".1f", "alloc_mem (GB)"),
        ("wps", ".1f", "words_per_second"),
        ("avg_wps", ".1f", "avg_words_per_second"),
        ("eta", "%Y-%m-%d %H:%M:%S", "ETA"),
    ]:
        name = key if new_name is None else new_name
        try:
            parts.append(f"{name}: {metrics[key]:>{fmt}}")
        except KeyError:
            logger.error(f"{key} not found in {sorted(metrics.keys())}")
            raise

    return " - ".join(parts)


def eval_log_msg(logs: Dict[str, Union[float, int]]) -> str:
    parts = []
    for key, fmt, new_name in [
        ("step", "06", None),
        ("perplexity", ".3f", "eval_perplexity"),
        ("eval_loss", ".3f", None),
        ("train_loss", ".3f", None),
    ]:
        name = key if new_name is None else new_name
        if key in logs:
            parts.append(f"{name}: {logs[key]:>{fmt}}")

    return " - ".join(parts)


class MetricsLogger:
    def __init__(
        self,
        dst_dir: Path,
        tag: str,
        is_master: bool,
        wandb_args: WandbArgs,
        mlflow_args: MLFlowArgs,
        config: Optional[Dict[str, Any]] = None,
    ):
        self.dst_dir = dst_dir
        self.tag = tag
        self.is_master = is_master
        self.jsonl_path = dst_dir / f"metrics.{tag}.jsonl"
        self.tb_dir = dst_dir / "tb"
        self.summary_writer: Optional[SummaryWriter] = None

        if not self.is_master:
            return

        filename_suffix = f".{tag}"
        self.tb_dir.mkdir(exist_ok=True)
        self.summary_writer = SummaryWriter(
            log_dir=str(self.tb_dir),
            max_queue=1000,
            filename_suffix=filename_suffix,
        )
        self.is_wandb = wandb_args.project is not None
        self.is_mlflow = mlflow_args.tracking_uri is not None

        if self.is_wandb:
            import wandb

            if wandb_args.key is not None:
                wandb.login(key=wandb_args.key)  # LLM
            if wandb_args.offline:
                os.environ["WANDB_MODE"] = "offline"
            if wandb.run is None:
                logger.info("initializing wandb")
                wandb.init(
                    config=config,
                    dir=dst_dir,
                    project=wandb_args.project,
                    job_type="training",
                    name=wandb_args.run_name or dst_dir.name,
                    resume=False,
                )

            self.wandb_log = wandb.log

        if self.is_mlflow:
            import mlflow

            mlflow.set_tracking_uri(mlflow_args.tracking_uri)
            mlflow.set_experiment(mlflow_args.experiment_name or dst_dir.name)

            if tag == "train":
                mlflow.start_run()

            self.mlflow_log = mlflow.log_metric

    def log(self, metrics: Dict[str, Union[float, int]], step: int):
        if not self.is_master:
            return

        metrics_to_ignore = {"step"}
        assert self.summary_writer is not None
        for key, value in metrics.items():
            if key in metrics_to_ignore:
                continue
            assert isinstance(value, (int, float)), (key, value)
            self.summary_writer.add_scalar(
                tag=f"{self.tag}.{key}", scalar_value=value, global_step=step
            )

            if self.is_mlflow:
                self.mlflow_log(f"{self.tag}.{key}", value, step=step)

        if self.is_wandb:
            # grouping in wandb is done with /
            self.wandb_log(
                {
                    f"{self.tag}/{key}": value
                    for key, value in metrics.items()
                    if key not in metrics_to_ignore
                },
                step=step,
            )

        metrics_: Dict[str, Any] = dict(metrics)  # shallow copy
        if "step" in metrics_:
            assert step == metrics_["step"]
        else:
            metrics_["step"] = step
        metrics_["at"] = datetime.utcnow().isoformat()
        with self.jsonl_path.open("a") as fp:
            fp.write(f"{json.dumps(metrics_)}\n")

    def close(self):
        if not self.is_master:
            return

        if self.summary_writer is not None:
            self.summary_writer.close()
            self.summary_writer = None

        if self.is_wandb:
            import wandb

            # to be sure we are not hanging while finishing
            wandb.finish()

        if self.is_mlflow:
            import mlflow

            mlflow.end_run()

    def __del__(self):
        if self.summary_writer is not None:
            raise RuntimeError(
                "MetricsLogger not closed properly! You should "
                "make sure the close() method is called!"
            )