import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple from dataclasses import dataclass from einops import rearrange, repeat from flash_attn import flash_attn_func from .liger_rope import LigerRopeFunction from .rms_norm import LlamaRMSNorm from .config import LlamaConfig class CPLinear(nn.Module): def __init__(self, in_features, n_head, head_dim, kv_rank=2, q_rank=6): super().__init__() self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False) self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False) self.W_A_k = nn.Linear(in_features, n_head * kv_rank, bias=False) self.W_B_k = nn.Linear(in_features, kv_rank * head_dim, bias=False) self.W_A_v = nn.Linear(in_features, n_head * kv_rank, bias=False) self.W_B_v = nn.Linear(in_features, kv_rank * head_dim, bias=False) nn.init.xavier_uniform_(self.W_A_q.weight) nn.init.xavier_uniform_(self.W_B_q.weight) nn.init.xavier_uniform_(self.W_A_k.weight) nn.init.xavier_uniform_(self.W_B_k.weight) nn.init.xavier_uniform_(self.W_A_v.weight) nn.init.xavier_uniform_(self.W_B_v.weight) self.n_head = n_head self.q_rank = q_rank self.head_dim = head_dim self.kv_rank = kv_rank def forward(self, x): batch_size, seq_len, _ = x.size() A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank) A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.kv_rank) A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.kv_rank) B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim) B_k = self.W_B_k(x).view(batch_size, seq_len, self.kv_rank, self.head_dim) B_v = self.W_B_v(x).view(batch_size, seq_len, self.kv_rank, self.head_dim) A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank) A_k = A_k.view(batch_size * seq_len, self.n_head, self.kv_rank) A_v = A_v.view(batch_size * seq_len, self.n_head, self.kv_rank) B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim) B_k = B_k.view(batch_size * seq_len, self.kv_rank, self.head_dim) B_v = B_v.view(batch_size * seq_len, self.kv_rank, self.head_dim) q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim) k = torch.bmm(A_k, B_k).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim) v = torch.bmm(A_v, B_v).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim) return q, k, v class CausalTensorProductSelfAttn(nn.Module): def __init__(self, config, kv_rank=2, q_rank=6): super().__init__() self.n_head = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.n_embd = config.hidden_size self.rank = kv_rank self.q_rank = q_rank self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.rank, self.q_rank) self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=False) self.register_buffer( "cos_cached", self._compute_rope_embeddings( self.max_position_embeddings, self.head_dim, self.rope_theta, dtype=torch.float32, device=self.o_proj.weight.device, )[0], persistent=False, ) self.register_buffer( "sin_cached", self._compute_rope_embeddings( self.max_position_embeddings, self.head_dim, self.rope_theta, dtype=torch.float32, device=self.o_proj.weight.device, )[1], persistent=False, ) self.using_groupnorm = getattr(config, 'using_groupnorm', False) self.subln = LlamaRMSNorm(self.head_dim, eps=1e-5) def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype) sin = emb.sin().to(dtype) return cos.unsqueeze(0), sin.unsqueeze(0) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # In B S (H D) bsz, seq_len, _ = hidden_states.size() if position_ids is None: position_ids = torch.arange(seq_len, device=hidden_states.device) position_ids = repeat(position_ids, 'l -> b l', b=bsz) q, k, v = self.c_qkv(hidden_states) # B S (HD) -> B S H D cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim] sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim] q, k = LigerRopeFunction.apply( q, k, cos.squeeze(0), sin.squeeze(0), position_ids ) attn_out = flash_attn_func( q, k, v, dropout_p=0.0, causal=attention_mask is None ) attn_out = self.subln(attn_out) attn_out = rearrange(attn_out, "b s h d -> b s (h d)") attn_out = self.o_proj(attn_out) return attn_out