yagizdevre commited on
Commit
a2fbb2f
·
1 Parent(s): 44d4c6b

transformer new

Browse files
Files changed (7) hide show
  1. attn.py +66 -0
  2. attn_masks.py +188 -0
  3. attn_mods.py +127 -0
  4. config.json +7 -11
  5. configuration_minitransformer.py +18 -13
  6. layers.py +11 -72
  7. modeling_minitransformer.py +30 -42
attn.py CHANGED
@@ -123,3 +123,69 @@ class AttentionSDPA(nn.Module):
123
 
124
  y = self.resid_dropout(self.o_proj(y))
125
  return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  y = self.resid_dropout(self.o_proj(y))
125
  return y
126
+
127
+
128
+ class FlexAttention(nn.Module):
129
+ """
130
+ Generalized Multihead Attention and supports various attention masks.
131
+ Supports Rotary Positional Embeddings.
132
+ """
133
+ def __init__(self, config, mask_mod, score_mod=None):
134
+ """
135
+ Initializes the Attention class.
136
+
137
+ Args:
138
+ dim (int): Embedding size.
139
+ num_heads (int): Number of heads.
140
+ mask_mod (Callable): Mask to modify attention scores, e.g. causal.
141
+ """
142
+ super().__init__()
143
+ self.dim, self.num_heads = config.dim, config.num_heads
144
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
145
+ self.head_dim = config.dim // config.num_heads
146
+
147
+ self.wq = nn.Linear(config.dim, config.dim)
148
+ self.wk = nn.Linear(config.dim, config.dim)
149
+ self.wv = nn.Linear(config.dim, config.dim)
150
+
151
+ self.mask_mod = mask_mod
152
+ self.score_mod = score_mod
153
+ self.block_mask = create_block_mask(
154
+ mask_mod=self.mask_mod,
155
+ B=None, # Broadcast
156
+ H=None, # Broadcast
157
+ Q_LEN=config.seq_len,
158
+ KV_LEN=config.seq_len,
159
+ device=config.device,
160
+ )
161
+
162
+ self.o_proj = nn.Linear(config.dim, config.dim)
163
+ self.o_proj.SCALE_INIT = 1
164
+
165
+ def forward(
166
+ self,
167
+ x: torch.Tensor = None,
168
+ q: torch.Tensor = None,
169
+ k: torch.Tensor = None,
170
+ v: torch.Tensor = None,
171
+ freqs_cis: torch.Tensor = None,
172
+ ) -> torch.Tensor:
173
+ if x is not None:
174
+ q = k = v = x
175
+ if any(t is None for t in [q, k, v]):
176
+ raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
177
+
178
+ bsz, q_len, _ = q.shape
179
+ _, k_len, _ = k.shape
180
+ _, v_len, _ = v.shape
181
+
182
+ Q = self.wq(q).reshape(bsz, self.num_heads, q_len, self.head_dim)
183
+ K = self.wk(k).reshape(bsz, self.num_heads, k_len, self.head_dim)
184
+ V = self.wv(v).reshape(bsz, self.num_heads, v_len, self.head_dim)
185
+
186
+ Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis)
187
+
188
+ output = flex_attention(Q, K, V, block_mask=self.block_mask, score_mod=self.score_mod)
189
+ output = output.reshape(bsz, q_len, self.dim)
190
+ output = self.o_proj(output)
191
+ return output
attn_masks.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.attention.flex_attention import _mask_mod_signature
3
+
4
+ def causal_mask(
5
+ batch_size: int,
6
+ num_heads: int,
7
+ q_idx: torch.Tensor,
8
+ kv_idx: torch.Tensor
9
+ ) -> torch.Tensor:
10
+ """
11
+ Returns a boolean tensor indicating which positions in the attention matrix
12
+ are valid for causal (autoregressive) attention. By default, it's True for
13
+ positions (i, j) where i >= j.
14
+
15
+ Args:
16
+ batch_size (int): Batch size (unused here).
17
+ num_heads (int): Number of heads (unused here).
18
+ q_idx (torch.Tensor): Tensor indexing the query positions.
19
+ kv_idx (torch.Tensor): Tensor indexing the key/value positions.
20
+
21
+ Returns:
22
+ torch.Tensor: A boolean tensor where True indicates that the query at
23
+ position i can attend to the key at position j, respecting i >= j.
24
+ """
25
+ return q_idx >= kv_idx
26
+
27
+
28
+ def generate_sliding_window_mask(window_size: int, causal: bool = True) -> _mask_mod_signature:
29
+ """
30
+ Creates a sliding window mask function.
31
+
32
+ If `causal=True`, each query token at position i can attend only to tokens j
33
+ in [i - window_size, i].
34
+ If `causal=False`, each query token i can attend to any token j in
35
+ [i - window_size, i + window_size], i.e. a symmetric window of size `window_size`.
36
+
37
+ Args:
38
+ window_size (int): The maximum distance from i that i can attend to.
39
+ causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True.
40
+
41
+ Returns:
42
+ _mask_mod_signature: A callable mask function that takes
43
+ (batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor
44
+ indicating allowed attention connections.
45
+ """
46
+ def sliding_window_mask(
47
+ batch_size: int,
48
+ num_heads: int,
49
+ q_idx: torch.Tensor,
50
+ kv_idx: torch.Tensor
51
+ ) -> torch.Tensor:
52
+ """
53
+ If causal is True:
54
+ within_window = (q_idx - kv_idx) <= window_size, and q_idx >= kv_idx.
55
+ If causal is False:
56
+ within_window = abs(q_idx - kv_idx) <= window_size.
57
+ """
58
+ if causal:
59
+ # standard "look back" window
60
+ distance = q_idx - kv_idx
61
+ within_window = (distance >= 0) & (distance <= window_size)
62
+ else:
63
+ # symmetrical window around i
64
+ distance = (q_idx - kv_idx).abs()
65
+ within_window = distance <= window_size
66
+
67
+ return within_window
68
+
69
+ name_ext = "causal" if causal else "noncausal"
70
+ sliding_window_mask.__name__ = f"sliding_window_{window_size}_{name_ext}"
71
+ return sliding_window_mask
72
+
73
+
74
+ def generate_dilated_sliding_window_mask(
75
+ window_size: int,
76
+ dilation: int = 2,
77
+ causal: bool = True
78
+ ) -> _mask_mod_signature:
79
+ """
80
+ Creates a dilated sliding window mask function.
81
+
82
+ If `causal=True`, each query token i can attend tokens j in [i - window_size, i]
83
+ such that (i - j) % dilation == 0.
84
+ If `causal=False`, each query token i can attend tokens j in [i - window_size,
85
+ i + window_size] for which |i - j| % dilation == 0.
86
+
87
+ Args:
88
+ window_size (int): The maximum distance from i to j (backwards if causal=True,
89
+ otherwise symmetric around i).
90
+ dilation (int): The stride for skipping positions.
91
+ causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True.
92
+
93
+ Returns:
94
+ _mask_mod_signature: A callable mask function that takes
95
+ (batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor
96
+ indicating allowed attention connections.
97
+ """
98
+ def dilated_sliding_window_mask(
99
+ batch_size: int,
100
+ num_heads: int,
101
+ q_idx: torch.Tensor,
102
+ kv_idx: torch.Tensor
103
+ ) -> torch.Tensor:
104
+ """
105
+ If causal is True:
106
+ distance = q_idx - kv_idx
107
+ 0 <= distance <= window_size and distance % dilation == 0.
108
+ If causal is False:
109
+ distance = (q_idx - kv_idx).abs()
110
+ distance <= window_size and distance % dilation == 0.
111
+ """
112
+ if causal:
113
+ distance = q_idx - kv_idx
114
+ within_window = (distance >= 0) & (distance <= window_size)
115
+ else:
116
+ distance = (q_idx - kv_idx).abs()
117
+ within_window = distance <= window_size
118
+
119
+ meets_dilation = (distance % dilation) == 0
120
+ return within_window & meets_dilation
121
+
122
+ mode_str = "causal" if causal else "noncausal"
123
+ dilated_sliding_window_mask.__name__ = (
124
+ f"dilated_sliding_window_{window_size}_dilation_{dilation}_{mode_str}"
125
+ )
126
+ return dilated_sliding_window_mask
127
+
128
+
129
+ def main():
130
+ """
131
+ Demonstrates usage of each mask by printing attention grids. We include a few
132
+ basic checks to ensure the masks behave as expected. We show both the causal
133
+ and non-causal versions for the sliding window and dilated masks.
134
+ """
135
+ B, H = 1, 1
136
+ Q_LEN, KV_LEN = 8, 8
137
+
138
+ # coordinate grids
139
+ q_idx = torch.arange(Q_LEN).unsqueeze(-1).expand(Q_LEN, KV_LEN)
140
+ kv_idx = torch.arange(KV_LEN).unsqueeze(0).expand(Q_LEN, KV_LEN)
141
+
142
+ print("= Causal Mask =")
143
+ c_mask = causal_mask(B, H, q_idx, kv_idx)
144
+ print(c_mask.int(), "\n")
145
+
146
+ print("= Sliding Window (window_size=2, causal=True) =")
147
+ sw_causal_fn = generate_sliding_window_mask(window_size=2, causal=True)
148
+ sw_causal = sw_causal_fn(B, H, q_idx, kv_idx)
149
+ print(sw_causal.int(), "\n")
150
+
151
+ print("= Sliding Window (window_size=2, causal=False) =")
152
+ sw_noncausal_fn = generate_sliding_window_mask(window_size=2, causal=False)
153
+ sw_noncausal = sw_noncausal_fn(B, H, q_idx, kv_idx)
154
+ print(sw_noncausal.int(), "\n")
155
+
156
+ print("= Dilated Sliding Window (window_size=4, dilation=2, causal=True) =")
157
+ ds_causal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=True)
158
+ ds_causal = ds_causal_fn(B, H, q_idx, kv_idx)
159
+ print(ds_causal.int(), "\n")
160
+
161
+ print("= Dilated Sliding Window (window_size=4, dilation=2, causal=False) =")
162
+ ds_noncausal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=False)
163
+ ds_noncausal = ds_noncausal_fn(B, H, q_idx, kv_idx)
164
+ print(ds_noncausal.int(), "\n")
165
+
166
+ # Quick checks:
167
+ # (1) Causal means no i < j
168
+ assert torch.all(c_mask == (q_idx >= kv_idx)), "Causal mask mismatch!"
169
+ # (2) For windowed masks with causal=True, check a random row
170
+ i = 5
171
+ row_sw = sw_causal[i]
172
+ allowed_js = torch.where(row_sw)[0]
173
+ if len(allowed_js) > 0:
174
+ # difference i-j <= 2
175
+ assert (i - allowed_js.min()) <= 2, "Window mismatch for sliding_window_mask(causal=True)."
176
+
177
+ # (3) Dilated mask with causal=True should skip every other position if dilation=2
178
+ i = 6
179
+ row_ds = ds_causal[i]
180
+ allowed_js = torch.where(row_ds)[0]
181
+ for j in allowed_js:
182
+ diff = i - j
183
+ assert diff % 2 == 0, f"Dilation mismatch: got diff={diff}."
184
+
185
+ print("All checks passed.")
186
+
187
+ if __name__ == "__main__":
188
+ main()
attn_mods.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn.attention.flex_attention import _score_mod_signature
4
+ from torch._inductor.lowering import make_pointwise, register_lowering
5
+
6
+ # Some internal torch.compile details
7
+ from torch._inductor.virtualized import ops
8
+ from functools import partial
9
+
10
+
11
+ @torch.library.custom_op("approx::tanh", mutates_args=())
12
+ def _tanh_approx(inp: Tensor) -> Tensor:
13
+ return torch.tanh(inp)
14
+
15
+
16
+ @_tanh_approx.register_fake
17
+ def _(inp: torch.Tensor) -> torch.Tensor:
18
+ return torch.tanh(inp)
19
+
20
+
21
+ def _tanh_approx_lowering(inp):
22
+ fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
23
+ return make_pointwise(fn)(inp)
24
+
25
+
26
+ register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering)
27
+
28
+
29
+ class _TanhApprox(torch.autograd.Function):
30
+ @staticmethod
31
+ def forward(x):
32
+ return torch.ops.approx.tanh(x)
33
+
34
+ @staticmethod
35
+ def setup_context(ctx, inputs, output):
36
+ (x,) = inputs
37
+ result = output
38
+ ctx.save_for_backward(result)
39
+
40
+ @staticmethod
41
+ def backward(ctx, grad_output):
42
+ (result,) = ctx.saved_tensors
43
+ return grad_output * (1 - result * result)
44
+
45
+ @staticmethod
46
+ def vmap(info, in_dims, x):
47
+ return torch.tanh(x), 0
48
+
49
+
50
+ _tanh_approx = _TanhApprox.apply
51
+
52
+
53
+ def generate_tanh_softcap(soft_cap: int, approx: bool = False) -> _score_mod_signature:
54
+ """Returns an tanh bias score_mod given the number of heads H
55
+
56
+ Args:
57
+ soft_cap: The soft cap value to use for normalizing logits
58
+ approx: Whether to use the `tanh.approx.` ptx instruction
59
+
60
+ Returns:
61
+ tanh_softcap: score_mod
62
+ """
63
+ tanh = _tanh_approx if approx else torch.tanh
64
+
65
+ def tanh_softcap(score, b, h, q_idx, kv_idx):
66
+ return soft_cap * tanh(score / soft_cap)
67
+
68
+ prefix = "tanh_softcap_approx" if approx else "tanh_softcap"
69
+ tanh_softcap.__name__ = f"{prefix}_{soft_cap}"
70
+
71
+ return tanh_softcap
72
+
73
+ def generate_alibi_bias(H: int) -> _score_mod_signature:
74
+ """Returns an alibi bias score_mod given the number of heads H
75
+
76
+ Args:
77
+ H: number of heads
78
+
79
+ Returns:
80
+ alibi_bias: alibi bias score_mod
81
+ """
82
+
83
+ def alibi_mod(score, b, h, q_idx, kv_idx):
84
+ scale = torch.exp2(-((h + 1) * 8.0 / H))
85
+ bias = (kv_idx - q_idx) * scale
86
+ return score + bias
87
+
88
+ return alibi_mod
89
+
90
+
91
+ def generate_tanh_softcap_alibi(H: int, soft_cap: float, approx: bool = False) -> _score_mod_signature:
92
+ """Returns a combined ALiBi and tanh softcapping score_mod.
93
+
94
+ Args:
95
+ H (int): number of heads for ALiBi scaling
96
+ soft_cap (float): the soft cap value for normalizing/logit clipping
97
+ approx (bool): Whether to use the 'tanh.approx' PTX-based approximation
98
+
99
+ Returns:
100
+ A combined score_mod function that first applies ALiBi,
101
+ then performs softcap + tanh (optionally approximate).
102
+ """
103
+ tanh_func = _tanh_approx if approx else torch.tanh
104
+
105
+ def alibi_tanh_softcap(score, b, h, q_idx, kv_idx):
106
+ # Compute ALiBi bias
107
+ scale = torch.exp2(-((h + 1) * 8.0 / H))
108
+ bias = (kv_idx - q_idx) * scale
109
+ score = score + bias
110
+
111
+ # Apply softcap
112
+ score = score / soft_cap
113
+
114
+ # Apply tanh
115
+ score = tanh_func(score)
116
+
117
+ # Rescale by soft_cap
118
+ score = score * soft_cap
119
+ return score
120
+
121
+ # Give the score_mod a unique name:
122
+ if approx:
123
+ alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_approx_{soft_cap}"
124
+ else:
125
+ alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_{soft_cap}"
126
+
127
+ return alibi_tanh_softcap
config.json CHANGED
@@ -2,17 +2,15 @@
2
  "model_type": "minitransformer",
3
  "_name_or_path": "Transformer_500M",
4
  "architectures": ["MiniTransformer"],
5
- "n_embd": 768,
6
- "n_heads": 12,
7
- "n_layers": 27,
8
  "seq_len": 8192,
9
  "window_size": 8192,
10
  "vocab_size": 200064,
11
  "mlp_scale": 4,
12
  "bias": false,
13
  "dropout": 0.0,
14
- "num_eigh": 24,
15
- "use_hankel_L": false,
16
  "num_epochs": 1,
17
  "global_bsz": 524288,
18
  "bsz": 1,
@@ -27,7 +25,7 @@
27
  "ddp": true,
28
  "mixed_precision": true,
29
  "torch_dtype": "bfloat16",
30
- "use_cpu_offload": false,
31
  "sharding_strategy": "full_shard",
32
  "state_dict_type": "full",
33
  "auto_wrap_policy": "partial",
@@ -42,12 +40,10 @@
42
  "buffer": "bfloat16"
43
  },
44
  "fsdp_modules": [
45
- "Attention"
46
  ],
47
  "use_activation_checkpointing": true,
48
- "use_flash_fft": true,
49
- "use_approx": true,
50
- "use_attn": true,
51
  "softcap": 50.0,
52
- "torch_compile": false
 
53
  }
 
2
  "model_type": "minitransformer",
3
  "_name_or_path": "Transformer_500M",
4
  "architectures": ["MiniTransformer"],
5
+ "dim": 768,
6
+ "num_heads": 24,
7
+ "num_layers": 27,
8
  "seq_len": 8192,
9
  "window_size": 8192,
10
  "vocab_size": 200064,
11
  "mlp_scale": 4,
12
  "bias": false,
13
  "dropout": 0.0,
 
 
14
  "num_epochs": 1,
15
  "global_bsz": 524288,
16
  "bsz": 1,
 
25
  "ddp": true,
26
  "mixed_precision": true,
27
  "torch_dtype": "bfloat16",
28
+ "cpu_offload": false,
29
  "sharding_strategy": "full_shard",
30
  "state_dict_type": "full",
31
  "auto_wrap_policy": "partial",
 
40
  "buffer": "bfloat16"
41
  },
42
  "fsdp_modules": [
43
+ "AttentionLayer"
44
  ],
45
  "use_activation_checkpointing": true,
 
 
 
46
  "softcap": 50.0,
47
+ "theta": 10000.0,
48
+ "torch_compile": true
49
  }
configuration_minitransformer.py CHANGED
@@ -7,33 +7,38 @@ class MiniTransformerConfig(PretrainedConfig):
7
  def __init__(
8
  self,
9
  bsz: int = 1,
10
- n_embd: int = 768,
11
- n_heads: int = 12,
12
- n_layers: int = 27,
13
  seq_len: int = 8192,
14
- window_size: int = 8192,
15
  vocab_size: int = 200064,
16
- mlp_scale: int = 4,
17
  bias: bool = False,
18
  dropout: float = 0.0,
19
  softcap: float = 50.0,
20
- torch_dtype = torch.bfloat16,
21
- device: str = None,
 
 
22
  **kwargs,
23
  ):
24
  super().__init__(**kwargs)
25
  self.bsz = bsz
26
- self.n_embd = n_embd
27
- self.n_heads = n_heads
28
- self.n_layers = n_layers
29
  self.seq_len = seq_len
30
  self.window_size = window_size
31
  self.vocab_size = vocab_size
32
- self.hidden_size = n_embd
33
- self.intermediate_size = n_embd * mlp_scale
34
- self.hidden_act = "swish"
35
  self.bias = bias
36
  self.dropout = dropout
37
  self.softcap = softcap
 
 
38
  self.torch_dtype = torch_dtype
39
  self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
 
 
7
  def __init__(
8
  self,
9
  bsz: int = 1,
10
+ dim: int = 1536,
11
+ num_heads: int = 8,
12
+ num_layers: int = 26,
13
  seq_len: int = 8192,
14
+ window_size: int = 1024,
15
  vocab_size: int = 200064,
16
+ mlp_scale: int = 12,
17
  bias: bool = False,
18
  dropout: float = 0.0,
19
  softcap: float = 50.0,
20
+ theta: float = 10_000.0,
21
+ use_alibi: bool = False,
22
+ torch_dtype: torch.dtype = torch.bfloat16,
23
+ device: torch.device = None,
24
  **kwargs,
25
  ):
26
  super().__init__(**kwargs)
27
  self.bsz = bsz
28
+ self.dim = dim
29
+ self.num_heads = num_heads
30
+ self.num_layers = num_layers
31
  self.seq_len = seq_len
32
  self.window_size = window_size
33
  self.vocab_size = vocab_size
34
+ self.hidden_size = dim
35
+ self.mlp_scale = mlp_scale
36
+ self.intermediate_size = self.dim * self.mlp_scale
37
  self.bias = bias
38
  self.dropout = dropout
39
  self.softcap = softcap
40
+ self.theta = theta
41
+ self.use_alibi = use_alibi
42
  self.torch_dtype = torch_dtype
43
  self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
44
+
layers.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
 
4
- from .modules import STU
5
  from .modules import MLP
6
  from .modules import Attention
7
  try:
@@ -23,80 +23,19 @@ except ImportError as e:
23
  from torch.nn import RMSNorm
24
  triton_norm = False
25
 
26
-
27
- class STULayer(nn.Module):
28
- def __init__(self, config, phi, n):
29
- super(STULayer, self).__init__()
30
- if isinstance(config.torch_dtype, str):
31
- torch_dtype = getattr(torch, config.torch_dtype)
32
- else:
33
- torch_dtype = config.torch_dtype
34
- self.stu_norm = (
35
- TritonNorm(config.n_embd)
36
- if triton_norm
37
- else RMSNorm(config.n_embd, dtype=torch_dtype)
38
- )
39
- self.stu = STU(config, phi, n)
40
- self.stu = self.stu.to(dtype=torch_dtype)
41
- self.mlp_norm = (
42
- TritonNorm(config.n_embd)
43
- if triton_norm
44
- else RMSNorm(config.n_embd, dtype=torch_dtype)
45
- )
46
- self.mlp = (
47
- TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
48
- )
49
-
50
- # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
51
- self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
52
- self.mlp = self.mlp.to(dtype=torch_dtype)
53
- self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
54
-
55
- def forward(self, x: torch.Tensor) -> torch.Tensor:
56
- # Debug dtype
57
-
58
- # Normalize and apply STU
59
- x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
60
- x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
61
- x = x + x_stu
62
-
63
- # Normalize and apply MLP
64
- x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP
65
- x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
66
- x = x + x_mlp
67
-
68
- return x
69
-
70
  class AttentionLayer(nn.Module):
71
- def __init__(self, config) -> None:
72
  super(AttentionLayer, self).__init__()
73
- if isinstance(config.torch_dtype, str):
74
- torch_dtype = getattr(torch, config.torch_dtype)
75
- else:
76
- torch_dtype = config.torch_dtype
77
- self.attn_norm = (
78
- TritonNorm(config.n_embd)
79
- if triton_norm
80
- else RMSNorm(config.n_embd, dtype=torch_dtype)
81
- )
82
- self.attn = Attention(config)
83
- self.attn = self.attn.to(dtype=torch_dtype)
84
- self.mlp_norm = (
85
- TritonNorm(config.n_embd)
86
- if triton_norm
87
- else RMSNorm(config.n_embd, dtype=torch_dtype)
88
  )
89
- self.mlp = (
90
- TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
91
- )
92
- self.mlp = self.mlp.to(dtype=torch_dtype)
93
-
94
- # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
95
- self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
96
- self.mlp = self.mlp.to(dtype=torch_dtype)
97
- self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
98
 
99
- def forward(self, x: torch.Tensor) -> torch.Tensor:
100
- x = x + self.attn(self.attn_norm(x))
101
  x = x + self.mlp(self.mlp_norm(x))
102
  return x
 
1
  import torch
2
  import torch.nn as nn
3
 
4
+ from .attn import FlexAttention
5
  from .modules import MLP
6
  from .modules import Attention
7
  try:
 
23
  from torch.nn import RMSNorm
24
  triton_norm = False
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class AttentionLayer(nn.Module):
27
+ def __init__(self, config, mask_mod, score_mod=None) -> None:
28
  super(AttentionLayer, self).__init__()
29
+ self.attn_norm = nn.RMSNorm(config.dim)
30
+ self.attn = FlexAttention(
31
+ config=config,
32
+ mask_mod=mask_mod,
33
+ score_mod=score_mod,
 
 
 
 
 
 
 
 
 
 
34
  )
35
+ self.mlp_norm = nn.RMSNorm(config.dim)
36
+ self.mlp = MLP(config)
 
 
 
 
 
 
 
37
 
38
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
39
+ x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis)
40
  x = x + self.mlp(self.mlp_norm(x))
41
  return x
modeling_minitransformer.py CHANGED
@@ -10,6 +10,10 @@ from .utils import nearest_power_of_two
10
  from .layers import AttentionLayer
11
  from .configuration_minitransformer import MiniTransformerConfig
12
 
 
 
 
 
13
  try:
14
  from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
15
  triton_norm = True
@@ -33,39 +37,31 @@ class MiniTransformer(PreTrainedModel):
33
 
34
  def __init__(self, config) -> None:
35
  super(MiniTransformer, self).__init__(config)
36
- self.n_layers = config.n_layers
37
- self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
38
-
39
- if isinstance(config.torch_dtype, torch.dtype):
40
- torch_dtype = config.torch_dtype
41
- else:
42
- torch_dtype = getattr(torch, config.torch_dtype)
43
-
44
- device = torch.device(config.device)
45
-
46
- # TODO: Add support for Liger-Kernel Embedding once no longer experimental
47
- self.tok_emb = nn.Embedding(
48
- config.vocab_size, config.n_embd, dtype=config.torch_dtype
49
- )
50
  self.dropout = nn.Dropout(config.dropout)
51
 
52
  self.layers = nn.ModuleList()
53
- for _ in range(self.n_layers):
54
- self.layers.append(AttentionLayer(config))
 
55
 
56
- self.norm = (
57
- TritonNorm(config.n_embd)
58
- if triton_norm
59
- else RMSNorm(config.n_embd, dtype=config.torch_dtype)
60
- )
61
- # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for RMS Norm
62
- self.norm = self.norm.to(dtype=config.torch_dtype)
63
- self.lm_head = nn.Linear(
64
- config.n_embd, config.vocab_size, bias=config.bias, dtype=config.torch_dtype
65
- )
66
- self.tok_emb.weight = self.lm_head.weight
67
 
68
- self.std = (config.n_embd) ** -0.5
69
  self.apply(self._init_weights)
70
  print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
71
 
@@ -77,15 +73,13 @@ class MiniTransformer(PreTrainedModel):
77
  ) -> CausalLMOutput:
78
  # Compute embeddings
79
  tok_emb = self.tok_emb(input_ids)
80
- x = self.dropout(tok_emb)
81
 
82
- # Pass through layers
83
  for layer in self.layers:
84
- x = layer(x)
85
 
86
  # Normalize and project to vocabulary
87
- x = self.norm(x)
88
- logits = self.lm_head(x)
89
 
90
  loss = None
91
  if labels is not None:
@@ -107,26 +101,20 @@ class MiniTransformer(PreTrainedModel):
107
  n_params = sum(p.numel() for p in self.parameters())
108
  if hasattr(self, "pos_emb") and self.pos_emb is not None:
109
  n_params -= self.pos_emb.weight.numel()
110
- if self.tok_emb.weight is not self.lm_head.weight:
111
  n_params -= self.tok_emb.weight.numel()
112
  return n_params
113
 
114
  def _init_weights(self, module):
115
  if isinstance(module, nn.Linear):
116
  if hasattr(module, "SCALE_INIT"):
117
- self.std *= (2 * self.n_layers) ** -0.5
118
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
119
  if module.bias is not None:
120
  torch.nn.init.zeros_(module.bias)
121
  elif isinstance(module, nn.Embedding):
122
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
123
- elif isinstance(module, Attention):
124
- torch.nn.init.xavier_normal_(module.attn.weight)
125
- torch.nn.init.xavier_normal_(module.o_proj.weight)
126
- if module.attn.bias is not None:
127
- torch.nn.init.zeros_(module.attn.bias)
128
- if module.o_proj.bias is not None:
129
- torch.nn.init.zeros_(module.o_proj.bias)
130
  @staticmethod
131
  def top_k_top_p_filtering(
132
  logits: torch.Tensor,
 
10
  from .layers import AttentionLayer
11
  from .configuration_minitransformer import MiniTransformerConfig
12
 
13
+ from .attn_masks import causal_mask
14
+ from .attn_mods import generate_tanh_softcap
15
+ from .rotary_emb import precompute_freqs_cis
16
+
17
  try:
18
  from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
19
  triton_norm = True
 
37
 
38
  def __init__(self, config) -> None:
39
  super(MiniTransformer, self).__init__(config)
40
+ self.num_layers = config.num_layers
41
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
42
+ self.head_dim = config.dim // config.num_heads
43
+ logit_softcap = generate_tanh_softcap(soft_cap=config.softcap)
44
+
45
+ # From pytorch/pytorch#123411, we set persistent=True for torch.compile and PP compatibility
46
+ self.register_buffer("freqs_cis", precompute_freqs_cis(
47
+ head_dim=self.head_dim,
48
+ max_seq_len=config.seq_len,
49
+ theta=config.theta,
50
+ ), persistent=True)
51
+
52
+ self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
 
53
  self.dropout = nn.Dropout(config.dropout)
54
 
55
  self.layers = nn.ModuleList()
56
+ for _ in range(self.num_layers):
57
+ layer = AttentionLayer(config, mask_mod=causal_mask, score_mod=logit_softcap)
58
+ self.layers.append(layer)
59
 
60
+ self.norm = nn.RMSNorm(config.dim)
61
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)
62
+ # self.tok_emb.weight = self.lm_head.weight
 
 
 
 
 
 
 
 
63
 
64
+ self.std = (config.dim) ** -0.5
65
  self.apply(self._init_weights)
66
  print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
67
 
 
73
  ) -> CausalLMOutput:
74
  # Compute embeddings
75
  tok_emb = self.tok_emb(input_ids)
 
76
 
 
77
  for layer in self.layers:
78
+ tok_emb = layer(tok_emb, self.freqs_cis)
79
 
80
  # Normalize and project to vocabulary
81
+ tok_emb = self.norm(tok_emb)
82
+ logits = self.lm_head(tok_emb)
83
 
84
  loss = None
85
  if labels is not None:
 
101
  n_params = sum(p.numel() for p in self.parameters())
102
  if hasattr(self, "pos_emb") and self.pos_emb is not None:
103
  n_params -= self.pos_emb.weight.numel()
104
+ if self.tok_emb.weight is self.lm_head.weight:
105
  n_params -= self.tok_emb.weight.numel()
106
  return n_params
107
 
108
  def _init_weights(self, module):
109
  if isinstance(module, nn.Linear):
110
  if hasattr(module, "SCALE_INIT"):
111
+ self.std *= (2 * self.num_layers) ** -0.5
112
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
113
  if module.bias is not None:
114
  torch.nn.init.zeros_(module.bias)
115
  elif isinstance(module, nn.Embedding):
116
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
117
+
 
 
 
 
 
 
118
  @staticmethod
119
  def top_k_top_p_filtering(
120
  logits: torch.Tensor,