/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #include "cutlass/gemm/device/gemm_universal_base.h" #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" #include "cutlass/epilogue/thread/linear_combination_residual_block.h" #include "cutlass_extensions/compute_occupancy.h" #include "cutlass_extensions/epilogue_helpers.h" #include "cutlass_extensions/ft_gemm_configs.h" #include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" #include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" #include "cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h" #include "cutlass_extensions/gemm/threadblock/default_mma.h" #pragma GCC diagnostic pop #include "../cutlass_heuristic.h" #include "fpA_intB_gemm.h" #include "cuda_utils.h" namespace fastertransformer { template void generic_mixed_gemm_kernelLauncher(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, int m, int n, int k, int bias_stride, CutlassGemmConfig gemm_config, char *workspace, size_t workspace_bytes, cudaStream_t stream, int *occupancy = nullptr) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, "Specialized for half, float"); static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, ""); // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. using ElementType_ = typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; using ElementType = ElementType_; using CutlassWeightType_ = typename cutlass::platform:: conditional::value, cutlass::half_t, WeightType>::type; using CutlassWeightType = CutlassWeightType_; // We need separate config for each architecture since we will target different tensorcore instructions. For float, // we do not target TCs. using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; using ElementAccumulator = typename MixedGemmArchTraits::AccType; using EpilogueOp = typename Epilogue::Op; using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< ElementType, cutlass::layout::RowMajor, MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB, MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, arch, ThreadblockShape, WarpShape, typename MixedGemmArchTraits::InstructionShape, EpilogueOp, typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, Stages, true, typename MixedGemmArchTraits::Operator>::GemmKernel; using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; if (occupancy != nullptr) { *occupancy = compute_occupancy_for_kernel(); return; } using Gemm = cutlass::gemm::device::GemmUniversalBase; const int ldb = cutlass::platform::is_same::value ? n : k * GemmKernel::kInterleave; typename Gemm::Arguments args({m, n, k}, {reinterpret_cast(const_cast(A)), k}, {reinterpret_cast(const_cast(B)), ldb}, {reinterpret_cast(const_cast(weight_scales)), 0}, // TODO: Support more general bias shape {reinterpret_cast(const_cast(biases)), bias_stride}, {reinterpret_cast(C), n}, gemm_config.split_k_factor, {ElementAccumulator(1.f), ElementAccumulator(0.f)}); // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write // our own predicated iterator in order to relax this limitation. if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { throw std::runtime_error("Temp assertion: k must be multiple of threadblockK"); } Gemm gemm; if (gemm.get_workspace_size(args) > workspace_bytes) { FT_LOG_WARNING( "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); // If requested split-k factor will require more workspace bytes, revert to standard gemm. args.batch_count = 1; } auto can_implement = gemm.can_implement(args); if (can_implement != cutlass::Status::kSuccess) { std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); } auto init_status = gemm.initialize(args, workspace, stream); if (init_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); } auto run_status = gemm.run(stream); if (run_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); } } template struct dispatch_stages { static void dispatch(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, int m, int n, int k, int bias_stride, CutlassGemmConfig gemm_config, char *workspace, size_t workspace_bytes, cudaStream_t stream, int *occupancy = nullptr) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg); } }; template struct dispatch_stages { static void dispatch(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, int m, int n, int k, int bias_stride, CutlassGemmConfig gemm_config, char *workspace, size_t workspace_bytes, cudaStream_t stream, int *occupancy = nullptr) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); generic_mixed_gemm_kernelLauncher( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); } }; template struct dispatch_stages 2)>::type> { static void dispatch(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, int m, int n, int k, int bias_stride, CutlassGemmConfig gemm_config, char *workspace, size_t workspace_bytes, cudaStream_t stream, int *occupancy = nullptr) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); generic_mixed_gemm_kernelLauncher( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); } }; template void dispatch_gemm_config(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, int m, int n, int k, int bias_stride, CutlassGemmConfig gemm_config, char *workspace, size_t workspace_bytes, cudaStream_t stream, int *occupancy = nullptr) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); switch (gemm_config.stages) { case 2: using DispatcherStages2 = dispatch_stages; DispatcherStages2::dispatch( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); break; case 3: using DispatcherStages3 = dispatch_stages; DispatcherStages3::dispatch( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); break; case 4: using DispatcherStages4 = dispatch_stages; DispatcherStages4::dispatch( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); break; default: std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); throw std::runtime_error("[FT Error][dispatch_gemm_config] " + err_msg); break; } } template void dispatch_gemm_to_cutlass(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, int m, int n, int k, int bias_stride, char *workspace, size_t workspace_bytes, CutlassGemmConfig gemm_config, cudaStream_t stream, int *occupancy = nullptr) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); // Note that SIMT configs are omitted here since they are not supported for fpA_intB. // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best // for mixed type gemms. switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, cutlass::gemm::GemmShape<32, 32, 64>>( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); break; case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: dispatch_gemm_config, cutlass::gemm::GemmShape<64, 32, 64>>( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); break; case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: dispatch_gemm_config, cutlass::gemm::GemmShape<128, 32, 64>>( A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); break; case CutlassTileConfig::Undefined: throw std::runtime_error("[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by heuristic."); break; default: throw std::runtime_error( "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); break; } } template CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); int device{-1}; check_cuda_error(cudaGetDevice(&device)); sm_ = getSMVersion(); check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); } template CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); } template template void CutlassFpAIntBGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n, int k, int bias_stride, CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (sm_ >= 70 && sm_ < 75) { dispatch_gemm_to_cutlass( A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } else if (sm_ >= 75 && sm_ < 80) { dispatch_gemm_to_cutlass( A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } else if (sm_ >= 80 && sm_ < 90) { dispatch_gemm_to_cutlass( A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } else { throw std::runtime_error( "[FT Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type GEMM"); } } template template void CutlassFpAIntBGemmRunner::run_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n, int k, int bias_stride, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); static constexpr bool is_weight_only = !std::is_same::value; std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, false); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { dispatch_to_arch(A, B, weight_scales, biases, C, m, n, k, bias_stride, candidate_configs[ii], workspace_ptr, workspace_bytes, stream, &occupancies[ii]); } // Standard GEMM, so 1 "expert". We use the same function for MoE and regular FFN. static constexpr int num_experts = 1; CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs, occupancies, m, n, k, num_experts, split_k_limit, workspace_bytes, multi_processor_count_, is_weight_only); dispatch_to_arch( A, B, weight_scales, biases, C, m, n, k, bias_stride, chosen_config, workspace_ptr, workspace_bytes, stream); } template void CutlassFpAIntBGemmRunner::gemm_bias_act(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, int m, int n, int k, int bias_stride, ActivationType activation_type, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); switch (activation_type) { case ActivationType::Relu: run_gemm( A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); break; case ActivationType::Gelu: run_gemm( A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); break; case ActivationType::Silu: run_gemm( A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); break; case ActivationType::Identity: run_gemm(A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); break; case ActivationType::InvalidType: FT_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be valid."); break; default: { if (isGatedActivation(activation_type)) { FT_CHECK_WITH_INFO(false, "Fused gated activations not supported"); } else { FT_CHECK_WITH_INFO(false, "Invalid activation type."); } } } } template void CutlassFpAIntBGemmRunner::gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int m, int n, int k, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); run_gemm(A, B, weight_scales, nullptr, C, m, n, k, 0, workspace_ptr, workspace_bytes, stream); } template void dispatch_gemm_residual(const T *A, const WeightType *B, const T *weight_scales, const T *biases, const T *residual, T *C, int m, int n, int k, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { using ElementType = typename cutlass::platform::conditional< cutlass::platform::is_same::value, cutlass::half_t, T>::type; using ElementOutput = ElementType; using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; using ElementAccumulator = typename EpilogueOp::ElementAccumulator; using Swizzle = typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; using InstructionShape = typename MixedGemmArchTraits::InstructionShape; using Epilogue = typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, WeightType, typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape, InstructionShape, EpilogueOp, Swizzle, stages, typename MixedGemmArchTraits::Operator>::Epilogue; using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< ElementType, cutlass::layout::RowMajor, MixedGemmArchTraits::ElementsPerAccessA, WeightType, typename MixedGemmArchTraits::LayoutB, MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape, InstructionShape, EpilogueOp, Swizzle, stages, true, typename MixedGemmArchTraits::Operator>::GemmKernel; using GemmKernel = cutlass::gemm::kernel::GemmFpAIntBWithBroadcast< typename GemmKernel_::Mma, Epilogue, typename GemmKernel_::ThreadblockSwizzle, Arch>; using Gemm = cutlass::gemm::device::GemmUniversalBase; // TODO: Support batch const int batch_count = 1; const auto lda = k; const int ldb = cutlass::platform::is_same::value ? n : k * GemmKernel::kInterleave; const int ldc = n; typename Gemm::Arguments args( {m, n, k}, batch_count, {ElementAccumulator(1.f), ElementAccumulator(1.f)}, A, B, weight_scales, residual, C, biases, nullptr, 0, 0, 0, 0, 0, 0, lda, ldb, ldc, ldc, 0, 0); if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || (k % MixedGemmArchTraits::ThreadblockK))) { throw std::runtime_error( "Temp assertion: k must be multiple of threadblockK"); } Gemm gemm; auto can_implement = gemm.can_implement(args); if (can_implement != cutlass::Status::kSuccess) { std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); } auto init_status = gemm.initialize(args, workspace_ptr, stream); if (init_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); } auto run_status = gemm.run(stream); if (run_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); } } template void dispatch_gemm_residual(CutlassTileConfig tile_config, const T *A, const WeightType *B, const T *weight_scales, const T *biases, const T *residual, T *C, int m, int n, int k, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { if (tile_config == CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64) { dispatch_gemm_residual< T, WeightType, Arch, cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, EpilogueOp, stages>( A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else if (tile_config == CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64) { dispatch_gemm_residual< T, WeightType, Arch, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, EpilogueOp, stages>( A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else { // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: dispatch_gemm_residual< T, WeightType, Arch, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>, EpilogueOp, stages>( A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } } template void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, const WeightType *B, const T *weight_scales, const T *biases, const T *residual, T *C, int m, int n, int k, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { if constexpr (std::is_same::value) { dispatch_gemm_residual( config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else if constexpr (std::is_same::value) { dispatch_gemm_residual( config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else { if (config.stages == 3) { dispatch_gemm_residual( config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else if (config.stages == 4) { dispatch_gemm_residual( config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else { // 2 dispatch_gemm_residual( config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } } } template class ActivationOp, template class BinaryOp> inline void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, const WeightType *B, const T *weight_scales, const T *biases, const T *residual, T *C, int m, int n, int k, const std::string &unary_op, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { using ElementOutput = T; using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; using ElementAccumulator = typename MixedGemmArchTraits::AccType; if (unary_op == "identity") { using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< ElementOutput, ElementAccumulator, ElementAccumulator, ElementOutput, 128 / cutlass::sizeof_bits::value, ActivationOp, BinaryOp, cutlass::epilogue::thread::Identity>; dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else if (unary_op == "relu") { using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< ElementOutput, ElementAccumulator, ElementAccumulator, ElementOutput, 128 / cutlass::sizeof_bits::value, ActivationOp, BinaryOp, cutlass::epilogue::thread::ReLu>; dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, workspace_bytes, stream); } else { throw std::runtime_error( "[FT Error][Unsupported unary op after residual block] " + unary_op); } } template class ActivationOp> void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, const WeightType *B, const T *weight_scales, const T *biases, const T *residual, T *C, int m, int n, int k, const std::string &binary_op, const std::string &unary_op, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { if (binary_op == "plus") { dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op, workspace_ptr, workspace_bytes, stream); } else if (binary_op == "multiply") { dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op, workspace_ptr, workspace_bytes, stream); } else { throw std::runtime_error( "[FT Error][Unsupported binary op for residual block] " + binary_op); } } template void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, const WeightType *B, const T *weight_scales, const T *biases, const T *residual, T *C, int m, int n, int k, const std::string &activation, const std::string &binary_op, const std::string &unary_op, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { if (activation == "identity") { dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); } else if ("silu") { dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); } else if ("relu") { dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); } else if ("gelu") { dispatch_gemm_residual( config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); } else { throw std::runtime_error( "[FT Error][Unsupported activation before residual binary op] " + activation); } } template void CutlassFpAIntBGemmRunner::gemm_bias_act_residual( const T *A, const WeightType *B, const T *weight_scales, const T *biases, const T *residual, T *C, int m, int n, int k, const std::string &activation, const std::string &binary_op, const std::string &unary_op, char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { std::vector candidate_configs = get_candidate_configs(sm_, true, false); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { dispatch_to_arch( A, B, weight_scales, biases, C, m, n, k, 0, candidate_configs[ii], workspace_ptr, workspace_bytes, stream, &occupancies[ii]); } CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies( candidate_configs, occupancies, m, n, k, 1, split_k_limit, workspace_bytes, multi_processor_count_, true); if (sm_ >= 80 && sm_ < 90) { dispatch_gemm_residual( chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); } else if (sm_ >= 75 && sm_ < 80) { dispatch_gemm_residual( chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); } else if (sm_ == 70) { dispatch_gemm_residual( chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); } else { throw std::runtime_error("[FT Error][Unsupported SM] " + sm_); } } template int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, const int n, const int k) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); // TODO(masahi): Shouldn't it be 0? // These are the min tile sizes for each config, which would launch the maximum number of blocks const int max_grid_m = (m + 31) / 32; const int max_grid_n = (n + 127) / 128; // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. return max_grid_m * max_grid_n * split_k_limit * 4; } } // namespace fastertransformer