|
from flash_attn import flash_attn_func |
|
from typing import Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from .liger_rope import LigerRopeFunction |
|
from .config import LlamaConfig |
|
|
|
class LlamaAttention(nn.Module): |
|
def __init__(self, config: LlamaConfig): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.head_dim = config.hidden_size // config.num_attention_heads |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size: |
|
raise ValueError( |
|
f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size}" |
|
f" and `num_attention_heads`: {self.num_heads})." |
|
) |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
self.register_buffer( |
|
"cos_cached", |
|
self._compute_rope_embeddings( |
|
self.max_position_embeddings, |
|
self.head_dim, |
|
self.rope_theta, |
|
dtype=torch.float32, |
|
device=self.q_proj.weight.device, |
|
)[0], |
|
persistent=False, |
|
) |
|
self.register_buffer( |
|
"sin_cached", |
|
self._compute_rope_embeddings( |
|
self.max_position_embeddings, |
|
self.head_dim, |
|
self.rope_theta, |
|
dtype=torch.float32, |
|
device=self.q_proj.weight.device, |
|
)[1], |
|
persistent=False, |
|
) |
|
|
|
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): |
|
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
|
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) |
|
freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos().to(dtype) |
|
sin = emb.sin().to(dtype) |
|
return cos.unsqueeze(0), sin.unsqueeze(0) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
|
|
bsz, seq_len, _ = hidden_states.size() |
|
|
|
if position_ids is None: |
|
position_ids = torch.arange(seq_len, device=hidden_states.device) |
|
position_ids = repeat(position_ids, 'l -> b l', b=bsz) |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.num_heads, d=self.head_dim) |
|
key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim) |
|
value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim) |
|
|
|
|
|
cos = self.cos_cached[:, position_ids] |
|
sin = self.sin_cached[:, position_ids] |
|
|
|
query_states, key_states = LigerRopeFunction.apply( |
|
query_states, |
|
key_states, |
|
cos.squeeze(0), |
|
sin.squeeze(0), |
|
position_ids |
|
) |
|
|
|
attn_output = flash_attn_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
dropout_p=0.0, |
|
causal=attention_mask is None |
|
) |
|
|
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)") |
|
return self.o_proj(attn_output) |