|
import torch |
|
from torch.nn.attention.flex_attention import _mask_mod_signature |
|
|
|
def causal_mask( |
|
batch_size: int, |
|
num_heads: int, |
|
q_idx: torch.Tensor, |
|
kv_idx: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Returns a boolean tensor indicating which positions in the attention matrix |
|
are valid for causal (autoregressive) attention. By default, it's True for |
|
positions (i, j) where i >= j. |
|
|
|
Args: |
|
batch_size (int): Batch size (unused here). |
|
num_heads (int): Number of heads (unused here). |
|
q_idx (torch.Tensor): Tensor indexing the query positions. |
|
kv_idx (torch.Tensor): Tensor indexing the key/value positions. |
|
|
|
Returns: |
|
torch.Tensor: A boolean tensor where True indicates that the query at |
|
position i can attend to the key at position j, respecting i >= j. |
|
""" |
|
return q_idx >= kv_idx |
|
|
|
|
|
def generate_sliding_window_mask(window_size: int, causal: bool = True) -> _mask_mod_signature: |
|
""" |
|
Creates a sliding window mask function. |
|
|
|
If `causal=True`, each query token at position i can attend only to tokens j |
|
in [i - window_size, i]. |
|
If `causal=False`, each query token i can attend to any token j in |
|
[i - window_size, i + window_size], i.e. a symmetric window of size `window_size`. |
|
|
|
Args: |
|
window_size (int): The maximum distance from i that i can attend to. |
|
causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True. |
|
|
|
Returns: |
|
_mask_mod_signature: A callable mask function that takes |
|
(batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor |
|
indicating allowed attention connections. |
|
""" |
|
def sliding_window_mask( |
|
batch_size: int, |
|
num_heads: int, |
|
q_idx: torch.Tensor, |
|
kv_idx: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
If causal is True: |
|
within_window = (q_idx - kv_idx) <= window_size, and q_idx >= kv_idx. |
|
If causal is False: |
|
within_window = abs(q_idx - kv_idx) <= window_size. |
|
""" |
|
if causal: |
|
|
|
distance = q_idx - kv_idx |
|
within_window = (distance >= 0) & (distance <= window_size) |
|
else: |
|
|
|
distance = (q_idx - kv_idx).abs() |
|
within_window = distance <= window_size |
|
|
|
return within_window |
|
|
|
name_ext = "causal" if causal else "noncausal" |
|
sliding_window_mask.__name__ = f"sliding_window_{window_size}_{name_ext}" |
|
return sliding_window_mask |
|
|
|
|
|
def generate_dilated_sliding_window_mask( |
|
window_size: int, |
|
dilation: int = 2, |
|
causal: bool = True |
|
) -> _mask_mod_signature: |
|
""" |
|
Creates a dilated sliding window mask function. |
|
|
|
If `causal=True`, each query token i can attend tokens j in [i - window_size, i] |
|
such that (i - j) % dilation == 0. |
|
If `causal=False`, each query token i can attend tokens j in [i - window_size, |
|
i + window_size] for which |i - j| % dilation == 0. |
|
|
|
Args: |
|
window_size (int): The maximum distance from i to j (backwards if causal=True, |
|
otherwise symmetric around i). |
|
dilation (int): The stride for skipping positions. |
|
causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True. |
|
|
|
Returns: |
|
_mask_mod_signature: A callable mask function that takes |
|
(batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor |
|
indicating allowed attention connections. |
|
""" |
|
def dilated_sliding_window_mask( |
|
batch_size: int, |
|
num_heads: int, |
|
q_idx: torch.Tensor, |
|
kv_idx: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
If causal is True: |
|
distance = q_idx - kv_idx |
|
0 <= distance <= window_size and distance % dilation == 0. |
|
If causal is False: |
|
distance = (q_idx - kv_idx).abs() |
|
distance <= window_size and distance % dilation == 0. |
|
""" |
|
if causal: |
|
distance = q_idx - kv_idx |
|
within_window = (distance >= 0) & (distance <= window_size) |
|
else: |
|
distance = (q_idx - kv_idx).abs() |
|
within_window = distance <= window_size |
|
|
|
meets_dilation = (distance % dilation) == 0 |
|
return within_window & meets_dilation |
|
|
|
mode_str = "causal" if causal else "noncausal" |
|
dilated_sliding_window_mask.__name__ = ( |
|
f"dilated_sliding_window_{window_size}_dilation_{dilation}_{mode_str}" |
|
) |
|
return dilated_sliding_window_mask |
|
|
|
|
|
def main(): |
|
""" |
|
Demonstrates usage of each mask by printing attention grids. We include a few |
|
basic checks to ensure the masks behave as expected. We show both the causal |
|
and non-causal versions for the sliding window and dilated masks. |
|
""" |
|
B, H = 1, 1 |
|
Q_LEN, KV_LEN = 8, 8 |
|
|
|
|
|
q_idx = torch.arange(Q_LEN).unsqueeze(-1).expand(Q_LEN, KV_LEN) |
|
kv_idx = torch.arange(KV_LEN).unsqueeze(0).expand(Q_LEN, KV_LEN) |
|
|
|
print("= Causal Mask =") |
|
c_mask = causal_mask(B, H, q_idx, kv_idx) |
|
print(c_mask.int(), "\n") |
|
|
|
print("= Sliding Window (window_size=2, causal=True) =") |
|
sw_causal_fn = generate_sliding_window_mask(window_size=2, causal=True) |
|
sw_causal = sw_causal_fn(B, H, q_idx, kv_idx) |
|
print(sw_causal.int(), "\n") |
|
|
|
print("= Sliding Window (window_size=2, causal=False) =") |
|
sw_noncausal_fn = generate_sliding_window_mask(window_size=2, causal=False) |
|
sw_noncausal = sw_noncausal_fn(B, H, q_idx, kv_idx) |
|
print(sw_noncausal.int(), "\n") |
|
|
|
print("= Dilated Sliding Window (window_size=4, dilation=2, causal=True) =") |
|
ds_causal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=True) |
|
ds_causal = ds_causal_fn(B, H, q_idx, kv_idx) |
|
print(ds_causal.int(), "\n") |
|
|
|
print("= Dilated Sliding Window (window_size=4, dilation=2, causal=False) =") |
|
ds_noncausal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=False) |
|
ds_noncausal = ds_noncausal_fn(B, H, q_idx, kv_idx) |
|
print(ds_noncausal.int(), "\n") |
|
|
|
|
|
|
|
assert torch.all(c_mask == (q_idx >= kv_idx)), "Causal mask mismatch!" |
|
|
|
i = 5 |
|
row_sw = sw_causal[i] |
|
allowed_js = torch.where(row_sw)[0] |
|
if len(allowed_js) > 0: |
|
|
|
assert (i - allowed_js.min()) <= 2, "Window mismatch for sliding_window_mask(causal=True)." |
|
|
|
|
|
i = 6 |
|
row_ds = ds_causal[i] |
|
allowed_js = torch.where(row_ds)[0] |
|
for j in allowed_js: |
|
diff = i - j |
|
assert diff % 2 == 0, f"Dilation mismatch: got diff={diff}." |
|
|
|
print("All checks passed.") |
|
|
|
if __name__ == "__main__": |
|
main() |