import contextlib import dataclasses import datetime import logging import time from typing import Optional, Protocol import torch logger = logging.getLogger("utils") @dataclasses.dataclass class TrainState: max_steps: int step: int = 0 elapsed_time: float = 0.0 n_seen_tokens: int = 0 this_step_time: float = 0.0 begin_step_time: float = 0.0 this_eval_perplexity: Optional[float] = None this_eval_loss: Optional[float] = None def start_step(self): self.step += 1 self.begin_step_time = time.time() def end_step(self, n_batch_tokens: int): self.this_step_time = time.time() - self.begin_step_time self.this_step_tokens = n_batch_tokens self.elapsed_time += self.this_step_time self.n_seen_tokens += self.this_step_tokens self.begin_step_time = time.time() @property def wps(self): return self.this_step_tokens / self.this_step_time @property def avg_wps(self): return self.n_seen_tokens / self.elapsed_time @property def eta(self): steps_left = self.max_steps - self.step avg_time_per_step = self.elapsed_time / self.step return steps_left * avg_time_per_step def set_random_seed(seed: int) -> None: """Set random seed for reproducibility.""" torch.manual_seed(seed) torch.cuda.manual_seed(seed) class Closable(Protocol): def close(self): pass @contextlib.contextmanager def logged_closing(thing: Closable, name: str): """ Logging the closing to be sure something is not hanging at exit time """ try: setattr(thing, "wrapped_by_closing", True) yield finally: logger.info(f"Closing: {name}") try: thing.close() except Exception: logger.error(f"Error while closing {name}!") raise logger.info(f"Closed: {name}") def now_as_str() -> str: return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")