File size: 6,256 Bytes
3fea540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

import math

import torch
import torch.nn as nn

def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
    dtype = tensor.dtype
    clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240.,  dtype=dtype)
    quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
    return quant_tensor

def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
    return tensor * scale

# Based on: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
def qdq_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    query = dequantize_fp8(quantize_fp8(query, query_scale), query_scale)
    key = dequantize_fp8(quantize_fp8(key, key_scale), key_scale)
    value = dequantize_fp8(quantize_fp8(value, value_scale), value_scale)
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = dequantize_fp8(quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale), softmax_scale)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

def qop_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    query = quantize_fp8(query, query_scale)
    key = quantize_fp8(key, key_scale)
    value = quantize_fp8(value, value_scale)
    # Your quantized kernel starts here
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    scale_factor *= (query_scale * key_scale)
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = (query @ key.transpose(-2, -1)) * scale_factor # or, attn_weight = dequantize_fp8(query @ key.transpose(-2, -1), scale_factor)
    attn_weight += attn_bias
    attn_weight = quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return (attn_weight @ value) * (softmax_scale * value_scale) # or, return dequantize_fp8(attn_weight @ value, softmax_scale * value_scale)

# Module that implements `torch.nn.functional.scaled_dot_product_attention`
class QuantScaledDotProductAttention(nn.Module):
    def __init__(self, quant_param):
        super().__init__()
        q_scale = torch.tensor(quant_param['out_q']['act_scale']).view(quant_param['out_q']['act_scale_shape'])
        k_scale = torch.tensor(quant_param['out_k']['act_scale']).view(quant_param['out_k']['act_scale_shape'])
        v_scale = torch.tensor(quant_param['out_v']['act_scale']).view(quant_param['out_v']['act_scale_shape'])
        sm_scale = torch.tensor(quant_param['output_softmax_quant']['act_scale']).view(quant_param['output_softmax_quant']['act_scale_shape'])
        # Not used, included in model. Kept because we use the zp_dtype as a type hint
        #q_zp = torch.tensor(quant_param['out_q']['act_zp']).view(quant_param['out_q']['act_zp_shape'])
        #k_zp = torch.tensor(quant_param['out_k']['act_zp']).view(quant_param['out_k']['act_zp_shape'])
        #v_zp = torch.tensor(quant_param['out_v']['act_zp']).view(quant_param['out_v']['act_zp_shape'])
        assert quant_param['out_q']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Q Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_q']['act_zp_dtype']}"
        assert quant_param['out_k']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"K Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_k']['act_zp_dtype']}"
        assert quant_param['out_v']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"V Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_v']['act_zp_dtype']}"
        assert quant_param['output_softmax_quant']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"SoftMax Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['output_softmax_quant']['act_zp_dtype']}"
        self.register_buffer('q_scale', q_scale)
        self.register_buffer('k_scale', k_scale)
        self.register_buffer('v_scale', v_scale)
        self.register_buffer('sm_scale', sm_scale)


    # I.e., "fake quantization"
    def qdq_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        return qdq_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)

    # Accelerated version
    def qop_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        return qop_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)

    def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, qop=False):
        if qop:
            return self.qop_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)
        else:
            return self.qdq_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)