/* coding=utf-8 * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #ifndef __HIP_PLATFORM_HCC__ #include #endif #include #include #include "scaled_upper_triang_masked_softmax.h" #include "type_shim.h" namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = input.size(0); const int seq_len = input.size(1); TORCH_INTERNAL_ASSERT(seq_len <= 2048); // Output auto act_options = input.options().requires_grad(false); torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); DISPATCH_HALF_AND_BFLOAT( input.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_forward", dispatch_scaled_upper_triang_masked_softmax_forward( reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), scale_factor, seq_len, seq_len, attn_batches);); return softmax_results; } torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); // output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = output_grads.size(0); const int seq_len = output_grads.size(1); TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); void* output_grads_ptr = static_cast(output_grads.data_ptr()); // Softmax Grad DISPATCH_HALF_AND_BFLOAT( output_grads_.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_backward", dispatch_scaled_upper_triang_masked_softmax_backward( reinterpret_cast(output_grads_ptr), reinterpret_cast(output_grads_ptr), reinterpret_cast(softmax_results.data_ptr()), scale_factor, seq_len, seq_len, attn_batches);); // backward pass is completely in-place return output_grads; } } // namespace scaled_upper_triang_masked_softmax } // namespace fused_softmax } // namespace multihead_attn