Blackroot's picture
Upload 18 files
6aced58 verified
raw
history blame
4.41 kB
from flash_attn import flash_attn_func
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .liger_rope import LigerRopeFunction
from .config import LlamaConfig
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_attention_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.register_buffer(
"cos_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.q_proj.weight.device,
)[0],
persistent=False,
)
self.register_buffer(
"sin_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.q_proj.weight.device,
)[1],
persistent=False,
)
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
return cos.unsqueeze(0), sin.unsqueeze(0)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# In B S (H D)
bsz, seq_len, _ = hidden_states.size()
if position_ids is None:
position_ids = torch.arange(seq_len, device=hidden_states.device)
position_ids = repeat(position_ids, 'l -> b l', b=bsz)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.num_heads, d=self.head_dim)
key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)
value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)
# Slice off position specific rope freqs from the cached freqs
cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
query_states, key_states = LigerRopeFunction.apply(
query_states,
key_states,
cos.squeeze(0),
sin.squeeze(0),
position_ids
)
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=0.0,
causal=attention_mask is None
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
return self.o_proj(attn_output)