|
import logging |
|
import os |
|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
from simple_parsing.helpers import Serializable |
|
|
|
from model.args import LoraArgs |
|
|
|
from .data.args import DataArgs |
|
|
|
|
|
@dataclass |
|
class OptimArgs(Serializable): |
|
lr: float = 3e-4 |
|
weight_decay: float = 0.1 |
|
pct_start: float = 0.3 |
|
|
|
|
|
@dataclass |
|
class WandbArgs(Serializable): |
|
project: Optional[str] = None |
|
offline: bool = False |
|
key: Optional[str] = None |
|
run_name: Optional[str] = None |
|
|
|
def __post_init__(self) -> None: |
|
if self.project is not None: |
|
try: |
|
import wandb |
|
except ImportError: |
|
raise ImportError("`wandb` not installed. Either make sure `wandb` is installed or set `wandb:project` to None.") |
|
|
|
if len(self.project) == 0: |
|
raise ValueError("`wandb.project` must not be an empty string.") |
|
|
|
@dataclass |
|
class MLFlowArgs(Serializable): |
|
tracking_uri: Optional[str] = None |
|
experiment_name: Optional[str] = None |
|
|
|
def __post_init__(self) -> None: |
|
if self.tracking_uri is not None: |
|
try: |
|
import mlflow |
|
except ImportError: |
|
raise ImportError("`mlflow` not installed. Either make sure `mlflow` is installed or set `mlflow.tracking_uri` to None.") |
|
|
|
if self.experiment_name is None: |
|
raise ValueError("If `mlflow.tracking_uri` is set, `mlflow.experiment_name` must be set as well.") |
|
|
|
|
|
|
|
@dataclass |
|
class TrainArgs(Serializable): |
|
data: DataArgs |
|
|
|
|
|
model_id_or_path: str |
|
|
|
run_dir: str |
|
|
|
|
|
optim: OptimArgs = field(default_factory=OptimArgs) |
|
seed: int = 0 |
|
|
|
num_microbatches: int = 1 |
|
|
|
seq_len: int = 2048 |
|
batch_size: int = 1 |
|
max_norm: float = 1.0 |
|
max_steps: int = 100 |
|
log_freq: int = 1 |
|
|
|
|
|
ckpt_freq: int = 0 |
|
ckpt_only_lora: bool = True |
|
|
|
no_ckpt: bool = False |
|
num_ckpt_keep: Optional[int] = 3 |
|
eval_freq: int = 0 |
|
no_eval: bool = True |
|
|
|
|
|
|
|
checkpoint: bool = True |
|
|
|
world_size: Optional[int] = field(init=False, default=None) |
|
|
|
|
|
wandb: WandbArgs = field(default_factory=WandbArgs) |
|
mlflow: MLFlowArgs = field(default_factory=MLFlowArgs) |
|
|
|
|
|
lora: Optional[LoraArgs] = field(default_factory=LoraArgs) |
|
|
|
def __post_init__(self) -> None: |
|
assert getattr(self, "world_size", None) is None |
|
self.world_size = int(os.environ.get("WORLD_SIZE", -1)) |
|
|
|
if self.wandb.offline: |
|
command = f"cd {self.run_dir}; wandb sync --sync-all" |
|
logging.info(f"to sync wandb offline, run: {command}") |
|
|
|
assert self.num_microbatches >= 1 |
|
|
|
assert self.num_ckpt_keep is None or self.num_ckpt_keep >= 1 |
|
|
|
if self.model_id_or_path is not None: |
|
Path(self.model_id_or_path).exists() |
|
|
|
if not self.ckpt_only_lora: |
|
logging.warning( |
|
"You are have disabled `ckpt_only_lora` and are thus merging the trained LoRA checkpoint into the base model upon checkpointing. This might lead to OOM erros - make sure you have enough CPU and GPU memory." |
|
) |
|
|