Blackroot's picture
Upload 18 files
6aced58 verified
raw
history blame
5.92 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from dataclasses import dataclass
from einops import rearrange, repeat
from flash_attn import flash_attn_func
from .liger_rope import LigerRopeFunction
from .rms_norm import LlamaRMSNorm
from .config import LlamaConfig
class CPLinear(nn.Module):
def __init__(self, in_features, n_head, head_dim, kv_rank=2, q_rank=6):
super().__init__()
self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False)
self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False)
self.W_A_k = nn.Linear(in_features, n_head * kv_rank, bias=False)
self.W_B_k = nn.Linear(in_features, kv_rank * head_dim, bias=False)
self.W_A_v = nn.Linear(in_features, n_head * kv_rank, bias=False)
self.W_B_v = nn.Linear(in_features, kv_rank * head_dim, bias=False)
nn.init.xavier_uniform_(self.W_A_q.weight)
nn.init.xavier_uniform_(self.W_B_q.weight)
nn.init.xavier_uniform_(self.W_A_k.weight)
nn.init.xavier_uniform_(self.W_B_k.weight)
nn.init.xavier_uniform_(self.W_A_v.weight)
nn.init.xavier_uniform_(self.W_B_v.weight)
self.n_head = n_head
self.q_rank = q_rank
self.head_dim = head_dim
self.kv_rank = kv_rank
def forward(self, x):
batch_size, seq_len, _ = x.size()
A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank)
A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.kv_rank)
A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.kv_rank)
B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim)
B_k = self.W_B_k(x).view(batch_size, seq_len, self.kv_rank, self.head_dim)
B_v = self.W_B_v(x).view(batch_size, seq_len, self.kv_rank, self.head_dim)
A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank)
A_k = A_k.view(batch_size * seq_len, self.n_head, self.kv_rank)
A_v = A_v.view(batch_size * seq_len, self.n_head, self.kv_rank)
B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim)
B_k = B_k.view(batch_size * seq_len, self.kv_rank, self.head_dim)
B_v = B_v.view(batch_size * seq_len, self.kv_rank, self.head_dim)
q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
k = torch.bmm(A_k, B_k).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
v = torch.bmm(A_v, B_v).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
return q, k, v
class CausalTensorProductSelfAttn(nn.Module):
def __init__(self, config, kv_rank=2, q_rank=6):
super().__init__()
self.n_head = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.n_embd = config.hidden_size
self.rank = kv_rank
self.q_rank = q_rank
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.rank, self.q_rank)
self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=False)
self.register_buffer(
"cos_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.o_proj.weight.device,
)[0],
persistent=False,
)
self.register_buffer(
"sin_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.o_proj.weight.device,
)[1],
persistent=False,
)
self.using_groupnorm = getattr(config, 'using_groupnorm', False)
self.subln = LlamaRMSNorm(self.head_dim, eps=1e-5)
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
return cos.unsqueeze(0), sin.unsqueeze(0)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# In B S (H D)
bsz, seq_len, _ = hidden_states.size()
if position_ids is None:
position_ids = torch.arange(seq_len, device=hidden_states.device)
position_ids = repeat(position_ids, 'l -> b l', b=bsz)
q, k, v = self.c_qkv(hidden_states) # B S (HD) -> B S H D
cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
q, k = LigerRopeFunction.apply(
q,
k,
cos.squeeze(0),
sin.squeeze(0),
position_ids
)
attn_out = flash_attn_func(
q,
k,
v,
dropout_p=0.0,
causal=attention_mask is None
)
attn_out = self.subln(attn_out)
attn_out = rearrange(attn_out, "b s h d -> b s (h d)")
attn_out = self.o_proj(attn_out)
return attn_out