Triton-InternVL2-2B / triton-test.py
radna's picture
Upload 20 files
5c0cb68 verified
raw
history blame
734 Bytes
import torch
from triton_flash_atn import _attention
# Define dimensions
batch_size = 2
num_heads = 4
seq_len = 128
head_dim = 64
# Create random input tensors for Q, K, V
q = torch.randn(batch_size, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
# Define whether the attention is causal and the scaling factor
causal = False
sm_scale = 1.0 / (head_dim ** 0.5)
# Apply flash attention
attention = _attention.apply
output = attention(q, k, v, causal, sm_scale)
print(output)