Rocketknight1 HF staff commited on
Commit
2db3e9f
·
1 Parent(s): e3c4338

Upload HyenaDNAForCausalLM

Browse files
Files changed (4) hide show
  1. config.json +35 -0
  2. configuration_hyena.py +88 -0
  3. model.safetensors +3 -0
  4. modeling_hyena.py +569 -0
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "hyenadna-medium-450k-seqlen-hf",
3
+ "activation_freq": 10,
4
+ "architectures": [
5
+ "HyenaDNAForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hyena.HyenaConfig",
9
+ "AutoModel": "modeling_hyena.HyenaDNAModel",
10
+ "AutoModelForCausalLM": "modeling_hyena.HyenaDNAForCausalLM",
11
+ "AutoModelForSequenceClassification": "modeling_hyena.HyenaDNAForSequenceClassification"
12
+ },
13
+ "d_inner": 1024,
14
+ "d_model": 256,
15
+ "emb_dim": 5,
16
+ "embed_dropout": 0.1,
17
+ "filter_order": 64,
18
+ "hyena_dropout": 0.0,
19
+ "hyena_filter_dropout": 0.0,
20
+ "hyena_order": 2,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_epsilon": 1e-05,
23
+ "max_seq_len": 450002,
24
+ "model_type": "hyenadna",
25
+ "n_layer": 8,
26
+ "num_inner_mlps": 2,
27
+ "pad_vocab_size_multiple": 8,
28
+ "short_filter_order": 3,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "float32",
31
+ "train_freq": true,
32
+ "transformers_version": "4.35.0.dev0",
33
+ "use_bias": true,
34
+ "vocab_size": 12
35
+ }
configuration_hyena.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class HyenaConfig(PretrainedConfig):
6
+ model_type = "hyenadna"
7
+ def __init__(
8
+ self,
9
+ vocab_size=12,
10
+ d_model=256,
11
+ d_inner=None,
12
+ use_bias=True,
13
+ train_freq=True,
14
+ max_seq_len=1024,
15
+ emb_dim=3,
16
+ n_layer=12,
17
+ num_inner_mlps=2,
18
+ hyena_order=2,
19
+ short_filter_order=3,
20
+ filter_order=64,
21
+ activation_freq=1,
22
+ embed_dropout=0.1,
23
+ hyena_dropout=0.0,
24
+ hyena_filter_dropout=0.0,
25
+ layer_norm_epsilon=1e-5,
26
+ initializer_range=0.02,
27
+ pad_vocab_size_multiple=8,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.d_model = d_model
32
+ if d_inner is None:
33
+ self.d_inner = 4 * d_model
34
+ else:
35
+ self.d_inner = d_inner
36
+ self.use_bias = use_bias
37
+ self.train_freq = train_freq
38
+ self.max_seq_len = max_seq_len
39
+ self.emb_dim = emb_dim
40
+ self.n_layer = n_layer
41
+ self.hyena_order = hyena_order
42
+ self.filter_order = filter_order
43
+ self.short_filter_order = short_filter_order
44
+ self.activation_freq = activation_freq
45
+ self.num_inner_mlps = num_inner_mlps
46
+ self.embed_dropout = embed_dropout
47
+ self.hyena_dropout = hyena_dropout
48
+ self.hyena_filter_dropout = hyena_filter_dropout
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.initializer_range = initializer_range
51
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
52
+ super().__init__(**kwargs)
53
+
54
+ @classmethod
55
+ def from_original_config(cls, config_path, **kwargs):
56
+ with open(config_path, "r") as f:
57
+ config = json.load(f)
58
+
59
+ vocab_size = config["vocab_size"]
60
+ d_model = config["d_model"]
61
+ d_inner = config["d_inner"]
62
+ max_seq_len = config["layer"]["l_max"]
63
+ emb_dim = config["layer"]["emb_dim"]
64
+ filter_order = config["layer"]["filter_order"]
65
+ if "local_order" in config["layer"]:
66
+ short_filter_order = config["layer"]["local_order"]
67
+ elif "short_filter_order" in config["layer"]:
68
+ short_filter_order = config["layer"]["short_filter_order"]
69
+ else:
70
+ short_filter_order = 3
71
+ n_layer = config["n_layer"]
72
+ activation_freq = config["layer"]["w"]
73
+ embed_dropout = config["embed_dropout"]
74
+ pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
75
+ return cls(vocab_size=vocab_size,
76
+ d_model=d_model,
77
+ d_inner=d_inner,
78
+ max_seq_len=max_seq_len,
79
+ emb_dim=emb_dim,
80
+ filter_order=filter_order,
81
+ short_filter_order=short_filter_order,
82
+ n_layer=n_layer,
83
+ activation_freq=activation_freq,
84
+ embed_dropout=embed_dropout,
85
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
86
+ tie_word_embeddings=False,
87
+ **kwargs
88
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a4e6386b1e469130e7fd2e9aab2afa3d71ba5bfa58a19b284d13e4317cc25cb
3
+ size 112652080
modeling_hyena.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HyenaDNA custom code port to Hugging Face Hub"""
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from .configuration_hyena import HyenaConfig
9
+ from transformers import PreTrainedModel
10
+ from typing import Optional, Tuple, Union
11
+ from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput, BaseModelOutputWithNoAttention
12
+
13
+
14
+ def fftconv(u, k, D):
15
+ """
16
+ We apply a convolution through the fourier domain (from the Convolution Theorem)
17
+
18
+ """
19
+ seqlen = u.shape[-1]
20
+ fft_size = 2 * seqlen
21
+
22
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
24
+
25
+ if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
27
+
28
+ out = y + u * D.unsqueeze(-1)
29
+ return out.to(dtype=u.dtype)
30
+
31
+
32
+ @torch.jit.script
33
+ def mul_sum(q, y):
34
+ return (q * y).sum(dim=1)
35
+
36
+
37
+ class HyenaSin(nn.Module):
38
+ """The Sin activation function for the Hyena Filter function."""
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.freq = nn.Parameter(config.activation_freq * torch.ones(1, config.filter_order)) if config.train_freq else config.activation_freq * torch.ones(1, config.filter_order)
42
+
43
+ def forward(self, x):
44
+ return torch.sin(self.freq * x)
45
+
46
+
47
+ class HyenaPositionalEmbedding(nn.Module):
48
+ def __init__(self, config):
49
+ """Complex exponential positional embeddings for Hyena filters."""
50
+ super().__init__()
51
+
52
+ self.seq_len = config.max_seq_len
53
+ # The time embedding fed to the filteres is normalized so that t_f = 1
54
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
55
+
56
+ if config.emb_dim > 1:
57
+ bands = (config.emb_dim - 1) // 2
58
+ # To compute the right embeddings we use the "proper" linspace
59
+ t_rescaled = torch.linspace(0, self.seq_len - 1, self.seq_len)[None, :, None]
60
+ w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
+
62
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+ # Matt: This is just Euler's formula, so if complex64 is a problem it can be replaced
64
+ # by separate sin() and cos() calls.
65
+ z = torch.exp(-1j * f * w)
66
+ z = torch.cat([t, z.real, z.imag], dim=-1)
67
+ # TODO Set z's LR to lr_pos_emb
68
+ self.z = nn.Parameter(z, requires_grad=True)
69
+ self.register_buffer("t", t)
70
+
71
+ def forward(self, L):
72
+ return self.z[:, :L], self.t[:, :L]
73
+
74
+
75
+ class HyenaExponentialModulation(nn.Module):
76
+ """The window function applied to the output of the (MLP) filter function."""
77
+ def __init__(
78
+ self,
79
+ d_model,
80
+ fast_decay_pct=0.3,
81
+ slow_decay_pct=1.5,
82
+ target=1e-2,
83
+ modulate: bool=True,
84
+ shift: float = 0.05,
85
+ **kwargs
86
+ ):
87
+ super().__init__()
88
+ self.modulate = modulate
89
+ self.shift = shift
90
+ max_decay = math.log(target) / fast_decay_pct
91
+ min_decay = math.log(target) / slow_decay_pct
92
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
93
+ self.register_buffer("deltas", deltas)
94
+
95
+ def forward(self, t, x):
96
+ if self.modulate:
97
+ decay = torch.exp(-t * self.deltas.abs())
98
+ x = x * (decay + self.shift)
99
+ return x
100
+
101
+
102
+ class HyenaFilter(nn.Module):
103
+ def __init__(
104
+ self,
105
+ config,
106
+ **kwargs
107
+ ):
108
+ """
109
+ Implicit long filter with modulation.
110
+
111
+ Args:
112
+ d_model: number of channels in the input
113
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
114
+ order: width of the FFN
115
+ num_inner_mlps: number of inner linear layers inside filter MLP
116
+
117
+ Note:
118
+ filter_dropout is not implemented
119
+ """
120
+ super().__init__()
121
+
122
+ self.d_model = config.d_model * (config.hyena_order - 1)
123
+ self.use_bias = config.use_bias
124
+ self.bias = nn.Parameter(torch.randn(self.d_model))
125
+ self.dropout = nn.Dropout(config.hyena_filter_dropout)
126
+
127
+ act = HyenaSin(config)
128
+ self.emb_dim = config.emb_dim
129
+ assert self.emb_dim % 2 != 0 and self.emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
130
+ self.seq_len = config.max_seq_len
131
+
132
+ self.pos_emb = HyenaPositionalEmbedding(config)
133
+
134
+ self.implicit_filter = nn.Sequential(
135
+ nn.Linear(self.emb_dim, config.filter_order),
136
+ act,
137
+ )
138
+ for i in range(config.num_inner_mlps):
139
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
140
+ self.implicit_filter.append(act)
141
+
142
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
143
+
144
+ self.modulation = HyenaExponentialModulation(config.d_model)
145
+
146
+ self.normalized = False
147
+
148
+ def filter(self, L, *args, **kwargs):
149
+ z, t = self.pos_emb(L)
150
+ h = self.implicit_filter(z)
151
+ h = self.modulation(t, h)
152
+ return h
153
+
154
+ def forward(self, x, L, k=None, bias=None, *args, **kwargs):
155
+ if k is None: k = self.filter(L)
156
+
157
+ # Ensure compatibility with filters that return a tuple
158
+ k = k[0] if type(k) is tuple else k
159
+
160
+ y = fftconv(x, k, bias)
161
+ return y
162
+
163
+
164
+ class HyenaOperator(nn.Module):
165
+ def __init__(
166
+ self,
167
+ config,
168
+ **filter_args,
169
+ ):
170
+ r"""
171
+ Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
172
+
173
+ Args:
174
+ d_model (int): Dimension of the input and output embeddings (width of the layer)
175
+ l_max: (int): Maximum input sequence length. Defaults to None
176
+ order: (int): Depth of the Hyena recurrence. Defaults to 2
177
+ dropout: (float): Dropout probability. Defaults to 0.0
178
+ filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
179
+ """
180
+ super().__init__()
181
+
182
+ self.d_model = config.d_model
183
+ self.l_max = config.max_seq_len
184
+ self.order = config.hyena_order
185
+ inner_width = config.d_model * (self.order + 1)
186
+ self.dropout = nn.Dropout(config.hyena_dropout)
187
+ self.in_proj = nn.Linear(self.d_model, inner_width)
188
+ self.out_proj = nn.Linear(self.d_model, self.d_model)
189
+
190
+ self.short_filter = nn.Conv1d(
191
+ inner_width,
192
+ inner_width,
193
+ config.short_filter_order,
194
+ padding=2,
195
+ groups=inner_width
196
+ )
197
+ self.filter_fn = HyenaFilter(config)
198
+
199
+ def forward(self, u):
200
+ l = u.size(-2)
201
+ l_filter = min(l, self.l_max)
202
+ u = self.in_proj(u).transpose(1, 2)
203
+
204
+ uc = self.short_filter(u)[...,:l_filter]
205
+ *x, v = uc.split(self.d_model, dim=1)
206
+
207
+ k = self.filter_fn.filter(l_filter)[0]
208
+ k = k.transpose(0, 1).reshape(self.order - 1, self.d_model, l_filter)
209
+ bias = self.filter_fn.bias.reshape(self.order - 1, self.d_model)
210
+
211
+ for o, x_i in enumerate(reversed(x[1:])):
212
+ v = self.dropout(v * x_i)
213
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
214
+
215
+ y = (v * x[0]).transpose(1, 2)
216
+
217
+ y = self.out_proj(y)
218
+ return y
219
+
220
+ class HyenaMlp(nn.Module):
221
+
222
+ def __init__(self, config):
223
+ """
224
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
225
+ """
226
+ super().__init__()
227
+ in_features = config.d_model
228
+ hidden_features = config.d_inner
229
+ self.fc1 = nn.Linear(in_features, hidden_features)
230
+ self.fc2 = nn.Linear(hidden_features, config.d_model)
231
+
232
+ def forward(self, x):
233
+ y = self.fc1(x)
234
+ y = F.gelu(y, approximate="tanh")
235
+ y = self.fc2(y)
236
+ return y
237
+
238
+ class HyenaBlock(nn.Module):
239
+
240
+ def __init__(self, config):
241
+ """
242
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
243
+ For prenorm=True, this Block has a slightly different structure compared to a regular
244
+ prenorm Transformer block.
245
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
246
+ [Ref: https://arxiv.org/abs/2002.04745]
247
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
248
+ the hidden_states (output of the MLP) and the residual.
249
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
250
+ The residual needs to be provided (except for the very first block).
251
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
252
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
253
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
254
+ This is for performance reason: for post-norm architecture, returning the input allows us
255
+ to fuse the backward of nn.Linear with the residual connection.
256
+ """
257
+ super().__init__()
258
+ self.mixer = HyenaOperator(config)
259
+ self.norm1 = nn.LayerNorm(config.d_model)
260
+ self.mlp = HyenaMlp(config)
261
+ self.norm2 = nn.LayerNorm(config.d_model)
262
+
263
+ def forward(self, hidden_states):
264
+ r"""Pass the input through the encoder layer.
265
+ Args:
266
+ hidden_states: the sequence to the encoder layer (required).
267
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
268
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
269
+ before applying the query projection. Useful for e.g., ViT where we only care
270
+ about the CLS token in the last layer.
271
+ """
272
+ residual = hidden_states
273
+ residual = residual.to(torch.float32)
274
+ hyena_normed = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
275
+ hidden_states = self.mixer(hyena_normed)
276
+ # Tested above here and all is equivalent. That means the mixer is fine!!!
277
+ residual = hidden_states + residual
278
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
279
+ residual = residual.to(torch.float32)
280
+
281
+ hidden_states = self.mlp(hidden_states)
282
+ return hidden_states + residual
283
+
284
+
285
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
286
+
287
+
288
+ class HyenaEmbeddings(nn.Module):
289
+
290
+ def __init__(self, config, padding_idx=None):
291
+ """
292
+ If max_position_embeddings <= 0, there's no position embeddings
293
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
294
+ the project up to embed_dim
295
+ """
296
+ super().__init__()
297
+ vocab_size = config.vocab_size
298
+ if vocab_size % config.pad_vocab_size_multiple != 0:
299
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
300
+ self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
301
+
302
+ def forward(self, input_ids):
303
+ """
304
+ input_ids: (batch, seqlen)
305
+ """
306
+ embeddings = self.word_embeddings(input_ids)
307
+ return embeddings
308
+
309
+ class HyenaLMBackbone(nn.Module):
310
+
311
+ def __init__(self, config) -> None:
312
+ super().__init__()
313
+ # note max_position_embeddings is 0 for Hyena, and therefore isn't used
314
+ self.embeddings = HyenaEmbeddings(config)
315
+ self.dropout = nn.Dropout(config.embed_dropout)
316
+
317
+ self.layers = nn.ModuleList([HyenaBlock(config) for i in range(config.n_layer)])
318
+
319
+ self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
320
+ self.gradient_checkpointing = False
321
+
322
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
323
+ all_hidden_states = []
324
+ if inputs_embeds is not None:
325
+ hidden_states = inputs_embeds
326
+ else:
327
+ hidden_states = self.embeddings(input_ids)
328
+ if output_hidden_states:
329
+ all_hidden_states.append(hidden_states)
330
+
331
+ for layer in self.layers:
332
+ if self.gradient_checkpointing and self.training:
333
+ hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
334
+ else:
335
+ hidden_states = layer(hidden_states)
336
+ if output_hidden_states:
337
+ all_hidden_states.append(hidden_states)
338
+
339
+ hidden_states = self.ln_f(hidden_states.to(dtype=self.ln_f.weight.dtype))
340
+ if output_hidden_states:
341
+ all_hidden_states.append(hidden_states)
342
+
343
+ return hidden_states, all_hidden_states
344
+
345
+
346
+ class HyenaDNAPreTrainedModel(PreTrainedModel):
347
+ config_class = HyenaConfig
348
+ base_model_prefix = "hyena"
349
+ supports_gradient_checkpointing = True
350
+ _no_split_modules = ["HyenaBlock"]
351
+ _skip_keys_device_placement = "past_key_values"
352
+
353
+ def _init_weights(self, initializer_range=0.02):
354
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
355
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
356
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
357
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
358
+ #
359
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
360
+ for name, p in self.named_parameters():
361
+ if name in ["out_proj.weight", "fc2.weight"]:
362
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
363
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
364
+ # If using GLU activation for now, we scale the std by 2
365
+ elif name in ["output_linear.0.weight"]:
366
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
367
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
368
+
369
+
370
+ class HyenaDNAModel(HyenaDNAPreTrainedModel):
371
+ def __init__(self, config) -> None:
372
+ super().__init__(config)
373
+
374
+ self.backbone = HyenaLMBackbone(config)
375
+ self.config = config
376
+
377
+ # Initialize weights and apply final processing
378
+ self.post_init()
379
+
380
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
381
+ output_hidden_states = (
382
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
383
+ )
384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
385
+
386
+ hidden_states, all_hidden_states = self.backbone(input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states)
387
+ if return_dict:
388
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states,
389
+ hidden_states=all_hidden_states if output_hidden_states else None)
390
+ elif output_hidden_states:
391
+ return hidden_states, all_hidden_states
392
+ else:
393
+ return hidden_states
394
+
395
+
396
+ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
397
+
398
+ def __init__(self, config):
399
+ super().__init__(config)
400
+ self.hyena = HyenaDNAModel(config)
401
+ vocab_size = config.vocab_size
402
+ if vocab_size % config.pad_vocab_size_multiple != 0:
403
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
404
+ self.vocab_size = vocab_size
405
+ self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
406
+
407
+ # Initialize weights and apply final processing
408
+ self.post_init()
409
+
410
+ def get_input_embeddings(self):
411
+ return self.hyena.backbone.embeddings.word_embeddings
412
+
413
+ def set_input_embeddings(self, value):
414
+ self.hyena.backbone.embeddings.word_embeddings = value
415
+
416
+ def get_output_embeddings(self):
417
+ return self.lm_head
418
+
419
+ def set_output_embeddings(self, new_embeddings):
420
+ self.lm_head = new_embeddings
421
+
422
+ def set_decoder(self, decoder):
423
+ self.hyena = decoder
424
+
425
+ def get_decoder(self):
426
+ return self.hyena
427
+
428
+ def forward(
429
+ self,
430
+ input_ids: torch.LongTensor = None,
431
+ inputs_embeds: Optional[torch.FloatTensor] = None,
432
+ labels: Optional[torch.LongTensor] = None,
433
+ output_hidden_states: Optional[bool] = None,
434
+ return_dict: Optional[bool] = None,
435
+ ) -> Union[Tuple, CausalLMOutput]:
436
+
437
+ output_hidden_states = (
438
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
439
+ )
440
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
441
+
442
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
443
+ outputs = self.hyena(
444
+ input_ids=input_ids,
445
+ inputs_embeds=inputs_embeds,
446
+ output_hidden_states=output_hidden_states,
447
+ return_dict=return_dict,
448
+ )
449
+
450
+ hidden_states = outputs[0]
451
+ logits = self.lm_head(hidden_states)
452
+ logits = logits.float()
453
+
454
+ loss = None
455
+ if labels is not None:
456
+ # Shift so that tokens < n predict n
457
+ shift_logits = logits[..., :-1, :].contiguous()
458
+ shift_labels = labels[..., 1:].contiguous()
459
+ # Flatten the tokens
460
+ loss_fct = nn.CrossEntropyLoss()
461
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
462
+ shift_labels = shift_labels.view(-1)
463
+ # Enable model parallelism
464
+ shift_labels = shift_labels.to(shift_logits.device)
465
+ loss = loss_fct(shift_logits, shift_labels)
466
+
467
+ if not return_dict:
468
+ output = (logits,) + outputs[1:]
469
+ return (loss,) + output if loss is not None else output
470
+
471
+ return CausalLMOutput(
472
+ loss=loss,
473
+ logits=logits,
474
+ hidden_states=outputs.hidden_states,
475
+ )
476
+
477
+
478
+ class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
479
+ def __init__(self, config):
480
+ super().__init__(config)
481
+ self.num_labels = config.num_labels
482
+ self.hyena = HyenaDNAModel(config)
483
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
484
+
485
+ # Initialize weights and apply final processing
486
+ self.post_init()
487
+
488
+ def get_input_embeddings(self):
489
+ return self.hyena.backbone.embeddings.word_embeddings
490
+
491
+ def set_input_embeddings(self, value):
492
+ self.hyena.backbone.embeddings.word_embeddings = value
493
+
494
+ def forward(
495
+ self,
496
+ input_ids: torch.LongTensor = None,
497
+ inputs_embeds: Optional[torch.FloatTensor] = None,
498
+ labels: Optional[torch.LongTensor] = None,
499
+ output_hidden_states: Optional[bool] = None,
500
+ return_dict: Optional[bool] = None,
501
+ ) -> Union[Tuple, SequenceClassifierOutput]:
502
+ r"""
503
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
504
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
505
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
506
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
507
+ """
508
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
509
+
510
+ transformer_outputs = self.hyena(
511
+ input_ids,
512
+ inputs_embeds=inputs_embeds,
513
+ output_hidden_states=output_hidden_states,
514
+ return_dict=return_dict,
515
+ )
516
+ hidden_states = transformer_outputs[0]
517
+ logits = self.score(hidden_states)
518
+
519
+ if input_ids is not None:
520
+ batch_size = input_ids.shape[0]
521
+ else:
522
+ batch_size = inputs_embeds.shape[0]
523
+
524
+ if self.config.pad_token_id is None and batch_size != 1:
525
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
526
+ if self.config.pad_token_id is None:
527
+ sequence_lengths = -1
528
+ else:
529
+ if input_ids is not None:
530
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
531
+ logits.device
532
+ )
533
+ else:
534
+ sequence_lengths = -1
535
+
536
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
537
+
538
+ loss = None
539
+ if labels is not None:
540
+ labels = labels.to(logits.device)
541
+ if self.config.problem_type is None:
542
+ if self.num_labels == 1:
543
+ self.config.problem_type = "regression"
544
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
545
+ self.config.problem_type = "single_label_classification"
546
+ else:
547
+ self.config.problem_type = "multi_label_classification"
548
+
549
+ if self.config.problem_type == "regression":
550
+ loss_fct = nn.MSELoss()
551
+ if self.num_labels == 1:
552
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
553
+ else:
554
+ loss = loss_fct(pooled_logits, labels)
555
+ elif self.config.problem_type == "single_label_classification":
556
+ loss_fct = nn.CrossEntropyLoss()
557
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
558
+ elif self.config.problem_type == "multi_label_classification":
559
+ loss_fct = nn.BCEWithLogitsLoss()
560
+ loss = loss_fct(pooled_logits, labels)
561
+ if not return_dict:
562
+ output = (pooled_logits,) + transformer_outputs[1:]
563
+ return ((loss,) + output) if loss is not None else output
564
+
565
+ return SequenceClassifierOutput(
566
+ loss=loss,
567
+ logits=pooled_logits,
568
+ hidden_states=transformer_outputs.hidden_states,
569
+ )