Awell00's picture
feat!: add configuration for inference process
history blame
3.68 kB
from functools import wraps
from packaging import version
from collections import namedtuple
import os
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# constants
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
def once(fn):
called = False
def inner(x):
nonlocal called
if called:
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(
dropout = 0.,
flash = False,
scale = None
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
if device_version >= version.parse('8.0'):
if == 'nt':
print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if else 0.
return out
def forward(self, q, k, v):
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)
if self.flash:
return self.flash_attn(q, k, v)
# similarity
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out