|
import operator |
|
from functools import partial, reduce |
|
from typing import Iterable, List, Optional |
|
|
|
import torch |
|
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as torch_ckpt |
|
import torch.nn as nn |
|
from xformers.ops.fmha import memory_efficient_attention |
|
from xformers.ops.fmha.attn_bias import AttentionBias, BlockDiagonalCausalMask |
|
|
|
from .args import ModelArgs |
|
from .lora import LoRALinear |
|
from .moe import MoeLayer |
|
from .rope import apply_rotary_emb, precompute_freqs_cis |
|
|
|
|
|
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int): |
|
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) |
|
values = torch.repeat_interleave(values, repeats=repeats, dim=dim) |
|
return keys, values |
|
|
|
|
|
def maybe_lora_layer( |
|
args: ModelArgs, rank: Optional[int] = None |
|
) -> partial[LoRALinear] | type[nn.Linear]: |
|
MaybeLora: partial[LoRALinear] | type[nn.Linear] |
|
if not args.lora.enable: |
|
return nn.Linear |
|
|
|
rank = rank or args.lora.rank |
|
scaling = args.lora.scaling |
|
dropout = args.lora.dropout |
|
|
|
MaybeLora = partial( |
|
LoRALinear, |
|
rank=rank, |
|
scaling=scaling, |
|
dropout=dropout, |
|
) |
|
|
|
return MaybeLora |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
self.args = args |
|
|
|
self.n_heads: int = args.n_heads |
|
self.n_kv_heads: int = args.n_kv_heads |
|
self.head_dim: int = args.head_dim |
|
|
|
self.repeats = self.n_heads // self.n_kv_heads |
|
|
|
self.scale = self.args.head_dim**-0.5 |
|
|
|
MaybeLora = maybe_lora_layer(args) |
|
|
|
self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False) |
|
self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) |
|
self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) |
|
|
|
self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
mask: AttentionBias, |
|
) -> torch.Tensor: |
|
seqlen_sum, _ = x.shape |
|
|
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
|
|
|
xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim) |
|
xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) |
|
xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) |
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) |
|
|
|
key, val = xk, xv |
|
|
|
|
|
key, val = repeat_kv(key, val, self.repeats, dim=1) |
|
|
|
|
|
xq, key, val = xq[None, ...], key[None, ...], val[None, ...] |
|
output = memory_efficient_attention(xq, key, val, mask) |
|
|
|
return self.wo(output.view(seqlen_sum, -1)) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
|
|
MaybeLora = maybe_lora_layer(args) |
|
self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False) |
|
self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False) |
|
self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False) |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def _norm(self, x): |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
output = self._norm(x.float()).type_as(x) |
|
return output * self.weight |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
self.n_heads = args.n_heads |
|
self.dim = args.dim |
|
self.attention = Attention(args) |
|
|
|
self.feed_forward: MoeLayer | FeedForward |
|
if args.moe is not None: |
|
self.feed_forward = MoeLayer( |
|
experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], |
|
gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), |
|
moe_args=args.moe, |
|
) |
|
else: |
|
self.feed_forward = FeedForward(args=args) |
|
|
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) |
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) |
|
self.args = args |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
att_mask: AttentionBias, |
|
) -> torch.Tensor: |
|
r = self.attention(self.attention_norm(x), freqs_cis, att_mask) |
|
h = x + r |
|
|
|
r = self.feed_forward(self.ffn_norm(h)) |
|
out = h + r |
|
|
|
return out |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, args: ModelArgs, checkpoint: bool = False): |
|
super().__init__() |
|
self.args = args |
|
self.vocab_size = args.vocab_size |
|
self.n_layers = args.n_layers |
|
assert self.vocab_size > 0 |
|
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) |
|
self.layers = torch.nn.ModuleList() |
|
for _ in range(args.n_layers): |
|
block: torch.nn.Module = TransformerBlock(args=args) |
|
if checkpoint: |
|
|
|
non_reentrant_wrapper = partial( |
|
torch_ckpt.checkpoint_wrapper, |
|
checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT, |
|
) |
|
block = non_reentrant_wrapper(block) |
|
|
|
self.layers.append(block) |
|
|
|
self.norm = RMSNorm(args.dim, eps=args.norm_eps) |
|
|
|
self.output = torch.nn.Linear( |
|
args.dim, |
|
args.vocab_size, |
|
bias=False, |
|
) |
|
|
|
|
|
self._freqs_cis = None |
|
|
|
@property |
|
def dtype(self) -> torch.dtype: |
|
return self.tok_embeddings.weight.dtype |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
return self.tok_embeddings.weight.device |
|
|
|
@property |
|
def freqs_cis(self): |
|
|
|
device = next(iter(self.parameters())).device |
|
if self._freqs_cis is None: |
|
self._freqs_cis = precompute_freqs_cis( |
|
self.args.head_dim, 128_000, theta=self.args.rope_theta, device=device |
|
) |
|
|
|
return self._freqs_cis |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
seqlens: List[int], |
|
) -> torch.Tensor: |
|
assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0]) |
|
|
|
h = self.tok_embeddings(input_ids) |
|
positions = positions_from_sizes(seqlens, self.freqs_cis.device) |
|
att_mask = BlockDiagonalCausalMask.from_seqlens(seqlens) |
|
|
|
freqs_cis = self.freqs_cis[positions].to(device=h.device) |
|
|
|
for layer in self.layers: |
|
h = layer(h, freqs_cis, att_mask) |
|
|
|
return self.output(self.norm(h)).float() |
|
|
|
|
|
def positions_from_sizes(sizes: Iterable[int], device): |
|
return torch.tensor( |
|
reduce(operator.iadd, [list(range(s)) for s in sizes], []), |
|
dtype=torch.long, |
|
device=device, |
|
) |
|
|