|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
import torch |
|
|
|
|
|
class FusedRoPEFunc(torch.autograd.Function): |
|
""" |
|
Fused RoPE function |
|
|
|
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be |
|
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive |
|
`.contiguous()` calls, thus it may not achieve the best memory access pattern. |
|
""" |
|
|
|
@staticmethod |
|
def forward( |
|
ctx, |
|
t: torch.Tensor, |
|
freqs: torch.Tensor, |
|
transpose_output_memory: bool = False, |
|
) -> torch.Tensor: |
|
import fused_rotary_positional_embedding |
|
|
|
output = fused_rotary_positional_embedding.forward( |
|
t, freqs, transpose_output_memory |
|
) |
|
ctx.save_for_backward(freqs) |
|
ctx.transpose_output_memory = transpose_output_memory |
|
|
|
return output |
|
|
|
@staticmethod |
|
def backward( |
|
ctx, grad_output: torch.Tensor |
|
) -> Tuple[Union[torch.Tensor, None], ...]: |
|
import fused_rotary_positional_embedding |
|
|
|
(freqs,) = ctx.saved_tensors |
|
grad_input = fused_rotary_positional_embedding.backward( |
|
grad_output, freqs, ctx.transpose_output_memory |
|
) |
|
|
|
return grad_input, None, None |
|
|
|
|
|
def fused_apply_rotary_pos_emb( |
|
t: torch.Tensor, |
|
freqs: torch.Tensor, |
|
transpose_output_memory: bool = False, |
|
) -> torch.Tensor: |
|
"""Apply rotary positional embedding to input tensor T. |
|
|
|
Args: |
|
t (Tensor): Input tensor T is of shape [s, b, h, d] |
|
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [s, 1, 1, d] and |
|
`float` dtype |
|
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' |
|
dimension of the output's underlying memory format. This is very helpful when you want to |
|
get a contiguous tensor after calling `output.transpose(0, 1)`. |
|
|
|
Returns: |
|
Tensor: The input tensor after applying RoPE |
|
""" |
|
return FusedRoPEFunc.apply(t, freqs, transpose_output_memory) |
|
|
|
|
|
class FusedRoPECachedFunc(torch.autograd.Function): |
|
""" |
|
Fused RoPE function |
|
|
|
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be |
|
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive |
|
`.contiguous()` calls, thus it may not achieve the best memory access pattern. |
|
""" |
|
|
|
@staticmethod |
|
def forward( |
|
ctx, |
|
t: torch.Tensor, |
|
cos_: torch.Tensor, |
|
sin_: torch.Tensor, |
|
transpose_output_memory: bool = False, |
|
) -> torch.Tensor: |
|
import fused_rotary_positional_embedding |
|
|
|
output = fused_rotary_positional_embedding.forward_cached( |
|
t, cos_, sin_, transpose_output_memory |
|
) |
|
ctx.save_for_backward(cos_, sin_) |
|
ctx.transpose_output_memory = transpose_output_memory |
|
|
|
return output |
|
|
|
@staticmethod |
|
def backward( |
|
ctx, grad_output: torch.Tensor |
|
) -> Tuple[Union[torch.Tensor, None], ...]: |
|
import fused_rotary_positional_embedding |
|
|
|
cos_, sin_ = ctx.saved_tensors |
|
grad_input = fused_rotary_positional_embedding.backward_cached( |
|
grad_output, cos_, sin_, ctx.transpose_output_memory |
|
) |
|
|
|
return grad_input, None, None, None |
|
|
|
|
|
def fused_apply_rotary_pos_emb_cached( |
|
t: torch.Tensor, |
|
cos_: torch.Tensor, |
|
sin_: torch.Tensor, |
|
transpose_output_memory: bool = False, |
|
) -> torch.Tensor: |
|
"""Apply rotary positional embedding to input tensor T. |
|
|
|
Args: |
|
t (Tensor): Input tensor T is of shape [s, b, h, d] |
|
cos_ (Tensor): Cached cosine of the rotary positional embedding tensor is of |
|
shape [s, 1, 1, d] and dtype either `float` or the same as `t`. |
|
sin_ (Tensor): Cached sine of the rotary positional embedding tensor is of |
|
shape [s, 1, 1, d] and dtype either `float` or the same as `t`. |
|
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' |
|
dimension of the output's underlying memory format. This is very helpful when you want to |
|
get a contiguous tensor after calling `output.transpose(0, 1)`. |
|
|
|
Returns: |
|
Tensor: The input tensor after applying RoPE |
|
""" |
|
return FusedRoPECachedFunc.apply(t, cos_, sin_, transpose_output_memory) |
|
|