NEOX / megatron /model /fused_rope.py
akswelh's picture
Upload 251 files
d90b3a8 verified
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
import torch
class FusedRoPEFunc(torch.autograd.Function):
"""
Fused RoPE function
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive
`.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
import fused_rotary_positional_embedding
output = fused_rotary_positional_embedding.forward(
t, freqs, transpose_output_memory
)
ctx.save_for_backward(freqs)
ctx.transpose_output_memory = transpose_output_memory
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
import fused_rotary_positional_embedding
(freqs,) = ctx.saved_tensors
grad_input = fused_rotary_positional_embedding.backward(
grad_output, freqs, ctx.transpose_output_memory
)
return grad_input, None, None
def fused_apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.
Args:
t (Tensor): Input tensor T is of shape [s, b, h, d]
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [s, 1, 1, d] and
`float` dtype
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
dimension of the output's underlying memory format. This is very helpful when you want to
get a contiguous tensor after calling `output.transpose(0, 1)`.
Returns:
Tensor: The input tensor after applying RoPE
"""
return FusedRoPEFunc.apply(t, freqs, transpose_output_memory)
class FusedRoPECachedFunc(torch.autograd.Function):
"""
Fused RoPE function
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive
`.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
cos_: torch.Tensor,
sin_: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
import fused_rotary_positional_embedding
output = fused_rotary_positional_embedding.forward_cached(
t, cos_, sin_, transpose_output_memory
)
ctx.save_for_backward(cos_, sin_)
ctx.transpose_output_memory = transpose_output_memory
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
import fused_rotary_positional_embedding
cos_, sin_ = ctx.saved_tensors
grad_input = fused_rotary_positional_embedding.backward_cached(
grad_output, cos_, sin_, ctx.transpose_output_memory
)
return grad_input, None, None, None
def fused_apply_rotary_pos_emb_cached(
t: torch.Tensor,
cos_: torch.Tensor,
sin_: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.
Args:
t (Tensor): Input tensor T is of shape [s, b, h, d]
cos_ (Tensor): Cached cosine of the rotary positional embedding tensor is of
shape [s, 1, 1, d] and dtype either `float` or the same as `t`.
sin_ (Tensor): Cached sine of the rotary positional embedding tensor is of
shape [s, 1, 1, d] and dtype either `float` or the same as `t`.
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
dimension of the output's underlying memory format. This is very helpful when you want to
get a contiguous tensor after calling `output.transpose(0, 1)`.
Returns:
Tensor: The input tensor after applying RoPE
"""
return FusedRoPECachedFunc.apply(t, cos_, sin_, transpose_output_memory)