Commit
·
a2fbb2f
1
Parent(s):
44d4c6b
transformer new
Browse files- attn.py +66 -0
- attn_masks.py +188 -0
- attn_mods.py +127 -0
- config.json +7 -11
- configuration_minitransformer.py +18 -13
- layers.py +11 -72
- modeling_minitransformer.py +30 -42
attn.py
CHANGED
@@ -123,3 +123,69 @@ class AttentionSDPA(nn.Module):
|
|
123 |
|
124 |
y = self.resid_dropout(self.o_proj(y))
|
125 |
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
y = self.resid_dropout(self.o_proj(y))
|
125 |
return y
|
126 |
+
|
127 |
+
|
128 |
+
class FlexAttention(nn.Module):
|
129 |
+
"""
|
130 |
+
Generalized Multihead Attention and supports various attention masks.
|
131 |
+
Supports Rotary Positional Embeddings.
|
132 |
+
"""
|
133 |
+
def __init__(self, config, mask_mod, score_mod=None):
|
134 |
+
"""
|
135 |
+
Initializes the Attention class.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
dim (int): Embedding size.
|
139 |
+
num_heads (int): Number of heads.
|
140 |
+
mask_mod (Callable): Mask to modify attention scores, e.g. causal.
|
141 |
+
"""
|
142 |
+
super().__init__()
|
143 |
+
self.dim, self.num_heads = config.dim, config.num_heads
|
144 |
+
assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
|
145 |
+
self.head_dim = config.dim // config.num_heads
|
146 |
+
|
147 |
+
self.wq = nn.Linear(config.dim, config.dim)
|
148 |
+
self.wk = nn.Linear(config.dim, config.dim)
|
149 |
+
self.wv = nn.Linear(config.dim, config.dim)
|
150 |
+
|
151 |
+
self.mask_mod = mask_mod
|
152 |
+
self.score_mod = score_mod
|
153 |
+
self.block_mask = create_block_mask(
|
154 |
+
mask_mod=self.mask_mod,
|
155 |
+
B=None, # Broadcast
|
156 |
+
H=None, # Broadcast
|
157 |
+
Q_LEN=config.seq_len,
|
158 |
+
KV_LEN=config.seq_len,
|
159 |
+
device=config.device,
|
160 |
+
)
|
161 |
+
|
162 |
+
self.o_proj = nn.Linear(config.dim, config.dim)
|
163 |
+
self.o_proj.SCALE_INIT = 1
|
164 |
+
|
165 |
+
def forward(
|
166 |
+
self,
|
167 |
+
x: torch.Tensor = None,
|
168 |
+
q: torch.Tensor = None,
|
169 |
+
k: torch.Tensor = None,
|
170 |
+
v: torch.Tensor = None,
|
171 |
+
freqs_cis: torch.Tensor = None,
|
172 |
+
) -> torch.Tensor:
|
173 |
+
if x is not None:
|
174 |
+
q = k = v = x
|
175 |
+
if any(t is None for t in [q, k, v]):
|
176 |
+
raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
|
177 |
+
|
178 |
+
bsz, q_len, _ = q.shape
|
179 |
+
_, k_len, _ = k.shape
|
180 |
+
_, v_len, _ = v.shape
|
181 |
+
|
182 |
+
Q = self.wq(q).reshape(bsz, self.num_heads, q_len, self.head_dim)
|
183 |
+
K = self.wk(k).reshape(bsz, self.num_heads, k_len, self.head_dim)
|
184 |
+
V = self.wv(v).reshape(bsz, self.num_heads, v_len, self.head_dim)
|
185 |
+
|
186 |
+
Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis)
|
187 |
+
|
188 |
+
output = flex_attention(Q, K, V, block_mask=self.block_mask, score_mod=self.score_mod)
|
189 |
+
output = output.reshape(bsz, q_len, self.dim)
|
190 |
+
output = self.o_proj(output)
|
191 |
+
return output
|
attn_masks.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.attention.flex_attention import _mask_mod_signature
|
3 |
+
|
4 |
+
def causal_mask(
|
5 |
+
batch_size: int,
|
6 |
+
num_heads: int,
|
7 |
+
q_idx: torch.Tensor,
|
8 |
+
kv_idx: torch.Tensor
|
9 |
+
) -> torch.Tensor:
|
10 |
+
"""
|
11 |
+
Returns a boolean tensor indicating which positions in the attention matrix
|
12 |
+
are valid for causal (autoregressive) attention. By default, it's True for
|
13 |
+
positions (i, j) where i >= j.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
batch_size (int): Batch size (unused here).
|
17 |
+
num_heads (int): Number of heads (unused here).
|
18 |
+
q_idx (torch.Tensor): Tensor indexing the query positions.
|
19 |
+
kv_idx (torch.Tensor): Tensor indexing the key/value positions.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
torch.Tensor: A boolean tensor where True indicates that the query at
|
23 |
+
position i can attend to the key at position j, respecting i >= j.
|
24 |
+
"""
|
25 |
+
return q_idx >= kv_idx
|
26 |
+
|
27 |
+
|
28 |
+
def generate_sliding_window_mask(window_size: int, causal: bool = True) -> _mask_mod_signature:
|
29 |
+
"""
|
30 |
+
Creates a sliding window mask function.
|
31 |
+
|
32 |
+
If `causal=True`, each query token at position i can attend only to tokens j
|
33 |
+
in [i - window_size, i].
|
34 |
+
If `causal=False`, each query token i can attend to any token j in
|
35 |
+
[i - window_size, i + window_size], i.e. a symmetric window of size `window_size`.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
window_size (int): The maximum distance from i that i can attend to.
|
39 |
+
causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
_mask_mod_signature: A callable mask function that takes
|
43 |
+
(batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor
|
44 |
+
indicating allowed attention connections.
|
45 |
+
"""
|
46 |
+
def sliding_window_mask(
|
47 |
+
batch_size: int,
|
48 |
+
num_heads: int,
|
49 |
+
q_idx: torch.Tensor,
|
50 |
+
kv_idx: torch.Tensor
|
51 |
+
) -> torch.Tensor:
|
52 |
+
"""
|
53 |
+
If causal is True:
|
54 |
+
within_window = (q_idx - kv_idx) <= window_size, and q_idx >= kv_idx.
|
55 |
+
If causal is False:
|
56 |
+
within_window = abs(q_idx - kv_idx) <= window_size.
|
57 |
+
"""
|
58 |
+
if causal:
|
59 |
+
# standard "look back" window
|
60 |
+
distance = q_idx - kv_idx
|
61 |
+
within_window = (distance >= 0) & (distance <= window_size)
|
62 |
+
else:
|
63 |
+
# symmetrical window around i
|
64 |
+
distance = (q_idx - kv_idx).abs()
|
65 |
+
within_window = distance <= window_size
|
66 |
+
|
67 |
+
return within_window
|
68 |
+
|
69 |
+
name_ext = "causal" if causal else "noncausal"
|
70 |
+
sliding_window_mask.__name__ = f"sliding_window_{window_size}_{name_ext}"
|
71 |
+
return sliding_window_mask
|
72 |
+
|
73 |
+
|
74 |
+
def generate_dilated_sliding_window_mask(
|
75 |
+
window_size: int,
|
76 |
+
dilation: int = 2,
|
77 |
+
causal: bool = True
|
78 |
+
) -> _mask_mod_signature:
|
79 |
+
"""
|
80 |
+
Creates a dilated sliding window mask function.
|
81 |
+
|
82 |
+
If `causal=True`, each query token i can attend tokens j in [i - window_size, i]
|
83 |
+
such that (i - j) % dilation == 0.
|
84 |
+
If `causal=False`, each query token i can attend tokens j in [i - window_size,
|
85 |
+
i + window_size] for which |i - j| % dilation == 0.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
window_size (int): The maximum distance from i to j (backwards if causal=True,
|
89 |
+
otherwise symmetric around i).
|
90 |
+
dilation (int): The stride for skipping positions.
|
91 |
+
causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
_mask_mod_signature: A callable mask function that takes
|
95 |
+
(batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor
|
96 |
+
indicating allowed attention connections.
|
97 |
+
"""
|
98 |
+
def dilated_sliding_window_mask(
|
99 |
+
batch_size: int,
|
100 |
+
num_heads: int,
|
101 |
+
q_idx: torch.Tensor,
|
102 |
+
kv_idx: torch.Tensor
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
If causal is True:
|
106 |
+
distance = q_idx - kv_idx
|
107 |
+
0 <= distance <= window_size and distance % dilation == 0.
|
108 |
+
If causal is False:
|
109 |
+
distance = (q_idx - kv_idx).abs()
|
110 |
+
distance <= window_size and distance % dilation == 0.
|
111 |
+
"""
|
112 |
+
if causal:
|
113 |
+
distance = q_idx - kv_idx
|
114 |
+
within_window = (distance >= 0) & (distance <= window_size)
|
115 |
+
else:
|
116 |
+
distance = (q_idx - kv_idx).abs()
|
117 |
+
within_window = distance <= window_size
|
118 |
+
|
119 |
+
meets_dilation = (distance % dilation) == 0
|
120 |
+
return within_window & meets_dilation
|
121 |
+
|
122 |
+
mode_str = "causal" if causal else "noncausal"
|
123 |
+
dilated_sliding_window_mask.__name__ = (
|
124 |
+
f"dilated_sliding_window_{window_size}_dilation_{dilation}_{mode_str}"
|
125 |
+
)
|
126 |
+
return dilated_sliding_window_mask
|
127 |
+
|
128 |
+
|
129 |
+
def main():
|
130 |
+
"""
|
131 |
+
Demonstrates usage of each mask by printing attention grids. We include a few
|
132 |
+
basic checks to ensure the masks behave as expected. We show both the causal
|
133 |
+
and non-causal versions for the sliding window and dilated masks.
|
134 |
+
"""
|
135 |
+
B, H = 1, 1
|
136 |
+
Q_LEN, KV_LEN = 8, 8
|
137 |
+
|
138 |
+
# coordinate grids
|
139 |
+
q_idx = torch.arange(Q_LEN).unsqueeze(-1).expand(Q_LEN, KV_LEN)
|
140 |
+
kv_idx = torch.arange(KV_LEN).unsqueeze(0).expand(Q_LEN, KV_LEN)
|
141 |
+
|
142 |
+
print("= Causal Mask =")
|
143 |
+
c_mask = causal_mask(B, H, q_idx, kv_idx)
|
144 |
+
print(c_mask.int(), "\n")
|
145 |
+
|
146 |
+
print("= Sliding Window (window_size=2, causal=True) =")
|
147 |
+
sw_causal_fn = generate_sliding_window_mask(window_size=2, causal=True)
|
148 |
+
sw_causal = sw_causal_fn(B, H, q_idx, kv_idx)
|
149 |
+
print(sw_causal.int(), "\n")
|
150 |
+
|
151 |
+
print("= Sliding Window (window_size=2, causal=False) =")
|
152 |
+
sw_noncausal_fn = generate_sliding_window_mask(window_size=2, causal=False)
|
153 |
+
sw_noncausal = sw_noncausal_fn(B, H, q_idx, kv_idx)
|
154 |
+
print(sw_noncausal.int(), "\n")
|
155 |
+
|
156 |
+
print("= Dilated Sliding Window (window_size=4, dilation=2, causal=True) =")
|
157 |
+
ds_causal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=True)
|
158 |
+
ds_causal = ds_causal_fn(B, H, q_idx, kv_idx)
|
159 |
+
print(ds_causal.int(), "\n")
|
160 |
+
|
161 |
+
print("= Dilated Sliding Window (window_size=4, dilation=2, causal=False) =")
|
162 |
+
ds_noncausal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=False)
|
163 |
+
ds_noncausal = ds_noncausal_fn(B, H, q_idx, kv_idx)
|
164 |
+
print(ds_noncausal.int(), "\n")
|
165 |
+
|
166 |
+
# Quick checks:
|
167 |
+
# (1) Causal means no i < j
|
168 |
+
assert torch.all(c_mask == (q_idx >= kv_idx)), "Causal mask mismatch!"
|
169 |
+
# (2) For windowed masks with causal=True, check a random row
|
170 |
+
i = 5
|
171 |
+
row_sw = sw_causal[i]
|
172 |
+
allowed_js = torch.where(row_sw)[0]
|
173 |
+
if len(allowed_js) > 0:
|
174 |
+
# difference i-j <= 2
|
175 |
+
assert (i - allowed_js.min()) <= 2, "Window mismatch for sliding_window_mask(causal=True)."
|
176 |
+
|
177 |
+
# (3) Dilated mask with causal=True should skip every other position if dilation=2
|
178 |
+
i = 6
|
179 |
+
row_ds = ds_causal[i]
|
180 |
+
allowed_js = torch.where(row_ds)[0]
|
181 |
+
for j in allowed_js:
|
182 |
+
diff = i - j
|
183 |
+
assert diff % 2 == 0, f"Dilation mismatch: got diff={diff}."
|
184 |
+
|
185 |
+
print("All checks passed.")
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
main()
|
attn_mods.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from torch.nn.attention.flex_attention import _score_mod_signature
|
4 |
+
from torch._inductor.lowering import make_pointwise, register_lowering
|
5 |
+
|
6 |
+
# Some internal torch.compile details
|
7 |
+
from torch._inductor.virtualized import ops
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
|
11 |
+
@torch.library.custom_op("approx::tanh", mutates_args=())
|
12 |
+
def _tanh_approx(inp: Tensor) -> Tensor:
|
13 |
+
return torch.tanh(inp)
|
14 |
+
|
15 |
+
|
16 |
+
@_tanh_approx.register_fake
|
17 |
+
def _(inp: torch.Tensor) -> torch.Tensor:
|
18 |
+
return torch.tanh(inp)
|
19 |
+
|
20 |
+
|
21 |
+
def _tanh_approx_lowering(inp):
|
22 |
+
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
|
23 |
+
return make_pointwise(fn)(inp)
|
24 |
+
|
25 |
+
|
26 |
+
register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering)
|
27 |
+
|
28 |
+
|
29 |
+
class _TanhApprox(torch.autograd.Function):
|
30 |
+
@staticmethod
|
31 |
+
def forward(x):
|
32 |
+
return torch.ops.approx.tanh(x)
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def setup_context(ctx, inputs, output):
|
36 |
+
(x,) = inputs
|
37 |
+
result = output
|
38 |
+
ctx.save_for_backward(result)
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def backward(ctx, grad_output):
|
42 |
+
(result,) = ctx.saved_tensors
|
43 |
+
return grad_output * (1 - result * result)
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def vmap(info, in_dims, x):
|
47 |
+
return torch.tanh(x), 0
|
48 |
+
|
49 |
+
|
50 |
+
_tanh_approx = _TanhApprox.apply
|
51 |
+
|
52 |
+
|
53 |
+
def generate_tanh_softcap(soft_cap: int, approx: bool = False) -> _score_mod_signature:
|
54 |
+
"""Returns an tanh bias score_mod given the number of heads H
|
55 |
+
|
56 |
+
Args:
|
57 |
+
soft_cap: The soft cap value to use for normalizing logits
|
58 |
+
approx: Whether to use the `tanh.approx.` ptx instruction
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
tanh_softcap: score_mod
|
62 |
+
"""
|
63 |
+
tanh = _tanh_approx if approx else torch.tanh
|
64 |
+
|
65 |
+
def tanh_softcap(score, b, h, q_idx, kv_idx):
|
66 |
+
return soft_cap * tanh(score / soft_cap)
|
67 |
+
|
68 |
+
prefix = "tanh_softcap_approx" if approx else "tanh_softcap"
|
69 |
+
tanh_softcap.__name__ = f"{prefix}_{soft_cap}"
|
70 |
+
|
71 |
+
return tanh_softcap
|
72 |
+
|
73 |
+
def generate_alibi_bias(H: int) -> _score_mod_signature:
|
74 |
+
"""Returns an alibi bias score_mod given the number of heads H
|
75 |
+
|
76 |
+
Args:
|
77 |
+
H: number of heads
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
alibi_bias: alibi bias score_mod
|
81 |
+
"""
|
82 |
+
|
83 |
+
def alibi_mod(score, b, h, q_idx, kv_idx):
|
84 |
+
scale = torch.exp2(-((h + 1) * 8.0 / H))
|
85 |
+
bias = (kv_idx - q_idx) * scale
|
86 |
+
return score + bias
|
87 |
+
|
88 |
+
return alibi_mod
|
89 |
+
|
90 |
+
|
91 |
+
def generate_tanh_softcap_alibi(H: int, soft_cap: float, approx: bool = False) -> _score_mod_signature:
|
92 |
+
"""Returns a combined ALiBi and tanh softcapping score_mod.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
H (int): number of heads for ALiBi scaling
|
96 |
+
soft_cap (float): the soft cap value for normalizing/logit clipping
|
97 |
+
approx (bool): Whether to use the 'tanh.approx' PTX-based approximation
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
A combined score_mod function that first applies ALiBi,
|
101 |
+
then performs softcap + tanh (optionally approximate).
|
102 |
+
"""
|
103 |
+
tanh_func = _tanh_approx if approx else torch.tanh
|
104 |
+
|
105 |
+
def alibi_tanh_softcap(score, b, h, q_idx, kv_idx):
|
106 |
+
# Compute ALiBi bias
|
107 |
+
scale = torch.exp2(-((h + 1) * 8.0 / H))
|
108 |
+
bias = (kv_idx - q_idx) * scale
|
109 |
+
score = score + bias
|
110 |
+
|
111 |
+
# Apply softcap
|
112 |
+
score = score / soft_cap
|
113 |
+
|
114 |
+
# Apply tanh
|
115 |
+
score = tanh_func(score)
|
116 |
+
|
117 |
+
# Rescale by soft_cap
|
118 |
+
score = score * soft_cap
|
119 |
+
return score
|
120 |
+
|
121 |
+
# Give the score_mod a unique name:
|
122 |
+
if approx:
|
123 |
+
alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_approx_{soft_cap}"
|
124 |
+
else:
|
125 |
+
alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_{soft_cap}"
|
126 |
+
|
127 |
+
return alibi_tanh_softcap
|
config.json
CHANGED
@@ -2,17 +2,15 @@
|
|
2 |
"model_type": "minitransformer",
|
3 |
"_name_or_path": "Transformer_500M",
|
4 |
"architectures": ["MiniTransformer"],
|
5 |
-
"
|
6 |
-
"
|
7 |
-
"
|
8 |
"seq_len": 8192,
|
9 |
"window_size": 8192,
|
10 |
"vocab_size": 200064,
|
11 |
"mlp_scale": 4,
|
12 |
"bias": false,
|
13 |
"dropout": 0.0,
|
14 |
-
"num_eigh": 24,
|
15 |
-
"use_hankel_L": false,
|
16 |
"num_epochs": 1,
|
17 |
"global_bsz": 524288,
|
18 |
"bsz": 1,
|
@@ -27,7 +25,7 @@
|
|
27 |
"ddp": true,
|
28 |
"mixed_precision": true,
|
29 |
"torch_dtype": "bfloat16",
|
30 |
-
"
|
31 |
"sharding_strategy": "full_shard",
|
32 |
"state_dict_type": "full",
|
33 |
"auto_wrap_policy": "partial",
|
@@ -42,12 +40,10 @@
|
|
42 |
"buffer": "bfloat16"
|
43 |
},
|
44 |
"fsdp_modules": [
|
45 |
-
"
|
46 |
],
|
47 |
"use_activation_checkpointing": true,
|
48 |
-
"use_flash_fft": true,
|
49 |
-
"use_approx": true,
|
50 |
-
"use_attn": true,
|
51 |
"softcap": 50.0,
|
52 |
-
"
|
|
|
53 |
}
|
|
|
2 |
"model_type": "minitransformer",
|
3 |
"_name_or_path": "Transformer_500M",
|
4 |
"architectures": ["MiniTransformer"],
|
5 |
+
"dim": 768,
|
6 |
+
"num_heads": 24,
|
7 |
+
"num_layers": 27,
|
8 |
"seq_len": 8192,
|
9 |
"window_size": 8192,
|
10 |
"vocab_size": 200064,
|
11 |
"mlp_scale": 4,
|
12 |
"bias": false,
|
13 |
"dropout": 0.0,
|
|
|
|
|
14 |
"num_epochs": 1,
|
15 |
"global_bsz": 524288,
|
16 |
"bsz": 1,
|
|
|
25 |
"ddp": true,
|
26 |
"mixed_precision": true,
|
27 |
"torch_dtype": "bfloat16",
|
28 |
+
"cpu_offload": false,
|
29 |
"sharding_strategy": "full_shard",
|
30 |
"state_dict_type": "full",
|
31 |
"auto_wrap_policy": "partial",
|
|
|
40 |
"buffer": "bfloat16"
|
41 |
},
|
42 |
"fsdp_modules": [
|
43 |
+
"AttentionLayer"
|
44 |
],
|
45 |
"use_activation_checkpointing": true,
|
|
|
|
|
|
|
46 |
"softcap": 50.0,
|
47 |
+
"theta": 10000.0,
|
48 |
+
"torch_compile": true
|
49 |
}
|
configuration_minitransformer.py
CHANGED
@@ -7,33 +7,38 @@ class MiniTransformerConfig(PretrainedConfig):
|
|
7 |
def __init__(
|
8 |
self,
|
9 |
bsz: int = 1,
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
seq_len: int = 8192,
|
14 |
-
window_size: int =
|
15 |
vocab_size: int = 200064,
|
16 |
-
mlp_scale: int =
|
17 |
bias: bool = False,
|
18 |
dropout: float = 0.0,
|
19 |
softcap: float = 50.0,
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
**kwargs,
|
23 |
):
|
24 |
super().__init__(**kwargs)
|
25 |
self.bsz = bsz
|
26 |
-
self.
|
27 |
-
self.
|
28 |
-
self.
|
29 |
self.seq_len = seq_len
|
30 |
self.window_size = window_size
|
31 |
self.vocab_size = vocab_size
|
32 |
-
self.hidden_size =
|
33 |
-
self.
|
34 |
-
self.
|
35 |
self.bias = bias
|
36 |
self.dropout = dropout
|
37 |
self.softcap = softcap
|
|
|
|
|
38 |
self.torch_dtype = torch_dtype
|
39 |
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
|
|
|
|
7 |
def __init__(
|
8 |
self,
|
9 |
bsz: int = 1,
|
10 |
+
dim: int = 1536,
|
11 |
+
num_heads: int = 8,
|
12 |
+
num_layers: int = 26,
|
13 |
seq_len: int = 8192,
|
14 |
+
window_size: int = 1024,
|
15 |
vocab_size: int = 200064,
|
16 |
+
mlp_scale: int = 12,
|
17 |
bias: bool = False,
|
18 |
dropout: float = 0.0,
|
19 |
softcap: float = 50.0,
|
20 |
+
theta: float = 10_000.0,
|
21 |
+
use_alibi: bool = False,
|
22 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
23 |
+
device: torch.device = None,
|
24 |
**kwargs,
|
25 |
):
|
26 |
super().__init__(**kwargs)
|
27 |
self.bsz = bsz
|
28 |
+
self.dim = dim
|
29 |
+
self.num_heads = num_heads
|
30 |
+
self.num_layers = num_layers
|
31 |
self.seq_len = seq_len
|
32 |
self.window_size = window_size
|
33 |
self.vocab_size = vocab_size
|
34 |
+
self.hidden_size = dim
|
35 |
+
self.mlp_scale = mlp_scale
|
36 |
+
self.intermediate_size = self.dim * self.mlp_scale
|
37 |
self.bias = bias
|
38 |
self.dropout = dropout
|
39 |
self.softcap = softcap
|
40 |
+
self.theta = theta
|
41 |
+
self.use_alibi = use_alibi
|
42 |
self.torch_dtype = torch_dtype
|
43 |
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
|
44 |
+
|
layers.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
|
4 |
-
from .
|
5 |
from .modules import MLP
|
6 |
from .modules import Attention
|
7 |
try:
|
@@ -23,80 +23,19 @@ except ImportError as e:
|
|
23 |
from torch.nn import RMSNorm
|
24 |
triton_norm = False
|
25 |
|
26 |
-
|
27 |
-
class STULayer(nn.Module):
|
28 |
-
def __init__(self, config, phi, n):
|
29 |
-
super(STULayer, self).__init__()
|
30 |
-
if isinstance(config.torch_dtype, str):
|
31 |
-
torch_dtype = getattr(torch, config.torch_dtype)
|
32 |
-
else:
|
33 |
-
torch_dtype = config.torch_dtype
|
34 |
-
self.stu_norm = (
|
35 |
-
TritonNorm(config.n_embd)
|
36 |
-
if triton_norm
|
37 |
-
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
38 |
-
)
|
39 |
-
self.stu = STU(config, phi, n)
|
40 |
-
self.stu = self.stu.to(dtype=torch_dtype)
|
41 |
-
self.mlp_norm = (
|
42 |
-
TritonNorm(config.n_embd)
|
43 |
-
if triton_norm
|
44 |
-
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
45 |
-
)
|
46 |
-
self.mlp = (
|
47 |
-
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
|
48 |
-
)
|
49 |
-
|
50 |
-
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
|
51 |
-
self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
|
52 |
-
self.mlp = self.mlp.to(dtype=torch_dtype)
|
53 |
-
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
|
54 |
-
|
55 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56 |
-
# Debug dtype
|
57 |
-
|
58 |
-
# Normalize and apply STU
|
59 |
-
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
|
60 |
-
x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
|
61 |
-
x = x + x_stu
|
62 |
-
|
63 |
-
# Normalize and apply MLP
|
64 |
-
x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP
|
65 |
-
x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
|
66 |
-
x = x + x_mlp
|
67 |
-
|
68 |
-
return x
|
69 |
-
|
70 |
class AttentionLayer(nn.Module):
|
71 |
-
def __init__(self, config) -> None:
|
72 |
super(AttentionLayer, self).__init__()
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
TritonNorm(config.n_embd)
|
79 |
-
if triton_norm
|
80 |
-
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
81 |
-
)
|
82 |
-
self.attn = Attention(config)
|
83 |
-
self.attn = self.attn.to(dtype=torch_dtype)
|
84 |
-
self.mlp_norm = (
|
85 |
-
TritonNorm(config.n_embd)
|
86 |
-
if triton_norm
|
87 |
-
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
88 |
)
|
89 |
-
self.
|
90 |
-
|
91 |
-
)
|
92 |
-
self.mlp = self.mlp.to(dtype=torch_dtype)
|
93 |
-
|
94 |
-
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
|
95 |
-
self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
|
96 |
-
self.mlp = self.mlp.to(dtype=torch_dtype)
|
97 |
-
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
|
98 |
|
99 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
100 |
-
x = x + self.attn(self.attn_norm(x))
|
101 |
x = x + self.mlp(self.mlp_norm(x))
|
102 |
return x
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
|
4 |
+
from .attn import FlexAttention
|
5 |
from .modules import MLP
|
6 |
from .modules import Attention
|
7 |
try:
|
|
|
23 |
from torch.nn import RMSNorm
|
24 |
triton_norm = False
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class AttentionLayer(nn.Module):
|
27 |
+
def __init__(self, config, mask_mod, score_mod=None) -> None:
|
28 |
super(AttentionLayer, self).__init__()
|
29 |
+
self.attn_norm = nn.RMSNorm(config.dim)
|
30 |
+
self.attn = FlexAttention(
|
31 |
+
config=config,
|
32 |
+
mask_mod=mask_mod,
|
33 |
+
score_mod=score_mod,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
)
|
35 |
+
self.mlp_norm = nn.RMSNorm(config.dim)
|
36 |
+
self.mlp = MLP(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
|
39 |
+
x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis)
|
40 |
x = x + self.mlp(self.mlp_norm(x))
|
41 |
return x
|
modeling_minitransformer.py
CHANGED
@@ -10,6 +10,10 @@ from .utils import nearest_power_of_two
|
|
10 |
from .layers import AttentionLayer
|
11 |
from .configuration_minitransformer import MiniTransformerConfig
|
12 |
|
|
|
|
|
|
|
|
|
13 |
try:
|
14 |
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
|
15 |
triton_norm = True
|
@@ -33,39 +37,31 @@ class MiniTransformer(PreTrainedModel):
|
|
33 |
|
34 |
def __init__(self, config) -> None:
|
35 |
super(MiniTransformer, self).__init__(config)
|
36 |
-
self.
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
)
|
50 |
self.dropout = nn.Dropout(config.dropout)
|
51 |
|
52 |
self.layers = nn.ModuleList()
|
53 |
-
for _ in range(self.
|
54 |
-
|
|
|
55 |
|
56 |
-
self.norm = (
|
57 |
-
|
58 |
-
|
59 |
-
else RMSNorm(config.n_embd, dtype=config.torch_dtype)
|
60 |
-
)
|
61 |
-
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for RMS Norm
|
62 |
-
self.norm = self.norm.to(dtype=config.torch_dtype)
|
63 |
-
self.lm_head = nn.Linear(
|
64 |
-
config.n_embd, config.vocab_size, bias=config.bias, dtype=config.torch_dtype
|
65 |
-
)
|
66 |
-
self.tok_emb.weight = self.lm_head.weight
|
67 |
|
68 |
-
self.std = (config.
|
69 |
self.apply(self._init_weights)
|
70 |
print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
|
71 |
|
@@ -77,15 +73,13 @@ class MiniTransformer(PreTrainedModel):
|
|
77 |
) -> CausalLMOutput:
|
78 |
# Compute embeddings
|
79 |
tok_emb = self.tok_emb(input_ids)
|
80 |
-
x = self.dropout(tok_emb)
|
81 |
|
82 |
-
# Pass through layers
|
83 |
for layer in self.layers:
|
84 |
-
|
85 |
|
86 |
# Normalize and project to vocabulary
|
87 |
-
|
88 |
-
logits = self.lm_head(
|
89 |
|
90 |
loss = None
|
91 |
if labels is not None:
|
@@ -107,26 +101,20 @@ class MiniTransformer(PreTrainedModel):
|
|
107 |
n_params = sum(p.numel() for p in self.parameters())
|
108 |
if hasattr(self, "pos_emb") and self.pos_emb is not None:
|
109 |
n_params -= self.pos_emb.weight.numel()
|
110 |
-
if self.tok_emb.weight is
|
111 |
n_params -= self.tok_emb.weight.numel()
|
112 |
return n_params
|
113 |
|
114 |
def _init_weights(self, module):
|
115 |
if isinstance(module, nn.Linear):
|
116 |
if hasattr(module, "SCALE_INIT"):
|
117 |
-
self.std *= (2 * self.
|
118 |
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
119 |
if module.bias is not None:
|
120 |
torch.nn.init.zeros_(module.bias)
|
121 |
elif isinstance(module, nn.Embedding):
|
122 |
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
123 |
-
|
124 |
-
torch.nn.init.xavier_normal_(module.attn.weight)
|
125 |
-
torch.nn.init.xavier_normal_(module.o_proj.weight)
|
126 |
-
if module.attn.bias is not None:
|
127 |
-
torch.nn.init.zeros_(module.attn.bias)
|
128 |
-
if module.o_proj.bias is not None:
|
129 |
-
torch.nn.init.zeros_(module.o_proj.bias)
|
130 |
@staticmethod
|
131 |
def top_k_top_p_filtering(
|
132 |
logits: torch.Tensor,
|
|
|
10 |
from .layers import AttentionLayer
|
11 |
from .configuration_minitransformer import MiniTransformerConfig
|
12 |
|
13 |
+
from .attn_masks import causal_mask
|
14 |
+
from .attn_mods import generate_tanh_softcap
|
15 |
+
from .rotary_emb import precompute_freqs_cis
|
16 |
+
|
17 |
try:
|
18 |
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
|
19 |
triton_norm = True
|
|
|
37 |
|
38 |
def __init__(self, config) -> None:
|
39 |
super(MiniTransformer, self).__init__(config)
|
40 |
+
self.num_layers = config.num_layers
|
41 |
+
assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
|
42 |
+
self.head_dim = config.dim // config.num_heads
|
43 |
+
logit_softcap = generate_tanh_softcap(soft_cap=config.softcap)
|
44 |
+
|
45 |
+
# From pytorch/pytorch#123411, we set persistent=True for torch.compile and PP compatibility
|
46 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(
|
47 |
+
head_dim=self.head_dim,
|
48 |
+
max_seq_len=config.seq_len,
|
49 |
+
theta=config.theta,
|
50 |
+
), persistent=True)
|
51 |
+
|
52 |
+
self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
|
|
|
53 |
self.dropout = nn.Dropout(config.dropout)
|
54 |
|
55 |
self.layers = nn.ModuleList()
|
56 |
+
for _ in range(self.num_layers):
|
57 |
+
layer = AttentionLayer(config, mask_mod=causal_mask, score_mod=logit_softcap)
|
58 |
+
self.layers.append(layer)
|
59 |
|
60 |
+
self.norm = nn.RMSNorm(config.dim)
|
61 |
+
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)
|
62 |
+
# self.tok_emb.weight = self.lm_head.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
self.std = (config.dim) ** -0.5
|
65 |
self.apply(self._init_weights)
|
66 |
print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
|
67 |
|
|
|
73 |
) -> CausalLMOutput:
|
74 |
# Compute embeddings
|
75 |
tok_emb = self.tok_emb(input_ids)
|
|
|
76 |
|
|
|
77 |
for layer in self.layers:
|
78 |
+
tok_emb = layer(tok_emb, self.freqs_cis)
|
79 |
|
80 |
# Normalize and project to vocabulary
|
81 |
+
tok_emb = self.norm(tok_emb)
|
82 |
+
logits = self.lm_head(tok_emb)
|
83 |
|
84 |
loss = None
|
85 |
if labels is not None:
|
|
|
101 |
n_params = sum(p.numel() for p in self.parameters())
|
102 |
if hasattr(self, "pos_emb") and self.pos_emb is not None:
|
103 |
n_params -= self.pos_emb.weight.numel()
|
104 |
+
if self.tok_emb.weight is self.lm_head.weight:
|
105 |
n_params -= self.tok_emb.weight.numel()
|
106 |
return n_params
|
107 |
|
108 |
def _init_weights(self, module):
|
109 |
if isinstance(module, nn.Linear):
|
110 |
if hasattr(module, "SCALE_INIT"):
|
111 |
+
self.std *= (2 * self.num_layers) ** -0.5
|
112 |
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
113 |
if module.bias is not None:
|
114 |
torch.nn.init.zeros_(module.bias)
|
115 |
elif isinstance(module, nn.Embedding):
|
116 |
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
@staticmethod
|
119 |
def top_k_top_p_filtering(
|
120 |
logits: torch.Tensor,
|