File size: 4,161 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

from simple_parsing.helpers import Serializable

from model.args import LoraArgs

from .data.args import DataArgs


@dataclass
class OptimArgs(Serializable):
    lr: float = 3e-4
    weight_decay: float = 0.1
    pct_start: float = 0.3


@dataclass
class WandbArgs(Serializable):
    project: Optional[str] = None  # Fill this argument to use wandb.
    offline: bool = False
    key: Optional[str] = None
    run_name: Optional[str] = None

    def __post_init__(self) -> None:
        if self.project is not None:
            try:
                import wandb  # noqa: F401
            except ImportError:
                raise ImportError("`wandb` not installed. Either make sure `wandb` is installed or set `wandb:project` to None.")

            if len(self.project) == 0:
                raise ValueError("`wandb.project` must not be an empty string.")

@dataclass
class MLFlowArgs(Serializable):
    tracking_uri: Optional[str] = None
    experiment_name: Optional[str] = None

    def __post_init__(self) -> None:
        if self.tracking_uri is not None:
            try:
                import mlflow  # noqa: F401
            except ImportError:
                raise ImportError("`mlflow` not installed. Either make sure `mlflow` is installed or set `mlflow.tracking_uri` to None.")

            if self.experiment_name is None:
                raise ValueError("If `mlflow.tracking_uri` is set, `mlflow.experiment_name` must be set as well.")



@dataclass
class TrainArgs(Serializable):
    data: DataArgs

    # if specified, instruct_tokenizer and model will be loaded
    model_id_or_path: str  # Path to the directory containing the initial model or model id: "mistral-small"

    run_dir: str  # Path to the directory where everything will be saved. It needs to be empty.
    # Name of the wandb run, if None it will be set to the name of the run_dir.

    optim: OptimArgs = field(default_factory=OptimArgs)
    seed: int = 0
    # Number of steps to accumulate gradients before calling doing an optimizer step.
    num_microbatches: int = 1

    seq_len: int = 2048  # Number of tokens per batch per device.
    batch_size: int = 1
    max_norm: float = 1.0  # Gradient clipping.
    max_steps: int = 100  # Number of training steps.
    log_freq: int = 1  # Number of steps between each logging.

    # Number of steps between each checkpoint saving. If inferior to 1, only the last checkpoint will be saved.
    ckpt_freq: int = 0
    ckpt_only_lora: bool = True
    # If True, no checkpoint will be saved. This is useful for development.
    no_ckpt: bool = False
    num_ckpt_keep: Optional[int] = 3
    eval_freq: int = 0
    no_eval: bool = True

    # Efficiency
    # Determines whether gradient checkpointing should be utilized or not during the training process. Gradient checkpointing can be beneficial in reducing memory usage at the cost of slightly longer training times.
    checkpoint: bool = True

    world_size: Optional[int] = field(init=False, default=None)

    # logging
    wandb: WandbArgs = field(default_factory=WandbArgs)
    mlflow: MLFlowArgs = field(default_factory=MLFlowArgs)

    # LoRA
    lora: Optional[LoraArgs] = field(default_factory=LoraArgs)

    def __post_init__(self) -> None:
        assert getattr(self, "world_size", None) is None
        self.world_size = int(os.environ.get("WORLD_SIZE", -1))

        if self.wandb.offline:
            command = f"cd {self.run_dir}; wandb sync --sync-all"
            logging.info(f"to sync wandb offline, run: {command}")

        assert self.num_microbatches >= 1

        assert self.num_ckpt_keep is None or self.num_ckpt_keep >= 1

        if self.model_id_or_path is not None:
            Path(self.model_id_or_path).exists()

        if not self.ckpt_only_lora:
            logging.warning(
                "You are have disabled `ckpt_only_lora` and are thus merging the trained LoRA checkpoint into the base model upon checkpointing. This might lead to OOM erros - make sure you have enough CPU and GPU memory."
            )