|
import contextlib |
|
import dataclasses |
|
import datetime |
|
import logging |
|
import time |
|
from typing import Optional, Protocol |
|
|
|
import torch |
|
|
|
logger = logging.getLogger("utils") |
|
|
|
|
|
@dataclasses.dataclass |
|
class TrainState: |
|
max_steps: int |
|
step: int = 0 |
|
elapsed_time: float = 0.0 |
|
n_seen_tokens: int = 0 |
|
this_step_time: float = 0.0 |
|
begin_step_time: float = 0.0 |
|
this_eval_perplexity: Optional[float] = None |
|
this_eval_loss: Optional[float] = None |
|
|
|
def start_step(self): |
|
self.step += 1 |
|
self.begin_step_time = time.time() |
|
|
|
def end_step(self, n_batch_tokens: int): |
|
self.this_step_time = time.time() - self.begin_step_time |
|
self.this_step_tokens = n_batch_tokens |
|
|
|
self.elapsed_time += self.this_step_time |
|
self.n_seen_tokens += self.this_step_tokens |
|
|
|
self.begin_step_time = time.time() |
|
|
|
@property |
|
def wps(self): |
|
return self.this_step_tokens / self.this_step_time |
|
|
|
@property |
|
def avg_wps(self): |
|
return self.n_seen_tokens / self.elapsed_time |
|
|
|
@property |
|
def eta(self): |
|
steps_left = self.max_steps - self.step |
|
avg_time_per_step = self.elapsed_time / self.step |
|
|
|
return steps_left * avg_time_per_step |
|
|
|
|
|
def set_random_seed(seed: int) -> None: |
|
"""Set random seed for reproducibility.""" |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
class Closable(Protocol): |
|
def close(self): |
|
pass |
|
|
|
|
|
@contextlib.contextmanager |
|
def logged_closing(thing: Closable, name: str): |
|
""" |
|
Logging the closing to be sure something is not hanging at exit time |
|
""" |
|
try: |
|
setattr(thing, "wrapped_by_closing", True) |
|
yield |
|
finally: |
|
logger.info(f"Closing: {name}") |
|
try: |
|
thing.close() |
|
except Exception: |
|
logger.error(f"Error while closing {name}!") |
|
raise |
|
logger.info(f"Closed: {name}") |
|
|
|
|
|
def now_as_str() -> str: |
|
return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|