NEOX / megatron /fused_kernels /fused_rotary_positional_embedding_cuda.cu
akswelh's picture
Upload 251 files
d90b3a8 verified
/* coding=utf-8
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include "fused_rotary_positional_embedding.h"
#include "type_shim.h"
namespace fused_rope {
torch::Tensor fwd_cuda(const torch::Tensor& input,
const torch::Tensor& freqs,
const bool transpose_output)
{
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
torch::Tensor output;
if (transpose_output) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
DISPATCH_FLOAT_HALF_AND_BFLOAT(input.scalar_type(),
0,
"dispatch_fused_rope_forward",
dispatch_fused_rope_forward(s,
b,
h,
d,
d2,
stride_s,
stride_b,
stride_h,
stride_d,
o_stride_s,
o_stride_b,
o_stride_h,
o_stride_d,
input.data_ptr<scalar_t_0>(),
freqs.data_ptr<float>(),
output.data_ptr<scalar_t_0>()););
return output;
}
torch::Tensor bwd_cuda(const torch::Tensor& output_grads,
const torch::Tensor& freqs,
const bool transpose_output)
{
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads;
if (transpose_output) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
DISPATCH_FLOAT_HALF_AND_BFLOAT(
output_grads.scalar_type(),
0,
"dispatch_fused_rope_backward",
dispatch_fused_rope_backward(s,
b,
h,
d,
d2,
stride_s,
stride_b,
stride_h,
stride_d,
o_stride_s,
o_stride_b,
o_stride_h,
o_stride_d,
output_grads.data_ptr<scalar_t_0>(),
freqs.data_ptr<float>(),
input_grads.data_ptr<scalar_t_0>()););
return input_grads;
}
#define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \
switch (TYPE1) { \
case at::ScalarType::Float: { \
using scalar_t_0 = float; \
switch (TYPE2) { \
case at::ScalarType::Float: { \
using scalar_t_1 = float; \
__VA_ARGS__; \
break; \
} \
default: \
TORCH_CHECK(false, \
#NAME, \
" not supported for '", \
toString(TYPE1), \
"' with '", \
toString(TYPE2), \
"'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_0 = at::Half; \
switch (TYPE2) { \
case at::ScalarType::Float: { \
using scalar_t_1 = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_1 = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
TORCH_CHECK(false, \
#NAME, \
" not supported for '", \
toString(TYPE1), \
"' with '", \
toString(TYPE2), \
"'"); \
} \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_0 = at::BFloat16; \
switch (TYPE2) { \
case at::ScalarType::Float: { \
using scalar_t_1 = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_1 = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
TORCH_CHECK(false, \
#NAME, \
" not supported for '", \
toString(TYPE1), \
"' with '", \
toString(TYPE2), \
"'"); \
} \
break; \
} \
default: \
TORCH_CHECK(false, \
#NAME, \
" not supported for '", \
toString(TYPE1), \
"' with '", \
toString(TYPE2), \
"'"); \
}
torch::Tensor fwd_cached_cuda(const torch::Tensor& input,
const torch::Tensor& cos,
const torch::Tensor& sin,
const bool transpose_output)
{
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = cos.size(3);
// output
auto act_options = input.options().requires_grad(false);
torch::Tensor output;
if (transpose_output) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
DISPATCH_FUSED_ROPE_TYPES(input.scalar_type(),
cos.scalar_type(),
"dispatch_fused_rope_cached_forward",
dispatch_fused_rope_cached_forward(s,
b,
h,
d,
d2,
stride_s,
stride_b,
stride_h,
stride_d,
o_stride_s,
o_stride_b,
o_stride_h,
o_stride_d,
input.data_ptr<scalar_t_0>(),
cos.data_ptr<scalar_t_1>(),
sin.data_ptr<scalar_t_1>(),
output.data_ptr<scalar_t_0>()););
return output;
}
torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads,
const torch::Tensor& cos,
const torch::Tensor& sin,
const bool transpose_output)
{
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = cos.size(3);
auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads;
if (transpose_output) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
DISPATCH_FUSED_ROPE_TYPES(
output_grads.scalar_type(),
cos.scalar_type(),
"dispatch_fused_rope_cached_backward",
dispatch_fused_rope_cached_backward(s,
b,
h,
d,
d2,
stride_s,
stride_b,
stride_h,
stride_d,
o_stride_s,
o_stride_b,
o_stride_h,
o_stride_d,
output_grads.data_ptr<scalar_t_0>(),
cos.data_ptr<scalar_t_1>(),
sin.data_ptr<scalar_t_1>(),
input_grads.data_ptr<scalar_t_0>()););
return input_grads;
}
} // end namespace fused_rope