File size: 6,575 Bytes
9066b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import re
from dataclasses import dataclass
from typing import List, Optional

import torch

from transformers.modeling_flash_attention_utils import _flash_attention_forward


@dataclass
class IterStep:
    """A helper class for the iteration plan"""
    layer_slice: slice = slice(None)
    requires_grad: bool = True
    update: bool = True

@dataclass
class LayerType:
    """A helper class to collect the layer type information"""
    layer_idx: int
    use_sliding_window: bool
    attends_to: int
    attends_top: bool
    computes_kv: bool

class LayerTypeParser:
    """
    A helper class to parse the layer type string and provide some useful methods.

    Arguments:
        layer_type (str): A string of integers separated by underscores. The i-th integer
            means the layer will use the key-value pair in the i-th layer as the kv cache.
            Special characters may be placed after the integers:
            - `s` means the layer will use sliding window attention.

    >>> layer_type = LayerTypeParser("0_0_0_5s_5s_5s_8_8_8")[3]
    >>> layer_type.attends_to
    5
    >>> layer_type.attends_top
    True
    >>> layer_type.use_sliding_window
    True
    """
    def __init__(self, layer_type: str):
        self._layer_type = layer_type

        # parse the layer type
        self.layer_indices = []
        self.sliding_window = []
        for s in layer_type.split("_"):
            layer_idx, sliding_window = re.match(r"^(\d+)(s)?$", s).groups()
            self.layer_indices.append(int(layer_idx))
            self.sliding_window.append(bool(sliding_window))

    def __len__(self):
        return len(self.layer_indices)

    def __getitem__(self, layer_idx: int) -> LayerType:
        """return the layer type information for the given layer index"""
        return LayerType(
            layer_idx=layer_idx,
            use_sliding_window=self.sliding_window[layer_idx],
            attends_to=self.layer_indices[layer_idx],
            attends_top=self.layer_indices[layer_idx] > layer_idx,
            computes_kv=layer_idx in self.layer_indices,
        )

    def use_sliding_window(self) -> bool:
        """whether there exists a layer that uses sliding window attention"""
        return any(self.sliding_window)

    def attends_top(self) -> bool:
        """whether there exists a layer that attends to layers above it"""
        return any(self.layer_indices[i] > i for i in range(len(self)))

    def iteration_plan(self, forward_passes: int = 7, backward_passes: int = 2) -> List[IterStep]:
        """
        Return a iteration plan for the layer types. The plan is a list of IterStep objects.
        """
        # if there is no cyclic dependency, return the default plan
        if not self.attends_top():
            return [IterStep()]

        # otherwise, return the plan for the cyclic dependency
        plan = []
        i = 0
        while i < len(self):

            # if the layer attends to top layers, resolve the cyclic dependency
            if self[i].attends_top:

                # find the top layer in the cyclic dependency
                top = self[i].attends_to
                while top < max(self.layer_indices[i: top + 1]):
                    top = max(self.layer_indices[i: top + 1])
                top += 1

                # create iteration plan for this group
                layer_slice = slice(i, top)
                plan.extend([
                    *forward_passes * [IterStep(layer_slice, requires_grad=False, update=False)],
                    *(backward_passes - 1) * [IterStep(layer_slice, update=False)],
                    IterStep(layer_slice)
                ])

            # otherwise, create a default plan
            else:

                top = i + 1
                while top < len(self) and not self[top].attends_top:
                    top += 1
                plan.append(IterStep(slice(i, top)))

            # update the index
            i = top

        return plan

    def check(self, num_hidden_layers: int):
        """Check if the layer type is valid"""
        if len(self.layer_indices) != num_hidden_layers:
            raise ValueError("The number of layer types should be equal to the number of hidden layers.")
        for i in range(num_hidden_layers):
            if self.layer_indices[i] not in range(num_hidden_layers):
                raise ValueError("The layer type should be in the range of the number of hidden layers.")


def flash_attention_forward(
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    attention_mask: torch.Tensor,
    query_length: int,
    is_causal: bool,
    dropout: float = 0.0,
    position_ids: Optional[torch.Tensor] = None,
    softmax_scale: Optional[float] = None,
    sliding_window: Optional[int] = None,
    use_top_left_mask: bool = False,
    softcap: Optional[float] = None,
    deterministic: bool = None,
    no_diag: bool = False,
):
    """
    This function is a wrapper around the _flash_attention_forward function in the
    transformers library. It adds support to mask the diagonal elements of the attention
    matrix. The diagonal mask is used to resolve the cyclic dependencies in the LCKV model.
    """
    prune_query = False
    if no_diag:
        if key_states.size(1) == 1:
            b, l, _, d = value_states.size()
            _, _, h, _ = query_states.size()
            return value_states.new_zeros((b, l, h, d))

        if key_states.size(1) == query_states.size(1):
            prune_query = True
            query_states = query_states[:, 1:, :, :]
            query_length -= 1

            if attention_mask is not None:
                attention_mask = attention_mask[:, 1:]

        key_states = key_states[:, :-1, :, :]
        value_states = value_states[:, :-1, :, :]

        if sliding_window is not None:
            sliding_window = sliding_window - 1

    result: torch.Tensor = _flash_attention_forward(
        query_states=query_states,
        key_states=key_states,
        value_states=value_states,
        attention_mask=attention_mask,
        query_length=query_length,
        is_causal=is_causal,
        dropout=dropout,
        position_ids=position_ids,
        softmax_scale=softmax_scale,
        sliding_window=sliding_window,
        use_top_left_mask=use_top_left_mask,
        softcap=softcap,
        deterministic=deterministic,
    )

    if prune_query:
        b, _, h, d = result.size()
        result = torch.cat([result.new_zeros((b, 1, h, d)), result], dim=1)

    return result