File size: 789 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from dataclasses import dataclass, field
from typing import Optional
from simple_parsing.helpers import Serializable
@dataclass
class LoraArgs(Serializable):
enable: bool = True
rank: int = 16
dropout: float = 0.0
scaling: float = 2.0
def __post_init__(self):
if self.enable:
assert self.rank > 0
assert self.scaling > 0.0
@dataclass
class MoeArgs(Serializable):
num_experts: int = 8
num_experts_per_tok: int = 2
@dataclass
class ModelArgs(Serializable):
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
rope_theta: float = 10000.0
lora: LoraArgs = field(default_factory=LoraArgs)
moe: Optional[MoeArgs] = None
|