flash-attention / decoder_masked_multihead_attention.h
theonlyengine's picture
Upload 421 files
3f9c425 verified
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template<typename T>
struct Multihead_attention_params_base {
// The output buffer. Dimensions B x D.
T* out = nullptr;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const T *q = nullptr, *q_bias = nullptr;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const T *k = nullptr, *k_bias = nullptr;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const T *v = nullptr, *v_bias = nullptr;
// The cache for the Ks. The size must be at least B x L x D.
T* k_cache = nullptr;
// The cache for the Vs. The size must be at least B x L x D.
T* v_cache = nullptr;
// The indirections to use for cache when beam sampling.
const int* cache_indir = nullptr;
// Stride to handle the case when KQV is a single buffer
int stride_q = 0;
int stride_k = 0;
int stride_v = 0;
// The batch size.
int batch_size = 0;
// The beam width
int beam_width = 0;
// The sequence length.
int memory_max_len = 0;
// The number of heads (H).
int num_heads = 0;
int num_heads_kv = 0;
int num_heads_q_kv_ratio = 0;
// The hidden dimension per head (Dh).
int hidden_size_per_head = 0;
// The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0;
bool neox_rotary_style = false;
float rotary_base = 0.0f;
// The maximum length of input sentences.
int max_input_length = 0;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float inv_sqrt_dh = 0.0f;
// Used when we have some input context like gpt
const int* total_padding_tokens = nullptr;
const bool* masked_tokens = nullptr;
const int* prefix_prompt_lengths = nullptr;
int max_prefix_prompt_length = 0;
const T* relative_attention_bias = nullptr;
int relative_attention_bias_stride = 0;
// The slope per head of linear position bias to attention score (H).
const T* linear_bias_slopes = nullptr;
const T* ia3_key_weights = nullptr;
const T* ia3_value_weights = nullptr;
const int* ia3_tasks = nullptr;
const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr;
int int8_mode = 0;
const T *rotary_cos = nullptr;
const T *rotary_sin = nullptr;
const int *nnz_head_idx = nullptr;
int nnz_heads = 0;
};
template<typename T, bool CROSS_ATTENTION>
struct Multihead_attention_params: public Multihead_attention_params_base<T> {
// output cross attentions
float* cross_attention_out = nullptr;
int max_decoder_seq_len = 0;
bool is_return_cross_attentions = false;
// allows to exist attention eary
bool* finished = nullptr;
// required in case of cross attention
// will need it here till if constexpr in c++17
int* memory_length_per_sample = nullptr;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
};
template<typename T>
struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
// output cross attentions
float* cross_attention_out = nullptr;
int max_decoder_seq_len = 0;
bool is_return_cross_attentions = false;
// allows to exist attention eary
bool* finished = nullptr;
// required in case of cross attention
int* memory_length_per_sample = nullptr;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
};
template<class T>
using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
template<class T>
using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
template<typename T>
struct outputCrossAttentionParam {
// max decoder output length
int max_decoder_seq_len = 0;
T* cross_attention_out = nullptr;
bool is_return_cross_attentions = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////