File size: 4,955 Bytes
d90b3a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
import torch
class FusedRoPEFunc(torch.autograd.Function):
"""
Fused RoPE function
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive
`.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
import fused_rotary_positional_embedding
output = fused_rotary_positional_embedding.forward(
t, freqs, transpose_output_memory
)
ctx.save_for_backward(freqs)
ctx.transpose_output_memory = transpose_output_memory
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
import fused_rotary_positional_embedding
(freqs,) = ctx.saved_tensors
grad_input = fused_rotary_positional_embedding.backward(
grad_output, freqs, ctx.transpose_output_memory
)
return grad_input, None, None
def fused_apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.
Args:
t (Tensor): Input tensor T is of shape [s, b, h, d]
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [s, 1, 1, d] and
`float` dtype
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
dimension of the output's underlying memory format. This is very helpful when you want to
get a contiguous tensor after calling `output.transpose(0, 1)`.
Returns:
Tensor: The input tensor after applying RoPE
"""
return FusedRoPEFunc.apply(t, freqs, transpose_output_memory)
class FusedRoPECachedFunc(torch.autograd.Function):
"""
Fused RoPE function
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive
`.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
cos_: torch.Tensor,
sin_: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
import fused_rotary_positional_embedding
output = fused_rotary_positional_embedding.forward_cached(
t, cos_, sin_, transpose_output_memory
)
ctx.save_for_backward(cos_, sin_)
ctx.transpose_output_memory = transpose_output_memory
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
import fused_rotary_positional_embedding
cos_, sin_ = ctx.saved_tensors
grad_input = fused_rotary_positional_embedding.backward_cached(
grad_output, cos_, sin_, ctx.transpose_output_memory
)
return grad_input, None, None, None
def fused_apply_rotary_pos_emb_cached(
t: torch.Tensor,
cos_: torch.Tensor,
sin_: torch.Tensor,
transpose_output_memory: bool = False,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.
Args:
t (Tensor): Input tensor T is of shape [s, b, h, d]
cos_ (Tensor): Cached cosine of the rotary positional embedding tensor is of
shape [s, 1, 1, d] and dtype either `float` or the same as `t`.
sin_ (Tensor): Cached sine of the rotary positional embedding tensor is of
shape [s, 1, 1, d] and dtype either `float` or the same as `t`.
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
dimension of the output's underlying memory format. This is very helpful when you want to
get a contiguous tensor after calling `output.transpose(0, 1)`.
Returns:
Tensor: The input tensor after applying RoPE
"""
return FusedRoPECachedFunc.apply(t, cos_, sin_, transpose_output_memory)
|