|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import math |
|
|
|
|
|
class SinusoidalPositionalEmbedding(torch.nn.Module): |
|
def __init__(self, dim, base=10000, precision=torch.half): |
|
super().__init__() |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
self.precision = precision |
|
|
|
def forward(self, x, seq_dim=1): |
|
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) |
|
sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) |
|
if self.precision == torch.bfloat16: |
|
sinusoid_inp = sinusoid_inp.float() |
|
sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() |
|
if self.precision == torch.bfloat16: |
|
sin, cos = sin.bfloat16(), cos.bfloat16() |
|
emb = torch.cat((sin, cos), dim=-1) |
|
return emb[None, :, :] |
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
def __init__( |
|
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False |
|
): |
|
super().__init__() |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) |
|
self.seq_len_cached = None |
|
self.cos_cached = None |
|
self.sin_cached = None |
|
self.precision = precision |
|
self.max_seq_len = max_seq_len |
|
self.base = base |
|
self.dim = dim |
|
|
|
|
|
cos_cached, sin_cached, inv_freq = self._prepare_cache( |
|
max_seq_len, precision, base |
|
) |
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) |
|
self.cos_cached = cos_cached |
|
self.sin_cached = sin_cached |
|
|
|
def _prepare_cache(self, seq_len, precision, base): |
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
|
|
|
t = torch.arange(seq_len).type_as(inv_freq) |
|
freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
cos_cached = emb.cos()[:, None, None, :] |
|
sin_cached = emb.sin()[:, None, None, :] |
|
|
|
return ( |
|
cos_cached.to(precision), |
|
sin_cached.to(precision), |
|
inv_freq.to(precision), |
|
) |
|
|
|
def forward(self, x, seq_dim=0, seq_len=None): |
|
if seq_len is None: |
|
seq_len = x.shape[seq_dim] |
|
|
|
assert seq_len <= self.max_seq_len |
|
|
|
if seq_len != self.max_seq_len: |
|
|
|
return ( |
|
self.cos_cached[:seq_len, ...].to(x.device), |
|
self.sin_cached[:seq_len, ...].to(x.device), |
|
) |
|
else: |
|
return self.cos_cached.to(x.device), self.sin_cached.to(x.device) |
|
|
|
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
|
return torch.cat( |
|
(-x2, x1), dim=x1.ndim - 1 |
|
) |
|
|
|
|
|
@torch.jit.script |
|
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): |
|
cos, sin = ( |
|
cos[offset : q.shape[0] + offset, ...], |
|
sin[offset : q.shape[0] + offset, ...], |
|
) |
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) |
|
|
|
|
|
def apply_rotary_pos_emb_torch( |
|
q, k, cos, sin, offset: int = 0 |
|
): |
|
cos, sin = ( |
|
cos[offset : q.shape[0] + offset, ...], |
|
sin[offset : q.shape[0] + offset, ...], |
|
) |
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) |
|
|
|
|
|
class AliBi(torch.nn.Module): |
|
def __init__(self, num_heads, mp_size=1, mp_rank=1): |
|
super().__init__() |
|
|
|
|
|
assert mp_size <= num_heads and mp_rank <= mp_size |
|
self.mp_size = mp_size |
|
self.mp_rank = mp_rank |
|
self.num_heads = num_heads |
|
self.slice_size = num_heads // mp_size |
|
self.cached_matrix = None |
|
self.cached_seq_len = None |
|
slopes = torch.Tensor(self._get_slopes(num_heads))[ |
|
mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size |
|
] |
|
self.register_buffer("slopes", slopes) |
|
|
|
def _get_slopes(self, n): |
|
""" |
|
Get slopes for Alibi positional embedding |
|
n : int = number of heads. |
|
For best performance, restrict n to a power of 2. |
|
""" |
|
|
|
def get_slopes_power_of_2(n): |
|
start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
|
ratio = start |
|
return [start * ratio**i for i in range(n)] |
|
|
|
if math.log2(n).is_integer(): |
|
return get_slopes_power_of_2(n) |
|
else: |
|
closest_power_of_2 = 2 ** math.floor(math.log2(n)) |
|
return ( |
|
get_slopes_power_of_2(closest_power_of_2) |
|
+ self._get_slopes(2 * closest_power_of_2)[0::2][ |
|
: n - closest_power_of_2 |
|
] |
|
) |
|
|
|
def bias(self, seq_len_q, seq_len_k, device, dtype): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: |
|
a = self.cached_matrix |
|
else: |
|
target_seq_len = ( |
|
seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 |
|
) |
|
a = -torch.tril( |
|
torch.arange(target_seq_len) |
|
.view(target_seq_len, 1) |
|
.repeat(1, target_seq_len) |
|
+ torch.arange(0, -target_seq_len, -1) |
|
) |
|
a = a.to(device).to(dtype) |
|
slopes = self.slopes.to(a.device).to(a.dtype) |
|
a = a * slopes.view(self.slopes.shape[0], 1, 1) |
|
self.cached_seq_len = target_seq_len |
|
self.cached_matrix = a |
|
|
|
|
|
if self.cached_seq_len > seq_len_k: |
|
a = self.cached_matrix[:, :seq_len_k, :seq_len_k] |
|
|
|
if seq_len_q != seq_len_k: |
|
|
|
|
|
|
|
|
|
|
|
assert ( |
|
seq_len_q == 1 |
|
), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" |
|
a = a[:, seq_len_k - 1, :].view( |
|
a.shape[0], 1, a.shape[2] |
|
) |
|
|
|
return a |
|
|
|
def forward(self, x): |
|
|
|
seq_len_q = x.shape[-2] |
|
seq_len_k = x.shape[-1] |
|
|
|
|
|
|
|
|
|
if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: |
|
a = self.cached_matrix |
|
else: |
|
target_seq_len = ( |
|
seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 |
|
) |
|
a = -torch.tril( |
|
torch.arange(target_seq_len) |
|
.view(target_seq_len, 1) |
|
.repeat(1, target_seq_len) |
|
+ torch.arange(0, -target_seq_len, -1) |
|
) |
|
a = a.to(x.device).to(x.dtype) |
|
slopes = self.slopes.to(a.device).to(a.dtype) |
|
a = a * slopes.view(self.slopes.shape[0], 1, 1) |
|
self.cached_seq_len = target_seq_len |
|
self.cached_matrix = a |
|
|
|
|
|
if self.cached_seq_len > seq_len_k: |
|
a = self.cached_matrix[:, :seq_len_k, :seq_len_k] |
|
|
|
if seq_len_q != seq_len_k: |
|
|
|
|
|
|
|
|
|
|
|
assert ( |
|
seq_len_q == 1 |
|
), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" |
|
a = a[:, seq_len_k - 1, :].view( |
|
a.shape[0], 1, a.shape[2] |
|
) |
|
|
|
return x + a |
|
|