Zymrael commited on
Commit
27140ac
·
1 Parent(s): aef565b
Files changed (11) hide show
  1. cache.py +44 -0
  2. config.json +4 -4
  3. configuration_hyena.py +92 -0
  4. engine.py +389 -0
  5. layers.py +155 -0
  6. model.py +472 -0
  7. modeling_hyena.py +145 -0
  8. positional_embeddings.py +113 -0
  9. streamer.py +106 -0
  10. tokenizer.py +116 -0
  11. utils.py +96 -0
cache.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ from torch import Tensor
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+
9
+
10
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
11
+ @dataclass
12
+ class InferenceParams:
13
+ """Inference parameters that are passed to the main model in order
14
+ to efficienly calculate and store the context during inference."""
15
+
16
+ max_seqlen: int
17
+ max_batch_size: int
18
+ seqlen_offset: int = 0
19
+ batch_size_offset: int = 0
20
+ key_value_memory_dict: dict = field(default_factory=dict)
21
+ lengths_per_sample: Optional[Tensor] = None
22
+
23
+ def reset(self, max_seqlen, max_batch_size):
24
+ self.max_seqlen = max_seqlen
25
+ self.max_batch_size = max_batch_size
26
+ self.seqlen_offset = 0
27
+ if self.lengths_per_sample is not None:
28
+ self.lengths_per_sample.zero_()
29
+
30
+
31
+ @dataclass
32
+ class RecurrentInferenceParams:
33
+ """Inference parameters passed to blocks with recurrent mode."""
34
+
35
+ fir_filter_length: int = 3
36
+ state_dim: int = 16
37
+ seqlen_offset: int = 0
38
+ fir_state_dict: dict = field(default_factory=dict)
39
+ state_dict: dict = field(default_factory=dict)
40
+
41
+ def reset(self):
42
+ self.fir_filter_length = 3
43
+ self.state_dim = 16
44
+ self.seqlen_offset = 0
config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
- "_commit_hash": "1cc23830f62c268082475776fb449af8428eb703",
3
- "_name_or_path": "LongSafari/Evo-1",
4
  "architectures": [
5
  "StripedHyenaModelForCausalLM"
6
  ],
@@ -10,8 +10,8 @@
10
  24
11
  ],
12
  "auto_map": {
13
- "AutoConfig": "LongSafari/Evo-1--configuration_hyena.StripedHyenaConfig",
14
- "AutoModelForCausalLM": "LongSafari/Evo-1--modeling_hyena.StripedHyenaModelForCausalLM"
15
  },
16
  "column_split": false,
17
  "column_split_hyena": true,
 
1
  {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "togethercomputer/evo-1-phase-2",
4
  "architectures": [
5
  "StripedHyenaModelForCausalLM"
6
  ],
 
10
  24
11
  ],
12
  "auto_map": {
13
+ "AutoConfig": "configuration_hyena.StripedHyenaConfig",
14
+ "AutoModelForCausalLM": "modeling_hyena.StripedHyenaModelForCausalLM"
15
  },
16
  "column_split": false,
17
  "column_split_hyena": true,
configuration_hyena.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class StripedHyenaConfig(PretrainedConfig):
6
+ model_type = "stripedhyena"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ num_filters=4096,
13
+ inner_mlp_size=14336,
14
+ attn_layer_idxs=[],
15
+ hyena_layer_idxs=[],
16
+ num_layers=32,
17
+ tie_embeddings=False,
18
+ short_filter_length=3,
19
+ num_attention_heads=32,
20
+ proj_groups=4,
21
+ hyena_filter_groups=1,
22
+ split_k0=True,
23
+ column_split_hyena=True,
24
+ column_split=False,
25
+ model_parallel_size=1,
26
+ pipe_parallel_size=1,
27
+ short_filter_bias=True,
28
+ mha_out_proj_bias=False,
29
+ qkv_proj_bias=False,
30
+ final_norm=True,
31
+ use_cache=True,
32
+ use_flash_attention_2=True,
33
+ use_flash_rmsnorm=True,
34
+ use_flash_depthwise=False,
35
+ use_flashfft=False,
36
+ inference_mode=False,
37
+ prefill_style="fft",
38
+ max_seqlen=32768,
39
+ eps=1e-5,
40
+ state_size=2,
41
+ rotary_emb_base=500000,
42
+ smeared_gqa=False,
43
+ make_vocab_size_divisible_by=8,
44
+ log_intermediate_values=False,
45
+ **kwargs,
46
+ ):
47
+ self.vocab_size = vocab_size
48
+ self.hidden_size = hidden_size
49
+ self.num_filters = num_filters
50
+ self.inner_mlp_size = inner_mlp_size
51
+ self.attn_layer_idxs = attn_layer_idxs
52
+ self.hyena_layer_idxs = hyena_layer_idxs
53
+ self.num_layers = num_layers
54
+ self.tie_embeddings = tie_embeddings
55
+ self.short_filter_length = short_filter_length
56
+ self.num_attention_heads = num_attention_heads
57
+ self.proj_groups = proj_groups
58
+ self.hyena_filter_groups = hyena_filter_groups
59
+ self.split_k0 = split_k0
60
+ self.column_split_hyena = column_split_hyena
61
+ self.column_split = column_split
62
+ self.model_parallel_size = model_parallel_size
63
+ self.pipe_parallel_size = pipe_parallel_size
64
+ self.short_filter_bias = short_filter_bias
65
+ self.mha_out_proj_bias = mha_out_proj_bias
66
+ self.qkv_proj_bias = qkv_proj_bias
67
+ self.final_norm = final_norm
68
+ self.use_cache = use_cache
69
+ self.use_flash_attention_2 = use_flash_attention_2
70
+ self.use_flash_rmsnorm = use_flash_rmsnorm
71
+ self.use_flash_depthwise = use_flash_depthwise
72
+ self.use_flashfft = use_flashfft
73
+ self.inference_mode = inference_mode
74
+ self.prefill_style = prefill_style
75
+ self.max_seqlen = max_seqlen
76
+ self.eps = eps
77
+ self.state_size = state_size
78
+ self.rotary_emb_base = rotary_emb_base
79
+ self.smeared_gqa = smeared_gqa
80
+ self.make_vocab_size_divisible_by = make_vocab_size_divisible_by
81
+ self.log_intermediate_values = log_intermediate_values
82
+ super().__init__(**kwargs)
83
+
84
+ def to_dict(self):
85
+ return {attr: getattr(self, attr) for attr in self.__dict__}
86
+
87
+ @classmethod
88
+ def from_original_config(cls, config_path, **kwargs):
89
+ with open(config_path, "r") as f:
90
+ config = json.load(f)
91
+
92
+ return cls(**config, **kwargs)
engine.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ import gc
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ try:
12
+ import conv1d_cpp
13
+ except:
14
+ pass
15
+ from .utils import column_split
16
+
17
+ IIR_PREFILL_MODES = [
18
+ "recurrence",
19
+ "modal-fft",
20
+ "hybrid-modal-recurrence",
21
+ "modal-scan",
22
+ "canonical-fft",
23
+ "iir-fir-caching",
24
+ ]
25
+
26
+
27
+ def canonicalize_modal_system(poles, residues):
28
+ """Canonicalize a modal system.
29
+
30
+ Args:
31
+ poles (Tensor): The poles of the system.
32
+ residues (Tensor): The residues of the system.
33
+
34
+ Returns:
35
+ Tuple[Tensor, Tensor]: The canonicalized poles and residues.
36
+ """
37
+ raise NotImplementedError
38
+
39
+
40
+ def list_tensors(idx):
41
+ for obj in gc.get_objects():
42
+ try:
43
+ if torch.is_tensor(obj) and isinstance(obj, torch.Tensor):
44
+ # dump to log
45
+ print(type(obj), obj.size())
46
+ el = obj[0]
47
+ with open(f"tensors_{idx}.txt", "a") as f:
48
+ f.write(f"{type(obj)} {obj.size()} {el}\n")
49
+ except Exception as e:
50
+ pass
51
+
52
+
53
+ class HyenaInferenceEngine:
54
+ def __init__(
55
+ self,
56
+ fir_fn=None,
57
+ iir_prefill_style="modal-fft",
58
+ layer_idx=None,
59
+ ) -> None:
60
+ self.fir_fn = fir_fn
61
+ assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
62
+ self.iir_prefill_style = iir_prefill_style
63
+ self.layer_idx = layer_idx
64
+ self.low_mem_mode = False
65
+
66
+ def parallel_fir(
67
+ self,
68
+ fir_fn,
69
+ u,
70
+ weight,
71
+ bias,
72
+ L,
73
+ fir_length=3,
74
+ inference_params=None,
75
+ prefill_mode=None,
76
+ padding_mask=None,
77
+ ):
78
+ """Compute the output state of the long convolutional filter."""
79
+ # prepare input layout, dimensions and dispatch to fir kernel
80
+ if fir_fn != torch.nn.functional.conv1d:
81
+ z_pre = fir_fn(u)[:, :L] # B, L, D
82
+ z_pre = z_pre.permute(0, 2, 1)
83
+ else:
84
+ u = u.permute(0, 2, 1) # B, D, L
85
+ z_pre = fir_fn(
86
+ u,
87
+ weight,
88
+ bias=None, # don't pass it here, add manually instead! source of small error
89
+ stride=1,
90
+ padding=fir_length - 1,
91
+ groups=u.shape[1],
92
+ )[..., :L]
93
+
94
+ # add manually instead! source of small error
95
+ z_pre = z_pre + bias[None, :, None]
96
+
97
+ # handle padding post fir, the only place with biases
98
+ if type(padding_mask) == torch.Tensor:
99
+ z_pre = z_pre * padding_mask[:, None]
100
+
101
+ if inference_params is not None:
102
+ # handle seqlen last and dim last cases for `u`
103
+ if fir_fn != torch.nn.functional.conv1d:
104
+ fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1)
105
+ else:
106
+ fir_state = u[..., -fir_length + 1 :]
107
+ else:
108
+ fir_state = None
109
+
110
+ return z_pre, fir_state
111
+
112
+ def parallel_iir(
113
+ self,
114
+ z_pre,
115
+ h,
116
+ D,
117
+ L,
118
+ poles,
119
+ residues,
120
+ t,
121
+ dims,
122
+ layer_idx,
123
+ inference_params=None,
124
+ prefill_style="fft",
125
+ fftconv_fn=None,
126
+ padding_mask=None,
127
+ use_flashfft=False,
128
+ column_split_hyena=False,
129
+ long_fir_threshold=None,
130
+ ):
131
+ """Compute the output state of the short convolutional filter."""
132
+ fft_size = 2 * L
133
+ hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
134
+ # Compatibility with training infra that column splits the projections
135
+ if column_split_hyena:
136
+ z = z_pre.reshape(
137
+ z_pre.shape[0],
138
+ num_attention_heads,
139
+ 3 * hidden_size_per_attention_head,
140
+ z_pre.shape[2],
141
+ )
142
+ x2, x1, v = (
143
+ z[:, :, :hidden_size_per_attention_head],
144
+ z[
145
+ :,
146
+ :,
147
+ hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
148
+ ],
149
+ z[:, :, 2 * hidden_size_per_attention_head :],
150
+ )
151
+ x2, x1, v = (
152
+ x2.reshape(x2.shape[0], -1, x2.shape[-1]),
153
+ x1.reshape(x1.shape[0], -1, x1.shape[-1]),
154
+ v.reshape(v.shape[0], -1, v.shape[-1]),
155
+ )
156
+ else:
157
+ x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
158
+
159
+ x1v = x1 * v
160
+
161
+ if inference_params is not None and prefill_style == "recurrence":
162
+ y = self.prefill_via_direct_recurrence(
163
+ inference_params=inference_params,
164
+ x1v=x1v,
165
+ L=L,
166
+ poles=poles,
167
+ residues=residues,
168
+ )
169
+
170
+ else:
171
+ if use_flashfft and (L % 2) == 0: # only works with even L
172
+ y = fftconv_fn(
173
+ x1v.to(dtype=torch.bfloat16).contiguous(),
174
+ h.to(dtype=torch.float32),
175
+ )
176
+ X_s = None
177
+
178
+ elif long_fir_threshold is None:
179
+ H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
180
+ X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
181
+ X = X_s[..., : H.shape[-1]]
182
+ if len(z_pre.shape) > 3:
183
+ H = H.unsqueeze(1)
184
+ y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
185
+
186
+ else:
187
+ assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
188
+ h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
189
+ h = h[..., :long_fir_threshold]
190
+ y = F.conv1d(
191
+ x1v,
192
+ h.to(dtype=x1v.dtype),
193
+ stride=1,
194
+ groups=x1v.shape[1],
195
+ padding=h.shape[-1] - 1,
196
+ )[..., :L]
197
+
198
+ y = y.to(dtype=x1v.dtype)
199
+ y = (y + x1v * D.unsqueeze(-1)) * x2
200
+
201
+ if inference_params is not None:
202
+ if prefill_style == "fft":
203
+ self.prefill_via_modal_fft(
204
+ inference_params=inference_params,
205
+ x1v=x1v,
206
+ X_s=X_s,
207
+ L=L,
208
+ t=t,
209
+ poles=poles,
210
+ dims=dims,
211
+ layer_idx=layer_idx,
212
+ use_flashfft=use_flashfft,
213
+ fftconv_fn=fftconv_fn,
214
+ )
215
+
216
+ elif prefill_style == "recurrence":
217
+ # recurrent prefill is done before
218
+ pass
219
+ else:
220
+ raise NotImplementedError
221
+ if self.low_mem_mode:
222
+ # TODO: smarter gc
223
+ del z_pre, x2, x1, v, x1v, h, poles, residues
224
+ torch.cuda.empty_cache()
225
+
226
+ return y.permute(0, 2, 1)
227
+
228
+ def step_fir(self, u, fir_state, weight, bias=None):
229
+ """Step the FIR filter.
230
+
231
+ Note:
232
+ `fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...`
233
+ We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout).
234
+ """
235
+ h0, h = weight[..., 0, -1], weight[..., 0, :-1]
236
+ h0, h = h0[None], h[None]
237
+ y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias
238
+
239
+ # update
240
+ fir_state = torch.roll(fir_state, -1, dims=2)
241
+ fir_state[..., -1] = u
242
+ return y, fir_state
243
+
244
+ def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
245
+ x1v = x1 * v
246
+
247
+ residues, poles = (
248
+ torch.view_as_complex(residues.to(torch.float32)),
249
+ torch.view_as_complex(poles.to(torch.float32)),
250
+ )
251
+ # squeeze the dummy seqlen dimension
252
+ # D, state_dim, 1 -> 1, D, state_dim
253
+ residues, poles = residues[..., 0][None], poles[..., 0][None]
254
+ iir_state = poles * iir_state + x1v[..., None]
255
+
256
+ res_state = torch.sum(residues * iir_state, dim=-1).real
257
+
258
+ if iir_groups > 1:
259
+ raise NotImplementedError
260
+ y = x2 * (res_state + D * x1v)
261
+
262
+ return y, iir_state
263
+
264
+ def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
265
+ """Turns the IIR filter into a FIR and uses a cache for decoding."""
266
+ raise NotImplementedError(":)")
267
+
268
+ def prefill_via_direct_recurrence(
269
+ self, inference_params, x1v, L, residues, poles, *args, **kwargs
270
+ ) -> torch.Tensor:
271
+ """
272
+ Compute the IIR state via explicit SSM recurrence (modal form)
273
+
274
+ This is the most memory efficient prefilling method for Hyena filters.
275
+
276
+ Note:
277
+ dtypes: [state: float32, poles: float32, x1v: bfloat16, output: bfloat16]
278
+ """
279
+ state_dim = poles.shape[1]
280
+ x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
281
+ x1v_ = x1v_.repeat(1, 1, 1, state_dim, 2) # b, d, l, sdim, reim
282
+ x1v_[..., 1] = 0
283
+
284
+ state = 0 * x1v_[:, :, 0]
285
+ output = 0 * x1v_[:, :, :, 0, 0] # b, d, l
286
+
287
+ # suppress dummy seqlen dimension
288
+ poles = poles[:, :, 0][None]
289
+ residues = residues[:, :, 0][None].repeat(x1v_.shape[0], 1, 1, 1) # b, d, sdim, reim
290
+
291
+ # state: b, d, sdim, reim
292
+ # poles: 1, d, sdim, reim
293
+ # x1v_: b, d, l, sdim, reim
294
+ for i in range(L):
295
+ state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0]
296
+ state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1]
297
+ output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real
298
+
299
+ inference_params.state_dict[self.layer_idx] = torch.view_as_complex(state.to(dtype=torch.float32))
300
+
301
+ return output
302
+
303
+ def prefill_via_hybrid_recurrence(self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs):
304
+ """
305
+ Compute the IIR state via hybrid recurrence-convolution over blocks
306
+ """
307
+ raise NotImplementedError(":)")
308
+
309
+ def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
310
+ raise NotImplementedError
311
+
312
+ def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
313
+ """
314
+ Compute the IIR state via a single FFT with the denominator of the SSM in companion form.
315
+
316
+ This is the most memory efficient "parallelized" prefilling method for Hyena.
317
+
318
+ From: https://arxiv.org/abs/2310.18780
319
+ """
320
+ raise NotImplementedError(":)")
321
+
322
+ def prefill_via_modal_fft(
323
+ self,
324
+ inference_params,
325
+ x1v,
326
+ L,
327
+ poles,
328
+ t,
329
+ dims,
330
+ layer_idx,
331
+ X_s=None,
332
+ use_flashfft=False,
333
+ fftconv_fn=None,
334
+ state_dtype=torch.complex64,
335
+ *args,
336
+ **kwargs,
337
+ ):
338
+ """
339
+ Compute the IIR state via a single FFT, using the poles of the SSM in modal form.
340
+ """
341
+ # When the model has a long convolution derived from a SSM in modal form and prefill_style is "fft",
342
+ # we split the filter into poles and residues and reuse FFT computation on the input.
343
+ # This optimization is currently not supported when using flashfftconv.
344
+ hidden_size, _, _, state_size, hyena_filter_groups = dims
345
+
346
+ if use_flashfft:
347
+ # using real states
348
+ poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None]
349
+
350
+ state_s = poles**t
351
+ if hyena_filter_groups > 1:
352
+ raise NotImplementedError
353
+
354
+ x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1)
355
+ x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1])
356
+ state_s = state_s[None]
357
+
358
+ state = fftconv_fn(
359
+ x1v.contiguous(),
360
+ state_s.to(dtype=torch.float32),
361
+ )
362
+ state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2)
363
+ state = torch.view_as_complex(state.contiguous().to(dtype=torch.float32))
364
+ inference_params.state_dict[self.layer_idx] = state
365
+ else:
366
+ assert X_s is not None
367
+ bs = x1v.shape[0]
368
+ fft_size = 2 * L
369
+ poles = torch.view_as_complex(poles.to(torch.float32))
370
+ state_s = poles**t
371
+ state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
372
+ if hyena_filter_groups > 1:
373
+ state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
374
+ state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
375
+ inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
376
+
377
+ def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
378
+ """
379
+ Compute the IIR state given an input `u` and log_poles of the modal system.
380
+ """
381
+ bs = u.shape[0]
382
+ fft_size = 2 * L
383
+ U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
384
+ fft_size = 2 * L
385
+ x = (log_poles * t).exp()
386
+ # [batch, hidden_size, state_dim, 2 * seqlen]
387
+ X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
388
+ state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
389
+ return state
layers.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from .utils import grab_first_if_tuple
10
+
11
+ def grab_first_if_tuple(x):
12
+ if x.__class__.__name__ == "tuple":
13
+ return x[0]
14
+ else:
15
+ return x
16
+
17
+ class RMSNorm(torch.nn.Module):
18
+ def __init__(self, config):
19
+ super(RMSNorm, self).__init__()
20
+ self.eps, self.hidden_size = config.eps, config.hidden_size
21
+ self.scale = torch.nn.Parameter(torch.ones(self.hidden_size))
22
+ self.register_parameter("scale", self.scale)
23
+ self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
24
+
25
+ if self.use_flash_rmsnorm:
26
+ try:
27
+ from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
28
+
29
+ self.rmsnorm_func = rmsnorm_func
30
+ except:
31
+ raise ImportError(
32
+ "For `use_flash_rmsnorm`: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/layer_norm`"
33
+ )
34
+
35
+ def forward(self, x):
36
+ if self.use_flash_rmsnorm:
37
+ return self.rmsnorm_func(x, self.scale, self.eps)
38
+ else:
39
+ y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
40
+ return self.scale * y
41
+
42
+
43
+ class ParallelGatedMLP(nn.Module):
44
+ def __init__(
45
+ self,
46
+ config,
47
+ ):
48
+ super().__init__()
49
+
50
+ multiple_of = config.get("inner_size_multiple_of", 64)
51
+ self.act_type = config.get("mlp_activation", "silu")
52
+ if self.act_type == "gelu":
53
+ self.act = F.gelu
54
+ elif self.act_type == "silu":
55
+ self.act = F.silu
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ self.multiple_of = multiple_of * config.model_parallel_size
60
+
61
+ inner_size = int(2 * config.hidden_size * 4 / 3)
62
+ inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
63
+ if config.get("inner_mlp_size", None) is not None:
64
+ inner_size = config.inner_mlp_size
65
+
66
+ self.l1 = nn.Linear(
67
+ in_features=config.hidden_size,
68
+ out_features=inner_size,
69
+ bias=False,
70
+ )
71
+ self.l2 = nn.Linear(
72
+ in_features=config.hidden_size,
73
+ out_features=inner_size,
74
+ bias=False,
75
+ )
76
+ self.l3 = nn.Linear(
77
+ in_features=inner_size,
78
+ out_features=config.hidden_size,
79
+ bias=False,
80
+ )
81
+
82
+ def forward(self, z):
83
+ z1, z2 = self.l1(z), self.l2(z)
84
+ z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2)
85
+ y = self.l3(self.act(z1) * z2)
86
+ return grab_first_if_tuple(y)
87
+
88
+
89
+ class Embedding(nn.Module):
90
+ _train_dtype = "bf16"
91
+
92
+ def __init__(self, config):
93
+ super().__init__()
94
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
95
+
96
+ def embed(self, input_ids, position_ids=None, tokentype_ids=None):
97
+ embeddings = self.word_embeddings(input_ids)
98
+ return embeddings
99
+
100
+ def unembed(self, u):
101
+ weight = self.word_embeddings.weight
102
+ return torch.matmul(u, weight)
103
+
104
+
105
+ class VocabParallelEmbedding(nn.Embedding):
106
+ "Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
107
+
108
+ def __init__(self, config):
109
+ vocab_size, process_group, padding_idx = (
110
+ config.vocab_size,
111
+ config.get("process_group", None),
112
+ config.get("padding_idx", None),
113
+ )
114
+ self.process_group = process_group
115
+ if process_group is not None:
116
+ world_size = torch.distributed.get_world_size(process_group)
117
+ if vocab_size % world_size != 0:
118
+ raise ValueError(
119
+ f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})"
120
+ )
121
+ if world_size > 1 and padding_idx is not None:
122
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
123
+ else:
124
+ world_size = 1
125
+ super().__init__(
126
+ vocab_size // world_size,
127
+ embedding_dim=config.hidden_size,
128
+ padding_idx=padding_idx,
129
+ )
130
+
131
+ def embed(self, x: Tensor) -> Tensor:
132
+ if self.process_group is None:
133
+ return self.forward(x)
134
+ else:
135
+ rank = torch.distributed.get_rank(self.process_group)
136
+ vocab_size = self.num_embeddings
137
+ vocab_start_index, vocab_end_index = (
138
+ rank * vocab_size,
139
+ (rank + 1) * vocab_size,
140
+ )
141
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
142
+ input_ids_mask = (x < vocab_start_index) | (x >= vocab_end_index)
143
+ x = x - vocab_start_index
144
+ x[input_ids_mask] = 0
145
+ embeddings = self.forward(x)
146
+ embeddings[input_ids_mask] = 0.0
147
+ # Reduce to the global process group
148
+ torch.distributed.all_reduce(embeddings, group=self.process_group)
149
+ return embeddings
150
+
151
+ def unembed(self, u: Tensor) -> Tensor:
152
+ if self.process_group is None:
153
+ return u @ self.weight.T
154
+ else:
155
+ raise NotImplementedError
model.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+ # Note: MP and PP utilities are removed for ease of use and editing.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .cache import InferenceParams, RecurrentInferenceParams
11
+ from .engine import HyenaInferenceEngine
12
+ from .layers import ParallelGatedMLP, RMSNorm, VocabParallelEmbedding
13
+ from .utils import column_split, print_rank_0
14
+
15
+ try:
16
+ from flash_attn.modules.mha import MHA
17
+ except ImportError:
18
+ "flash_attn not installed"
19
+
20
+ try:
21
+ from .positional_embeddings import swap_mha_rope
22
+ except ImportError:
23
+ "could not import swap_mha_rope from positional_embeddings.py"
24
+
25
+ # dummy import to force huggingface to bundle the tokenizer
26
+ from .tokenizer import ByteTokenizer
27
+
28
+
29
+ class AttentionBlock(nn.Module):
30
+ def __init__(self, config, layer_idx) -> None:
31
+ super().__init__()
32
+ self.config = config
33
+ self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
34
+ self.layer_idx = layer_idx
35
+ self.proj_groups = config.get("proj_groups", 1)
36
+ dtype = config.get("attn_block_dtype", torch.bfloat16)
37
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
38
+ self.num_attention_heads = config.num_attention_heads
39
+ self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
40
+
41
+ self.counter = 0
42
+ self.inner_mha_cls = MHA(
43
+ embed_dim=config.hidden_size,
44
+ num_heads=config.num_attention_heads,
45
+ num_heads_kv=config.num_attention_heads // self.proj_groups,
46
+ rotary_emb_dim=config.hidden_size // config.num_attention_heads,
47
+ qkv_proj_bias=config.get("qkv_proj_bias", True),
48
+ rotary_emb_base=config.get("rotary_emb_base", 10000),
49
+ causal=True,
50
+ layer_idx=layer_idx,
51
+ out_proj_bias=config.get("mha_out_proj_bias", True),
52
+ use_flash_attn=self.config.use_flash_attn,
53
+ ).to(dtype=dtype)
54
+
55
+ # check if using interpolated rotary pos emb from config, and swap the rope emb
56
+ if config.get("use_interpolated_rotary_pos_emb", False):
57
+ swap_mha_rope(
58
+ mha=self.inner_mha_cls,
59
+ kwargs_new_rope={'scaling_factor': config.get("rotary_emb_scaling_factor", 1.)},
60
+ )
61
+
62
+ if self.config.get("smeared_gqa", False):
63
+ self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
64
+ self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
65
+
66
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
67
+
68
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
69
+ if (
70
+ type(padding_mask) == torch.Tensor
71
+ ): # workaround for masking bug in FA. This works because Wqkv does not have bias
72
+ # and attention scores will be also automatically zeroed.
73
+ u = u * padding_mask[..., None]
74
+ u = (
75
+ self.inner_mha_cls(
76
+ self.pre_norm(u),
77
+ inference_params=inference_params,
78
+ )
79
+ + u
80
+ )
81
+ if type(padding_mask) == torch.Tensor: # guard against bias
82
+ u = u * padding_mask[..., None]
83
+ u = self.mlp(self.post_norm(u)) + u
84
+ return u, None
85
+
86
+
87
+ class ParallelHyenaFilter(nn.Module):
88
+ def __init__(self, config, layer_idx) -> None:
89
+ super().__init__()
90
+ self.config = config
91
+ self.layer_idx = layer_idx
92
+ self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size)
93
+
94
+ self.use_flashfft = config.get("use_flashfft", False)
95
+ self.state_size = config.state_size
96
+ self.hidden_size = config.hidden_size
97
+ self.num_filters = config.num_filters
98
+ self.inference_mode = config.get("inference_mode", True)
99
+ self.counter = 0
100
+ self.column_split_hyena = config.get("column_split_hyena", True)
101
+
102
+ assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
103
+
104
+ self.D = nn.Parameter(torch.zeros(self.hidden_size))
105
+
106
+ # attention heads are not used except to split post short_filter
107
+ # projections in the same way as the checkpoint
108
+ self.num_attention_heads = config.num_attention_heads
109
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
110
+
111
+ # after preprocessing here we can save the new checkpoint
112
+ self.short_filter_length = config.short_filter_length
113
+ self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length))
114
+ self.short_filter_bias = (
115
+ nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
116
+ )
117
+
118
+ self.engine = HyenaInferenceEngine(layer_idx=layer_idx)
119
+ self.use_flash_depthwise = config.get("use_flash_depthwise", False)
120
+ self.data_dtype = None
121
+
122
+ if self.use_flash_depthwise:
123
+ self.fir_fn = FlashDepthwiseConv1d(
124
+ channels=3 * self.hidden_size,
125
+ kernel_size=self.short_filter_length,
126
+ padding=self.short_filter_length - 1,
127
+ weights=self.short_filter_weight,
128
+ bias=self.short_filter_bias,
129
+ device=None,
130
+ dtype=self.config.get("depthwise_dtype", torch.bfloat16),
131
+ )
132
+ else:
133
+ self.fir_fn = F.conv1d
134
+
135
+ self.fftconv_fn = None
136
+ self.long_fir_threshold = config.get("long_fir_threshold", None)
137
+ if self.long_fir_threshold is not None:
138
+ assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft"
139
+
140
+ self.num_systems = self.hidden_size // self.hyena_filter_groups
141
+
142
+ poles = torch.randn(self.num_systems, self.state_size, 1, 2)
143
+
144
+ # TODO: bring over init from internals
145
+ poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1)
146
+ poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1)
147
+
148
+ self.poles = nn.Parameter(poles)
149
+
150
+ self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
151
+ self.h = None
152
+
153
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
154
+ if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys():
155
+ return self.sequential_forward(u, inference_params)
156
+
157
+ else:
158
+ return self.parallel_forward(u, inference_params, padding_mask)
159
+
160
+ def parallel_forward(self, u, inference_params=None, padding_mask=None):
161
+ L = u.shape[1]
162
+ z_pre, fir_state = self.engine.parallel_fir(
163
+ self.fir_fn,
164
+ u,
165
+ self.short_filter_weight,
166
+ self.short_filter_bias,
167
+ L,
168
+ fir_length=self.short_filter_length,
169
+ inference_params=inference_params,
170
+ padding_mask=padding_mask,
171
+ )
172
+ if inference_params:
173
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
174
+
175
+ if self.h is None:
176
+ h, filter_dtype, poles, residues = self.compute_filter(L, u.device)
177
+ else:
178
+ h = self.h
179
+ filter_dtype = self.h.dtype
180
+
181
+ if self.hyena_filter_groups > 1:
182
+ h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1)
183
+
184
+ # if inference_params is not None, we plan to perform generation:
185
+ # prefilling is handled by the engine.
186
+ dims = (
187
+ self.hidden_size,
188
+ self.num_attention_heads,
189
+ self.hidden_size_per_attention_head,
190
+ self.state_size,
191
+ self.hyena_filter_groups,
192
+ )
193
+ y = self.engine.parallel_iir(
194
+ z_pre,
195
+ h,
196
+ self.D,
197
+ L,
198
+ t=self.t,
199
+ poles=self.poles,
200
+ residues=self.residues,
201
+ dims=dims,
202
+ inference_params=inference_params,
203
+ layer_idx=self.layer_idx,
204
+ prefill_style=self.config.get("prefill_style", "fft"),
205
+ use_flashfft=self.use_flashfft,
206
+ fftconv_fn=self.fftconv_fn,
207
+ column_split_hyena=self.column_split_hyena,
208
+ long_fir_threshold=self.long_fir_threshold,
209
+ padding_mask=padding_mask,
210
+ )
211
+
212
+ return y, inference_params
213
+
214
+ def sequential_forward(self, u, inference_params):
215
+ if self.data_dtype is None:
216
+ self.data_dtype = u.dtype
217
+ if len(u.shape) > 2:
218
+ u = u[:, -1]
219
+
220
+ fir_state, iir_state = (
221
+ inference_params.fir_state_dict[self.layer_idx],
222
+ inference_params.state_dict[self.layer_idx],
223
+ )
224
+
225
+ z_pre, fir_state = self.engine.step_fir(
226
+ u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias
227
+ )
228
+ x2, x1, v = (
229
+ column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
230
+ if self.column_split_hyena
231
+ else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
232
+ )
233
+
234
+ y, iir_state = self.engine.step_iir(
235
+ x2,
236
+ x1,
237
+ v,
238
+ self.D,
239
+ self.residues,
240
+ self.poles,
241
+ iir_state,
242
+ iir_groups=self.hyena_filter_groups,
243
+ )
244
+
245
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
246
+ inference_params.state_dict[self.layer_idx] = iir_state
247
+ y = y.to(dtype=self.data_dtype)
248
+ return y[:, None], inference_params
249
+
250
+ def update_time(self, L, device):
251
+ """
252
+ Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
253
+ If L is greater than the length of the previous batch, then the time vector is
254
+ reinitialized. Otherwise, the time vector is truncated from cache.
255
+ """
256
+ if not hasattr(self, "t"):
257
+ self.t = torch.arange(L, device=device)[None, None]
258
+ elif self.t.shape[-1] < L:
259
+ self.t = torch.arange(L, device=device)[None, None]
260
+ else:
261
+ self.t = self.t[..., :L]
262
+
263
+ def compute_filter(self, L, device):
264
+ self.update_time(L, device)
265
+ filter_dtype = torch.float32
266
+ residues, log_poles = (
267
+ torch.view_as_complex(self.residues.to(filter_dtype)),
268
+ torch.view_as_complex(self.poles.to(filter_dtype)).log(),
269
+ )
270
+ h = (residues * (log_poles * self.t).exp()).real.sum(1)[None]
271
+ return h, filter_dtype, log_poles, residues
272
+
273
+
274
+ class ParallelGatedConvBlock(nn.Module):
275
+ def __init__(self, config, layer_idx) -> None:
276
+ super().__init__()
277
+ self.config = config
278
+ self.layer_idx = layer_idx
279
+ self.low_mem_mode = config.get("low_mem_mode", False)
280
+ dtype = config.get("hyena_block_dtype", torch.float32)
281
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
282
+ self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(dtype=dtype)
283
+ self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype)
284
+ self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size)
285
+ self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype)
286
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
287
+
288
+ self.proj_norm_fn = self.proj_norm
289
+ self.res_mlp_norm_fn = self.res_mlp_norm
290
+
291
+ if self.config.get("compile", False):
292
+ self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
293
+ self.res_mlp_norm_fn = torch.compile(
294
+ self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead"
295
+ )
296
+
297
+ def proj_norm(self, x):
298
+ return self.projections(self.pre_norm(x))
299
+
300
+ def res_mlp_norm(self, x):
301
+ return self.mlp(self.post_norm(x)) + x
302
+
303
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
304
+ z = self.proj_norm_fn(u)
305
+
306
+ if type(padding_mask) == torch.Tensor: # guard against bias
307
+ z = z * padding_mask[..., None]
308
+
309
+ z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
310
+
311
+ z_in = self.out_filter_dense(z) + u
312
+
313
+ if type(padding_mask) == torch.Tensor: # guard against bias
314
+ z_in = z_in * padding_mask[..., None]
315
+
316
+ y = self.res_mlp_norm_fn(z_in)
317
+
318
+ return y, inference_params
319
+
320
+
321
+ def get_block(config, layer_idx, flash_fft=None):
322
+ if layer_idx in config.attn_layer_idxs:
323
+ return AttentionBlock(config, layer_idx)
324
+ elif layer_idx in config.hyena_layer_idxs:
325
+ block = ParallelGatedConvBlock(config, layer_idx)
326
+ if config.get("use_flashfft", "False"):
327
+ block.filter.fftconv_fn = flash_fft
328
+ return block
329
+ else:
330
+ raise NotImplementedError
331
+
332
+
333
+ class StripedHyena(nn.Module):
334
+ def __init__(self, config):
335
+ super().__init__()
336
+ self.config = config
337
+ self.embedding_layer = VocabParallelEmbedding(config)
338
+ self.norm = RMSNorm(config) if config.get("final_norm", True) else None
339
+ self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
340
+
341
+ if config.get("use_flashfft", "False"):
342
+ from flashfftconv import FlashFFTConv
343
+
344
+ self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
345
+ else:
346
+ self.flash_fft = None
347
+
348
+ self.blocks = nn.ModuleList(
349
+ get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
350
+ )
351
+
352
+ def forward(self, x, inference_params_dict=None, padding_mask=None):
353
+ L = x.shape[1]
354
+ x = self.embedding_layer.embed(x)
355
+ if inference_params_dict is not None:
356
+ x, inference_params_dict_out = self.stateful_forward(
357
+ x,
358
+ inference_params_dict=inference_params_dict,
359
+ )
360
+ else:
361
+ x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
362
+
363
+ x = self.norm(x)
364
+ x = self.unembed.unembed(x)
365
+ return x, inference_params_dict_out
366
+
367
+ def stateful_forward(self, x, inference_params_dict=None):
368
+ for block_idx, block in enumerate(self.blocks):
369
+ block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
370
+ inference_params = inference_params_dict[block_name]
371
+ x, _ = block(x, inference_params=inference_params)
372
+
373
+ return x, inference_params_dict
374
+
375
+ def stateless_forward(self, x, padding_mask=None):
376
+ if type(padding_mask) == torch.Tensor:
377
+ x = x * padding_mask[..., None]
378
+
379
+ for _, block in enumerate(self.blocks):
380
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
381
+ return x, None
382
+
383
+ def initialize_inference_params(self):
384
+ print_rank_0("Initializing inference params...")
385
+ inference_params_dict = {
386
+ "mha": InferenceParams(
387
+ max_seqlen=self.config.get("max_seqlen", 8192),
388
+ max_batch_size=self.config.get("max_batch_size", 1),
389
+ seqlen_offset=0,
390
+ ),
391
+ "hyena": RecurrentInferenceParams(
392
+ fir_filter_length=self.config.short_filter_length,
393
+ state_dim=self.config.state_size,
394
+ seqlen_offset=0,
395
+ ),
396
+ }
397
+ return inference_params_dict
398
+
399
+ def precompute_filters(self, L, device):
400
+ for block_idx, block in enumerate(self.blocks):
401
+ if type(block) == ParallelGatedConvBlock:
402
+ if type(block.filter) == ParallelHyenaFilter:
403
+ L = block.filter.long_fir_threshold or L
404
+ print_rank_0(f"Precomputing filters, L={L}...")
405
+
406
+ filter_dtype = torch.float16 if L >= 2048 else torch.float32
407
+
408
+ block.filter._set_time(L, device)
409
+ residues, poles = (
410
+ torch.view_as_complex(block.filter.residues.to(torch.float16)),
411
+ torch.view_as_complex(block.filter.poles.to(torch.float16)),
412
+ )
413
+
414
+ block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
415
+ block.filter.h = block.filter.h.to(dtype=filter_dtype)
416
+
417
+ def load_poles_residues(self, path):
418
+ "Load different poles and residues for each layer."
419
+ for block_idx, block in enumerate(self.blocks):
420
+ if type(block) == ParallelGatedConvBlock:
421
+ if type(block.filter) == ParallelHyenaFilter:
422
+ print(f"Loading poles and residues for block {block_idx}")
423
+ poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
424
+ poles = torch.view_as_real(poles)
425
+ residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu")
426
+ residues = torch.view_as_real(residues)
427
+ poles = poles.permute(1, 0, 2).unsqueeze(-2)
428
+ residues = residues.permute(1, 0, 2).unsqueeze(-2)
429
+
430
+ block.filter.poles = nn.Parameter(poles)
431
+ block.filter.residues = nn.Parameter(residues)
432
+
433
+ def to_bfloat16_except_poles_residues(self):
434
+ """Convert all parameters to bfloat16 except for the poles and residues.
435
+
436
+ Particularly important for longer prompts.
437
+ """
438
+ for k, p in self.named_parameters():
439
+ if "poles" not in k and "residues" not in k:
440
+ p.data = p.data.to(torch.bfloat16)
441
+
442
+ def load_from_split_converted_state_dict(self, path):
443
+
444
+ print("Loading from split converted state dict")
445
+
446
+ embedding_weight = torch.load(path + "/layer_00.pt")["word_embeddings.weight"]
447
+ self.embedding_layer.weight = nn.Parameter(embedding_weight.to(self.embedding_layer.weight.dtype))
448
+
449
+ print("Loading embedding weight ok")
450
+
451
+ if self.config.get("final_norm", False) is not None:
452
+ idx = len(self.blocks) + 1
453
+ final_norm_scale = torch.load(path + f"/layer_{idx:02d}.pt")["norm.scale"]
454
+ self.norm.scale = nn.Parameter(final_norm_scale.to(self.norm.scale.dtype))
455
+
456
+ print("loading final norm ok")
457
+
458
+ if not self.config.get("tie_embeddings", True):
459
+ idx = len(self.blocks) + 2
460
+ embedding_weight = torch.load(path + f"/layer_{idx:02d}.pt")["word_embeddings.weight"]
461
+ self.unembed.weight = nn.Parameter(embedding_weight.to(self.unembed.weight.dtype))
462
+
463
+ print("loading unembed weight ok")
464
+
465
+ for block_idx, block in enumerate(self.blocks):
466
+ print("loading block {}...".format(block_idx))
467
+ # strict = False if type(block) == ParallelGatedConvBlock else True
468
+ # some blocks (optionally) go through a round of conv distillation on some parameters
469
+ strict = True # safer to be strict and account for every layer
470
+
471
+ loaded_dict = torch.load(path + f"/layer_{block_idx + 1:02d}.pt")
472
+ block.load_state_dict(loaded_dict, strict=strict)
modeling_hyena.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """StripedHyena custom code port for the Hugging Face Hub"""
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from .configuration_hyena import StripedHyenaConfig
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
9
+ from transformers.utils import logging
10
+ from typing import Optional, Tuple, Union
11
+ from .model import StripedHyena
12
+ from .utils import dotdict
13
+ from .cache import InferenceParams
14
+ from .engine import HyenaInferenceEngine
15
+ from .layers import RMSNorm
16
+ from .utils import dotdict, column_split
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class StripedHyenaPreTrainedModel(PreTrainedModel):
22
+ config_class = StripedHyenaConfig
23
+ base_model_prefix = "sh"
24
+ supports_gradient_checkpointing = False
25
+ _no_split_modules = ["AttentionBlock", "ParallelGatedConvBlock"]
26
+ _skip_keys_device_placement = "past_key_values"
27
+ _keys_to_ignore_on_load_missing = [r"freq"]
28
+ _keys_to_ignore_on_load_unexpected = [r"fftconv", r"twiddle_factors"]
29
+ _supports_flash_attn_2 = True
30
+
31
+
32
+ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
33
+ supports_gradient_checkpointing = True
34
+
35
+ def __init__(self, config, **kwargs):
36
+ super().__init__(config, **kwargs)
37
+ model_config = dotdict(config.to_dict())
38
+ self.backbone = StripedHyena(model_config)
39
+ self.backbone.gradient_checkpointing = False
40
+ self.config = config
41
+ vocab_size = config.vocab_size
42
+ if vocab_size % config.make_vocab_size_divisible_by != 0:
43
+ vocab_size += config.make_vocab_size_divisible_by - (
44
+ vocab_size % config.make_vocab_size_divisible_by
45
+ )
46
+ self.vocab_size = vocab_size
47
+ self.post_init()
48
+ self.force_dtype()
49
+
50
+ def force_dtype(self):
51
+ self.backbone.to_bfloat16_except_poles_residues()
52
+
53
+ def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
54
+ self.backbone.gradient_checkpointing = enable
55
+
56
+ def get_input_embeddings(self):
57
+ return self.backbone.embedding_layer
58
+
59
+ def forward(
60
+ self,
61
+ input_ids: torch.LongTensor = None,
62
+ attention_mask: Optional[torch.LongTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ past_key_values=None,
68
+ return_dict: Optional[bool] = None,
69
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
70
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
71
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
72
+
73
+ if use_cache:
74
+ if self.backbone.gradient_checkpointing and self.backbone.training:
75
+ logger.warning_once(
76
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
77
+ )
78
+ use_cache = False
79
+ elif labels is not None:
80
+ logger.warning_once(
81
+ "`use_cache=True` is incompatible with loss calculation. Setting `use_cache=False`..."
82
+ )
83
+ use_cache = False
84
+
85
+ inputs = input_ids
86
+ if use_cache:
87
+ if past_key_values is None:
88
+ past_key_values = self.backbone.initialize_inference_params()
89
+
90
+ batch_size = input_ids.shape[0]
91
+ past_key_values["mha"].max_batch_size = batch_size
92
+ past_key_values["hyena"].max_batch_size = batch_size
93
+ else:
94
+ seqlen_offset = past_key_values["mha"].seqlen_offset
95
+ if seqlen_offset == 0:
96
+ # second loop through generate will have prompt_len + 1 as seqlen
97
+ seqlen_offset = input_ids.shape[-1] - 1
98
+ past_key_values["hyena"].seqlen_offset = seqlen_offset
99
+ past_key_values["mha"].seqlen_offset = seqlen_offset
100
+ else:
101
+ past_key_values["mha"].seqlen_offset += 1
102
+ past_key_values["hyena"].seqlen_offset += 1
103
+
104
+ inputs = input_ids[
105
+ :,
106
+ -1:,
107
+ ]
108
+
109
+ logits, past_key_values = self.backbone(
110
+ inputs,
111
+ padding_mask=attention_mask,
112
+ inference_params_dict=past_key_values if use_cache else None,
113
+ )
114
+
115
+ loss = None
116
+ if labels is not None:
117
+ shift_logits = logits[..., :-1, :].contiguous()
118
+ shift_labels = labels[..., 1:].contiguous()
119
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
120
+ shift_labels = shift_labels.view(-1)
121
+ shift_labels = shift_labels.to(shift_logits.device)
122
+ loss = F.cross_entropy(shift_logits, shift_labels)
123
+
124
+ if return_dict:
125
+ return CausalLMOutputWithPast(
126
+ logits=logits,
127
+ hidden_states=None,
128
+ past_key_values=past_key_values if use_cache else None,
129
+ loss=loss,
130
+ )
131
+ else:
132
+ return logits
133
+
134
+ @classmethod
135
+ def can_generate(cls) -> bool:
136
+ return True
137
+
138
+ def prepare_inputs_for_generation(
139
+ self, input_ids, attention_mask=None, past_key_values=None, **kwargs
140
+ ):
141
+ return {
142
+ "input_ids": input_ids,
143
+ "attention_mask": attention_mask,
144
+ "past_key_values": past_key_values,
145
+ }
positional_embeddings.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This software is distributed under the terms of the Apache License, Version 2.0
2
+ # Author: Armin Thomas, Eric Nguyen
3
+
4
+ import torch
5
+ import copy
6
+ from einops import rearrange
7
+ from flash_attn.layers.rotary import RotaryEmbedding
8
+ from flash_attn.modules.mha import MHA
9
+
10
+
11
+ # simple wrapper for flash-attn RoPE with linear scaling:
12
+ class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
13
+ def __init__(
14
+ self,
15
+ dim: int,
16
+ scaling_factor: float=1.,
17
+ base=10000.0,
18
+ interleaved=False,
19
+ scale_base=None,
20
+ pos_idx_in_fp32=True,
21
+ device=None,
22
+ ):
23
+ super().__init__(
24
+ dim=dim,
25
+ base=base,
26
+ interleaved=interleaved,
27
+ scale_base=scale_base,
28
+ pos_idx_in_fp32=pos_idx_in_fp32,
29
+ device=device
30
+ )
31
+ self._linear_scaling_factor = scaling_factor
32
+ # adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
33
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
34
+ # Reset the tables if the sequence length has changed,
35
+ # if we're on a new device (possibly due to tracing for instance),
36
+ # or if we're switching from inference mode to training
37
+ if (
38
+ seqlen > self._seq_len_cached
39
+ or self._cos_cached is None
40
+ or self._cos_cached.device != device
41
+ or self._cos_cached.dtype != dtype
42
+ or (self.training and self._cos_cached.is_inference())
43
+ ):
44
+ self._seq_len_cached = seqlen
45
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
46
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
47
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
48
+ if self.pos_idx_in_fp32:
49
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
50
+ # linear scaling:
51
+ t = t / self._linear_scaling_factor
52
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
53
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
54
+ # cos & sin output to change significantly.
55
+ # We want to recompute self.inv_freq if it was not loaded in fp32
56
+ if self.inv_freq.dtype != torch.float32:
57
+ inv_freq = self._compute_inv_freq(device=device)
58
+ else:
59
+ inv_freq = self.inv_freq
60
+ else:
61
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
62
+ # linear scaling:
63
+ t = t / self._linear_scaling_factor
64
+ inv_freq = self.inv_freq
65
+ # Don't do einsum, it converts fp32 to fp16 under AMP
66
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
67
+ freqs = torch.outer(t, inv_freq)
68
+ if self.scale is None:
69
+ self._cos_cached = torch.cos(freqs).to(dtype)
70
+ self._sin_cached = torch.sin(freqs).to(dtype)
71
+ else:
72
+ power = (
73
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
74
+ - seqlen // 2
75
+ ) / self.scale_base
76
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
77
+ # We want the multiplication by scale to happen in fp32
78
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
79
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
80
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
81
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
82
+
83
+ # swap out RoPE of existing mha:
84
+ def swap_mha_rope(
85
+ mha,
86
+ new_rope: torch.nn.Module=LinearlyScaledRotaryEmbedding,
87
+ kwargs_new_rope: dict=None
88
+ ):
89
+ # determine mha dtype and device:
90
+ dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
91
+ device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
92
+ # determine RoPE settings:
93
+ kwargs_old_rope = dict(
94
+ dim = mha.rotary_emb.dim,
95
+ base = mha.rotary_emb.base,
96
+ interleaved = mha.rotary_emb.interleaved,
97
+ scale_base = mha.rotary_emb.scale_base,
98
+ pos_idx_in_fp32 = mha.rotary_emb.pos_idx_in_fp32,
99
+ device = mha.rotary_emb.inv_freq.device
100
+ )
101
+ # delete old RoPE:
102
+ del mha.rotary_emb
103
+ # create new RoPE:
104
+ kwargs_new_rope = kwargs_new_rope or {'scaling_factor': 1.0}
105
+ scaled_rope = new_rope(
106
+ **kwargs_new_rope,
107
+ **kwargs_old_rope
108
+ ).to(dtype)
109
+ # attach new RoPE to mha:
110
+ mha.rotary_emb = scaled_rope
111
+ # make new sure RoPE is correctly registered:
112
+ assert isinstance(mha.rotary_emb, new_rope)
113
+ return mha
streamer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+
3
+
4
+ class BaseStreamer:
5
+ """
6
+ Base class from which `.generate()` streamers should inherit.
7
+ """
8
+
9
+ def put(self, value):
10
+ """Function that is called by `.generate()` to push new tokens"""
11
+ raise NotImplementedError()
12
+
13
+ def end(self):
14
+ """Function that is called by `.generate()` to signal the end of generation"""
15
+ raise NotImplementedError()
16
+
17
+
18
+ class ByteStreamer(BaseStreamer):
19
+ """
20
+ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
21
+
22
+ <Tip warning={true}>
23
+
24
+ The API for the streamer classes is still under development and may change in the future.
25
+
26
+ </Tip>
27
+
28
+ Parameters:
29
+ tokenizer (`AutoTokenizer`):
30
+ The tokenized used to decode the tokens.
31
+ skip_prompt (`bool`, *optional*, defaults to `False`):
32
+ Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
33
+ decode_kwargs (`dict`, *optional*):
34
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
35
+
36
+ Examples:
37
+
38
+ ```python
39
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
40
+
41
+ >>> tok = AutoTokenizer.from_pretrained("gpt2")
42
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
43
+ >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
44
+ >>> streamer = TextStreamer(tok)
45
+
46
+ >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
47
+ >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
48
+ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
49
+ ```
50
+ """
51
+
52
+ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
53
+ self.tokenizer = tokenizer
54
+ self.skip_prompt = skip_prompt
55
+ self.decode_kwargs = decode_kwargs
56
+
57
+ # variables used in the streaming process
58
+ self.token_cache = []
59
+ self.print_len = 0
60
+ self.next_tokens_are_prompt = True
61
+
62
+ def put(self, value):
63
+ """
64
+ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
65
+ """
66
+ if len(value.shape) > 1 and value.shape[0] > 1:
67
+ raise ValueError("TextStreamer only supports batch size 1")
68
+ elif len(value.shape) > 1:
69
+ value = value[0]
70
+
71
+ if self.skip_prompt and self.next_tokens_are_prompt:
72
+ self.next_tokens_are_prompt = False
73
+ return
74
+
75
+ # Add the new token to the cache and decodes the entire thing.
76
+ self.token_cache.extend(value.tolist())
77
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
78
+
79
+ # After the symbol for a new line, we flush the cache.
80
+ if text.endswith("\n"):
81
+ printable_text = text[self.print_len :]
82
+ self.token_cache = []
83
+ self.print_len = 0
84
+ else:
85
+ printable_text = text[self.print_len : self.print_len + 1]
86
+ self.print_len += len(printable_text)
87
+
88
+ self.on_finalized_text(printable_text)
89
+
90
+ def end(self):
91
+ """Flushes any remaining cache and prints a newline to stdout."""
92
+ # Flush the cache, if it exists
93
+ if len(self.token_cache) > 0:
94
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
95
+ printable_text = text[self.print_len :]
96
+ self.token_cache = []
97
+ self.print_len = 0
98
+ else:
99
+ printable_text = ""
100
+
101
+ self.next_tokens_are_prompt = True
102
+ self.on_finalized_text(printable_text, stream_end=True)
103
+
104
+ def on_finalized_text(self, text: str, stream_end: bool = False):
105
+ """Prints the new text to stdout. If the stream is ending, also prints a newline."""
106
+ print(text, flush=True, end="" if not stream_end else None)
tokenizer.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
2
+ from abc import ABC
3
+ import json
4
+ import pathlib
5
+
6
+ import torch
7
+ import tqdm
8
+ from tokenizers import Tokenizer
9
+ from abc import abstractmethod
10
+ from typing import Any, List, Union
11
+ import numpy as np
12
+
13
+
14
+ class HFAutoTokenizer:
15
+ def __init__(self, vocab_file):
16
+ self.tokenizer = Tokenizer.from_file(vocab_file)
17
+ self.eos = "</s>"
18
+ self.bos = "<s>"
19
+ self.eos_id = self.tokenize(self.eos)
20
+ self.bos_id = self.tokenize(self.bos)
21
+ self.vsize = 32000
22
+
23
+ def encode_to_list(self, text):
24
+ return self.tokenizer.encode(text, add_special_tokens=False)
25
+
26
+ def tokenize_file(self, input_file, output_file, verbose=False):
27
+ if verbose:
28
+ print(f"Tokenizing file: {input_file}")
29
+
30
+ if pathlib.Path(output_file).exists():
31
+ print(f"Output file {output_file} already exists, skipping")
32
+ return
33
+ with open(input_file, "r") as fin, open(output_file, "w") as fout:
34
+ for line in tqdm.tqdm(fin):
35
+ if verbose:
36
+ print(f"Tokenizing line: {line[-200:]}")
37
+ data = json.loads(line.strip())
38
+ if "text" not in data.keys():
39
+ break
40
+ tokenized_data = self.tokenize(data["text"])
41
+ fout.write(json.dumps({"tokens": tokenized_data}) + "\n")
42
+
43
+ def tokenize(self, text: str, *args, **kwargs):
44
+ ids = self.tokenizer.encode(text)
45
+ if type(ids) == list:
46
+ return torch.tensor(ids)
47
+ else:
48
+ return torch.tensor(ids.ids)
49
+
50
+ def tokenize_batch(self, text_batch):
51
+ return self.tokenizer.encode_batch(text_batch)
52
+
53
+ def detokenize(self, token_ids, skip_special_tokens=False):
54
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
55
+
56
+ def detokenize_batch(self, token_ids_batch, skip_special_tokens=False):
57
+ out = []
58
+ for token_ids in token_ids_batch:
59
+ out.append(
60
+ self.detokenize(
61
+ [t.item() for t in token_ids],
62
+ skip_special_tokens=skip_special_tokens,
63
+ )
64
+ )
65
+ return out
66
+
67
+ @property
68
+ def eod(self):
69
+ return self.eod_id
70
+
71
+ @property
72
+ def vocab_size(self):
73
+ return 32000
74
+
75
+
76
+ class ByteTokenizer:
77
+ """UTF-8 Encoder."""
78
+
79
+ def __init__(self):
80
+ self.vocab_size = 512
81
+ self.eod_id = 0
82
+ self.eos_id = 0
83
+ self.eos_token = 0
84
+ self.eos_token_id = 0
85
+ self.pad_id = 1
86
+
87
+ def clamp(self, n):
88
+ return max(32, min(n, self.vocab_size))
89
+
90
+ def decode_token(self, token: int):
91
+ return str(chr(self.clamp(token)))
92
+
93
+ def __call__(self, text: str, *args, **kwargs):
94
+ ids = torch.tensor(self.tokenize(text), dtype=torch.long).unsqueeze(0)
95
+ return {"input_ids": ids}
96
+
97
+ def tokenize(self, text: str):
98
+ return list(np.fromstring(text, dtype=np.uint8))
99
+
100
+ def tokenize_batch(self, text_batch: Union[List[str], str]):
101
+ if isinstance(text_batch, list):
102
+ return [self.tokenize(s) for s in text_batch]
103
+ else:
104
+ return self.tokenize(text_batch)
105
+
106
+ def decode(self, token_ids):
107
+ return "".join(list(map(self.decode_token, token_ids)))
108
+
109
+ def decode_batch(self, token_ids: Union[List[str], str]):
110
+ if isinstance(token_ids, list):
111
+ return [self.decode(s) for s in token_ids]
112
+ # elif if tensor, convert to list first
113
+ elif isinstance(token_ids, torch.Tensor):
114
+ return [self.decode(s) for s in token_ids.tolist()]
115
+ else:
116
+ return self.decode(token_ids)
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def grab_first_if_tuple(x):
5
+ if x.__class__.__name__ == "tuple":
6
+ return x[0]
7
+ else:
8
+ return x
9
+
10
+
11
+ def column_split(x, num_heads, head_size):
12
+ """Split a tensor with `num_heads` alongside the head dimension, instead of
13
+ across heads. Fixed to three projections
14
+ """
15
+
16
+ x_reshaped = x.reshape(
17
+ x.shape[0],
18
+ num_heads,
19
+ 3 * head_size,
20
+ )
21
+
22
+ x2, x1, v = (
23
+ x_reshaped[:, :, :head_size],
24
+ x_reshaped[
25
+ :,
26
+ :,
27
+ head_size : 2 * head_size,
28
+ ],
29
+ x_reshaped[:, :, 2 * head_size :],
30
+ )
31
+ x2, x1, v = (
32
+ x2.reshape(x2.shape[0], -1),
33
+ x1.reshape(x1.shape[0], -1),
34
+ v.reshape(v.shape[0], -1),
35
+ )
36
+ return x2, x1, v
37
+
38
+
39
+ def get_init_from_string(init_str):
40
+ if type(init_str) == str:
41
+ if init_str == "torch.nn.init.zeros_":
42
+ return torch.nn.init.zeros_
43
+ elif init_str == "torch.nn.init.xavier_uniform_":
44
+ return torch.nn.init.xavier_uniform_
45
+ elif init_str == "torch.nn.init.xavier_normal_":
46
+ return torch.nn.init.xavier_normal_
47
+ else:
48
+ raise ValueError(f"Unrecognized init {init_str}")
49
+
50
+
51
+ def print_rank_0(message, debug=False, end="\n"):
52
+ """Print from rank 0 only."""
53
+ if torch.distributed.is_initialized():
54
+ if torch.distributed.get_rank() == 0:
55
+ print(message, flush=True, end=end)
56
+ else:
57
+ print(message, flush=True, end=end)
58
+
59
+
60
+ class dotdict(dict):
61
+ """dot.notation access to dictionary attributes"""
62
+
63
+ __getattr__ = dict.get
64
+ __setattr__ = dict.__setitem__
65
+ __delattr__ = dict.__delitem__
66
+
67
+
68
+ def ensure_divisibility(numerator, denominator):
69
+ """Ensure that numerator is divisible by the denominator."""
70
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
71
+
72
+
73
+ def divide(numerator, denominator):
74
+ """Ensure that numerator is divisible by the denominator and return
75
+ the division value."""
76
+ ensure_divisibility(numerator, denominator)
77
+ return numerator // denominator
78
+
79
+
80
+ class VocabUtility:
81
+ """Split the vocabulary into `world_size` chunks amd return the
82
+ first and last index of the vocabulary belonging to the `rank`
83
+ partition: Note that indices in [first, last]"""
84
+
85
+ @staticmethod
86
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
87
+ index_f = rank * per_partition_vocab_size
88
+ index_l = index_f + per_partition_vocab_size
89
+ return index_f, index_l
90
+
91
+ @staticmethod
92
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
93
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
94
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(
95
+ per_partition_vocab_size, rank, world_size
96
+ )