CHEMISTral7Bv0.3 / model /transformer.py
Clemspace's picture
Initial model upload
cb9e677
import operator
from functools import partial, reduce
from typing import Iterable, List, Optional
import torch
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as torch_ckpt
import torch.nn as nn
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import AttentionBias, BlockDiagonalCausalMask
from .args import ModelArgs
from .lora import LoRALinear
from .moe import MoeLayer
from .rope import apply_rotary_emb, precompute_freqs_cis
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
return keys, values
def maybe_lora_layer(
args: ModelArgs, rank: Optional[int] = None
) -> partial[LoRALinear] | type[nn.Linear]:
MaybeLora: partial[LoRALinear] | type[nn.Linear]
if not args.lora.enable:
return nn.Linear
rank = rank or args.lora.rank
scaling = args.lora.scaling
dropout = args.lora.dropout
MaybeLora = partial(
LoRALinear,
rank=rank,
scaling=scaling,
dropout=dropout,
)
return MaybeLora
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.head_dim: int = args.head_dim
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
MaybeLora = maybe_lora_layer(args)
self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False)
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
mask: AttentionBias,
) -> torch.Tensor:
seqlen_sum, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim)
xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)
xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
key, val = xk, xv
# Repeat keys and values to match number of query heads
key, val = repeat_kv(key, val, self.repeats, dim=1)
# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(xq, key, val, mask)
return self.wo(output.view(seqlen_sum, -1))
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
MaybeLora = maybe_lora_layer(args)
self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False)
self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False)
self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False)
def forward(self, x) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward: MoeLayer | FeedForward
if args.moe is not None:
self.feed_forward = MoeLayer(
experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
moe_args=args.moe,
)
else:
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
att_mask: AttentionBias,
) -> torch.Tensor:
r = self.attention(self.attention_norm(x), freqs_cis, att_mask)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs, checkpoint: bool = False):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for _ in range(args.n_layers):
block: torch.nn.Module = TransformerBlock(args=args)
if checkpoint:
# activate gradient checkpointing as, see: https://pytorch.org/docs/stable/checkpoint.html
non_reentrant_wrapper = partial(
torch_ckpt.checkpoint_wrapper,
checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT,
)
block = non_reentrant_wrapper(block)
self.layers.append(block)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = torch.nn.Linear(
args.dim,
args.vocab_size,
bias=False,
)
# set lazily
self._freqs_cis = None
@property
def dtype(self) -> torch.dtype:
return self.tok_embeddings.weight.dtype
@property
def device(self) -> torch.device:
return self.tok_embeddings.weight.device
@property
def freqs_cis(self):
# lazy init
device = next(iter(self.parameters())).device
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis(
self.args.head_dim, 128_000, theta=self.args.rope_theta, device=device
)
return self._freqs_cis
def forward(
self,
input_ids: torch.Tensor,
seqlens: List[int],
) -> torch.Tensor:
assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])
h = self.tok_embeddings(input_ids)
positions = positions_from_sizes(seqlens, self.freqs_cis.device)
att_mask = BlockDiagonalCausalMask.from_seqlens(seqlens)
freqs_cis = self.freqs_cis[positions].to(device=h.device)
for layer in self.layers:
h = layer(h, freqs_cis, att_mask)
return self.output(self.norm(h)).float()
def positions_from_sizes(sizes: Iterable[int], device):
return torch.tensor(
reduce(operator.iadd, [list(range(s)) for s in sizes], []),
dtype=torch.long,
device=device,
)