|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, Tuple |
|
from dataclasses import dataclass |
|
from einops import rearrange, repeat |
|
|
|
from flash_attn import flash_attn_func |
|
from .liger_rope import LigerRopeFunction |
|
from .rms_norm import LlamaRMSNorm |
|
from .config import LlamaConfig |
|
|
|
class CPLinear(nn.Module): |
|
def __init__(self, in_features, n_head, head_dim, kv_rank=2, q_rank=6): |
|
super().__init__() |
|
self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False) |
|
self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False) |
|
self.W_A_k = nn.Linear(in_features, n_head * kv_rank, bias=False) |
|
self.W_B_k = nn.Linear(in_features, kv_rank * head_dim, bias=False) |
|
self.W_A_v = nn.Linear(in_features, n_head * kv_rank, bias=False) |
|
self.W_B_v = nn.Linear(in_features, kv_rank * head_dim, bias=False) |
|
|
|
nn.init.xavier_uniform_(self.W_A_q.weight) |
|
nn.init.xavier_uniform_(self.W_B_q.weight) |
|
nn.init.xavier_uniform_(self.W_A_k.weight) |
|
nn.init.xavier_uniform_(self.W_B_k.weight) |
|
nn.init.xavier_uniform_(self.W_A_v.weight) |
|
nn.init.xavier_uniform_(self.W_B_v.weight) |
|
|
|
self.n_head = n_head |
|
self.q_rank = q_rank |
|
self.head_dim = head_dim |
|
self.kv_rank = kv_rank |
|
|
|
def forward(self, x): |
|
batch_size, seq_len, _ = x.size() |
|
|
|
A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank) |
|
A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.kv_rank) |
|
A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.kv_rank) |
|
|
|
B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim) |
|
B_k = self.W_B_k(x).view(batch_size, seq_len, self.kv_rank, self.head_dim) |
|
B_v = self.W_B_v(x).view(batch_size, seq_len, self.kv_rank, self.head_dim) |
|
|
|
A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank) |
|
A_k = A_k.view(batch_size * seq_len, self.n_head, self.kv_rank) |
|
A_v = A_v.view(batch_size * seq_len, self.n_head, self.kv_rank) |
|
|
|
B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim) |
|
B_k = B_k.view(batch_size * seq_len, self.kv_rank, self.head_dim) |
|
B_v = B_v.view(batch_size * seq_len, self.kv_rank, self.head_dim) |
|
|
|
q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
|
k = torch.bmm(A_k, B_k).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
|
v = torch.bmm(A_v, B_v).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
|
|
|
return q, k, v |
|
|
|
class CausalTensorProductSelfAttn(nn.Module): |
|
def __init__(self, config, kv_rank=2, q_rank=6): |
|
super().__init__() |
|
self.n_head = config.num_attention_heads |
|
self.head_dim = config.hidden_size // config.num_attention_heads |
|
self.n_embd = config.hidden_size |
|
self.rank = kv_rank |
|
self.q_rank = q_rank |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
|
|
self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.rank, self.q_rank) |
|
self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=False) |
|
|
|
self.register_buffer( |
|
"cos_cached", |
|
self._compute_rope_embeddings( |
|
self.max_position_embeddings, |
|
self.head_dim, |
|
self.rope_theta, |
|
dtype=torch.float32, |
|
device=self.o_proj.weight.device, |
|
)[0], |
|
persistent=False, |
|
) |
|
self.register_buffer( |
|
"sin_cached", |
|
self._compute_rope_embeddings( |
|
self.max_position_embeddings, |
|
self.head_dim, |
|
self.rope_theta, |
|
dtype=torch.float32, |
|
device=self.o_proj.weight.device, |
|
)[1], |
|
persistent=False, |
|
) |
|
|
|
self.using_groupnorm = getattr(config, 'using_groupnorm', False) |
|
self.subln = LlamaRMSNorm(self.head_dim, eps=1e-5) |
|
|
|
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): |
|
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
|
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) |
|
freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos().to(dtype) |
|
sin = emb.sin().to(dtype) |
|
return cos.unsqueeze(0), sin.unsqueeze(0) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
|
|
bsz, seq_len, _ = hidden_states.size() |
|
|
|
if position_ids is None: |
|
position_ids = torch.arange(seq_len, device=hidden_states.device) |
|
position_ids = repeat(position_ids, 'l -> b l', b=bsz) |
|
|
|
q, k, v = self.c_qkv(hidden_states) |
|
|
|
cos = self.cos_cached[:, position_ids] |
|
sin = self.sin_cached[:, position_ids] |
|
|
|
q, k = LigerRopeFunction.apply( |
|
q, |
|
k, |
|
cos.squeeze(0), |
|
sin.squeeze(0), |
|
position_ids |
|
) |
|
|
|
attn_out = flash_attn_func( |
|
q, |
|
k, |
|
v, |
|
dropout_p=0.0, |
|
causal=attention_mask is None |
|
) |
|
|
|
attn_out = self.subln(attn_out) |
|
|
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)") |
|
attn_out = self.o_proj(attn_out) |
|
return attn_out |