hyunwoo3235 commited on
Commit
b04a333
·
1 Parent(s): 5a57f8e

load flax hubert model via remote code

Browse files
Files changed (2) hide show
  1. config.json +3 -0
  2. modeling_flax_hubert.py +966 -0
config.json CHANGED
@@ -4,6 +4,9 @@
4
  "architectures": [
5
  "HubertModel"
6
  ],
 
 
 
7
  "attention_dropout": 0.1,
8
  "bos_token_id": 1,
9
  "conv_bias": false,
 
4
  "architectures": [
5
  "HubertModel"
6
  ],
7
+ "auto_map": {
8
+ "FlaxAutoModel": "modeling_flax_hubert.FlaxHubertModel"
9
+ },
10
  "attention_dropout": 0.1,
11
  "bos_token_id": 1,
12
  "conv_bias": false,
modeling_flax_hubert.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Hubert model."""
16
+
17
+ from functools import partial
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import flax
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
25
+ from flax.linen.attention import dot_product_attention_weights
26
+ from flax.traverse_util import flatten_dict, unflatten_dict
27
+ from jax import lax
28
+ from transformers import HubertConfig
29
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput
30
+ from transformers.modeling_flax_utils import (
31
+ ACT2FN,
32
+ FlaxPreTrainedModel,
33
+ )
34
+ from transformers.utils import ModelOutput, logging
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ @flax.struct.dataclass
40
+ class FlaxHubertOutput(ModelOutput):
41
+ last_hidden_state: jnp.ndarray = None
42
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
43
+ attentions: Optional[Tuple[jnp.ndarray]] = None
44
+ extract_features: jnp.ndarray = None
45
+
46
+
47
+ class FlaxConvWithWeightNorm(nn.Module):
48
+ config: HubertConfig
49
+ dtype: jnp.dtype = jnp.float32
50
+
51
+ def setup(self):
52
+ self.conv = nn.Conv(
53
+ features=self.config.hidden_size,
54
+ kernel_size=(self.config.num_conv_pos_embeddings,),
55
+ kernel_init=jax.nn.initializers.he_normal(),
56
+ padding="VALID",
57
+ feature_group_count=self.config.num_conv_pos_embedding_groups,
58
+ dtype=self.dtype,
59
+ )
60
+ weight_shape = (
61
+ self.conv.features,
62
+ self.conv.features // self.conv.feature_group_count,
63
+ self.conv.kernel_size[0],
64
+ )
65
+ self.weight_v = self.param(
66
+ "weight_v", jax.nn.initializers.he_normal(), weight_shape
67
+ )
68
+ self.weight_g = self.param(
69
+ "weight_g",
70
+ lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :],
71
+ )
72
+ self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
73
+ self.prev_padding = self.conv.kernel_size[0] // 2
74
+
75
+ def _get_normed_weights(self):
76
+ weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
77
+ normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
78
+ normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
79
+ return normed_kernel
80
+
81
+ def __call__(self, hidden_states):
82
+ kernel = self._get_normed_weights()
83
+ hidden_states = jnp.pad(
84
+ hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0))
85
+ )
86
+ hidden_states = self.conv.apply(
87
+ {"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states
88
+ )
89
+ return hidden_states
90
+
91
+
92
+ class FlaxHubertNoLayerNormConvLayer(nn.Module):
93
+ config: HubertConfig
94
+ layer_id: int = 0
95
+ dtype: jnp.dtype = jnp.float32
96
+
97
+ def setup(self):
98
+ self.in_conv_dim = (
99
+ self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1
100
+ )
101
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
102
+
103
+ self.conv = nn.Conv(
104
+ features=self.config.conv_dim[self.layer_id],
105
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
106
+ strides=(self.config.conv_stride[self.layer_id],),
107
+ use_bias=self.config.conv_bias,
108
+ kernel_init=jax.nn.initializers.he_normal(),
109
+ padding="VALID",
110
+ dtype=self.dtype,
111
+ )
112
+ self.activation = ACT2FN[self.config.feat_extract_activation]
113
+
114
+ def __call__(self, hidden_states):
115
+ hidden_states = self.conv(hidden_states)
116
+ hidden_states = self.activation(hidden_states)
117
+ return hidden_states
118
+
119
+
120
+ class FlaxHubertLayerNormConvLayer(nn.Module):
121
+ config: HubertConfig
122
+ layer_id: int = 0
123
+ dtype: jnp.dtype = jnp.float32
124
+
125
+ def setup(self):
126
+ self.in_conv_dim = (
127
+ self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1
128
+ )
129
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
130
+
131
+ self.conv = nn.Conv(
132
+ features=self.config.conv_dim[self.layer_id],
133
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
134
+ strides=(self.config.conv_stride[self.layer_id],),
135
+ use_bias=self.config.conv_bias,
136
+ kernel_init=jax.nn.initializers.he_normal(),
137
+ padding="VALID",
138
+ dtype=self.dtype,
139
+ )
140
+ self.layer_norm = nn.LayerNorm(
141
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
142
+ )
143
+ self.activation = ACT2FN[self.config.feat_extract_activation]
144
+
145
+ def __call__(self, hidden_states):
146
+ hidden_states = self.conv(hidden_states)
147
+ hidden_states = self.layer_norm(hidden_states)
148
+ hidden_states = self.activation(hidden_states)
149
+ return hidden_states
150
+
151
+
152
+ class FlaxHubertGroupNormConvLayer(nn.Module):
153
+ config: HubertConfig
154
+ layer_id: int = 0
155
+ dtype: jnp.dtype = jnp.float32
156
+
157
+ def setup(self):
158
+ self.in_conv_dim = (
159
+ self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1
160
+ )
161
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
162
+
163
+ self.conv = nn.Conv(
164
+ features=self.config.conv_dim[self.layer_id],
165
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
166
+ strides=(self.config.conv_stride[self.layer_id],),
167
+ use_bias=self.config.conv_bias,
168
+ kernel_init=jax.nn.initializers.he_normal(),
169
+ padding="VALID",
170
+ dtype=self.dtype,
171
+ )
172
+ self.activation = ACT2FN[self.config.feat_extract_activation]
173
+
174
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, dtype=self.dtype)
175
+
176
+ def __call__(self, hidden_states):
177
+ hidden_states = self.conv(hidden_states)
178
+ hidden_states = self.layer_norm(hidden_states)
179
+ hidden_states = self.activation(hidden_states)
180
+ return hidden_states
181
+
182
+
183
+ class FlaxHubertPositionalConvEmbedding(nn.Module):
184
+ config: HubertConfig
185
+ dtype: jnp.dtype = jnp.float32
186
+
187
+ def setup(self):
188
+ self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
189
+ self.activation = ACT2FN[self.config.feat_extract_activation]
190
+ self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
191
+
192
+ def __call__(self, hidden_states):
193
+ hidden_states = hidden_states.transpose((0, 1, 2))
194
+
195
+ hidden_states = self.conv(hidden_states)
196
+
197
+ if self.num_pad_remove > 0:
198
+ hidden_states = hidden_states[:, : -self.num_pad_remove, :]
199
+ hidden_states = self.activation(hidden_states)
200
+
201
+ hidden_states = hidden_states.transpose((0, 1, 2))
202
+ return hidden_states
203
+
204
+
205
+ class FlaxConvLayersCollection(nn.Module):
206
+ config: HubertConfig
207
+ dtype: jnp.dtype = jnp.float32
208
+
209
+ def setup(self):
210
+ if self.config.feat_extract_norm == "layer":
211
+ self.layers = [
212
+ FlaxHubertLayerNormConvLayer(
213
+ self.config, layer_id=i, name=str(i), dtype=self.dtype
214
+ )
215
+ for i in range(self.config.num_feat_extract_layers)
216
+ ]
217
+ elif self.config.feat_extract_norm == "group":
218
+ self.layers = [
219
+ FlaxHubertGroupNormConvLayer(
220
+ self.config, layer_id=0, name=str(0), dtype=self.dtype
221
+ )
222
+ ] + [
223
+ FlaxHubertNoLayerNormConvLayer(
224
+ self.config, layer_id=i, name=str(i), dtype=self.dtype
225
+ )
226
+ for i in range(1, self.config.num_feat_extract_layers)
227
+ ]
228
+ else:
229
+ raise ValueError(
230
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group',"
231
+ " 'layer']"
232
+ )
233
+
234
+ def __call__(self, hidden_states):
235
+ for i, conv_layer in enumerate(self.layers):
236
+ hidden_states = conv_layer(hidden_states)
237
+ return hidden_states
238
+
239
+
240
+ class FlaxHubertFeatureEncoder(nn.Module):
241
+ config: HubertConfig
242
+ dtype: jnp.dtype = jnp.float32
243
+
244
+ def setup(self):
245
+ self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
246
+
247
+ def __call__(self, input_values, freeze_feature_encoder=False):
248
+ hidden_states = input_values[:, :, None]
249
+ hidden_states = self.conv_layers(hidden_states)
250
+ if freeze_feature_encoder:
251
+ hidden_states = jax.lax.stop_gradient(hidden_states)
252
+ return hidden_states
253
+
254
+
255
+ class FlaxHubertFeatureProjection(nn.Module):
256
+ config: HubertConfig
257
+ dtype: jnp.dtype = jnp.float32
258
+
259
+ def setup(self):
260
+ self.feat_proj_layer_norm = self.config.feat_proj_layer_norm
261
+ if self.feat_proj_layer_norm:
262
+ self.layer_norm = nn.LayerNorm(
263
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
264
+ )
265
+ self.projection = nn.Dense(
266
+ self.config.hidden_size,
267
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
268
+ dtype=self.dtype,
269
+ )
270
+ self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
271
+
272
+ def __call__(self, hidden_states, deterministic=True):
273
+ if self.feat_proj_layer_norm:
274
+ hidden_states = self.layer_norm(hidden_states)
275
+ hidden_states = self.projection(hidden_states)
276
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
277
+ return hidden_states
278
+
279
+
280
+ class FlaxHubertAttention(nn.Module):
281
+ config: HubertConfig
282
+ embed_dim: int
283
+ num_heads: int
284
+ dropout: float = 0.0
285
+ bias: bool = True
286
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
287
+
288
+ def setup(self) -> None:
289
+ self.head_dim = self.embed_dim // self.num_heads
290
+ if self.head_dim * self.num_heads != self.embed_dim:
291
+ raise ValueError(
292
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
293
+ f" and `num_heads`: {self.num_heads})."
294
+ )
295
+ self.scaling = self.head_dim**-0.5
296
+
297
+ dense = partial(
298
+ nn.Dense,
299
+ self.embed_dim,
300
+ use_bias=self.bias,
301
+ dtype=self.dtype,
302
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
303
+ )
304
+
305
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
306
+ self.out_proj = dense()
307
+
308
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
309
+
310
+ def _split_heads(self, hidden_states):
311
+ return hidden_states.reshape(
312
+ hidden_states.shape[:2] + (self.num_heads, self.head_dim)
313
+ )
314
+
315
+ def _merge_heads(self, hidden_states):
316
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
317
+
318
+ def __call__(
319
+ self,
320
+ hidden_states: jnp.ndarray,
321
+ attention_mask: Optional[jnp.ndarray] = None,
322
+ output_attentions: bool = False,
323
+ deterministic: bool = True,
324
+ ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
325
+ """Input shape: Batch x Time x Channel"""
326
+
327
+ # get query, key, value proj for self_attention
328
+ query_states = self.q_proj(hidden_states)
329
+ key_states = self.k_proj(hidden_states)
330
+ value_states = self.v_proj(hidden_states)
331
+
332
+ query_states = self._split_heads(query_states)
333
+ key_states = self._split_heads(key_states)
334
+ value_states = self._split_heads(value_states)
335
+
336
+ if attention_mask is not None:
337
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
338
+ attention_bias = lax.select(
339
+ attention_mask > 0,
340
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
341
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(
342
+ self.dtype
343
+ ),
344
+ )
345
+ else:
346
+ attention_bias = None
347
+
348
+ dropout_rng = None
349
+ if not deterministic and self.dropout > 0.0:
350
+ dropout_rng = self.make_rng("dropout")
351
+
352
+ attn_weights = dot_product_attention_weights(
353
+ query_states,
354
+ key_states,
355
+ bias=attention_bias,
356
+ dropout_rng=dropout_rng,
357
+ dropout_rate=self.dropout,
358
+ broadcast_dropout=True,
359
+ deterministic=deterministic,
360
+ dtype=self.dtype,
361
+ precision=None,
362
+ )
363
+
364
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
365
+ attn_output = self._merge_heads(attn_output)
366
+ attn_output = self.out_proj(attn_output)
367
+
368
+ return attn_output, attn_weights
369
+
370
+
371
+ class FlaxHubertFeedForward(nn.Module):
372
+ config: HubertConfig
373
+ dtype: jnp.dtype = jnp.float32
374
+
375
+ def setup(self):
376
+ self.intermediate_dropout = nn.Dropout(self.config.activation_dropout)
377
+
378
+ self.intermediate_dense = nn.Dense(
379
+ self.config.intermediate_size, dtype=self.dtype
380
+ )
381
+ if isinstance(self.config.hidden_act, str):
382
+ self.intermediate_activation = ACT2FN[self.config.hidden_act]
383
+ else:
384
+ self.intermediate_activation = self.config.hidden_act
385
+
386
+ self.output_dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
387
+ self.output_dropout = nn.Dropout(self.config.activation_dropout)
388
+
389
+ def __call__(self, hidden_states, deterministic=True):
390
+ hidden_states = self.intermediate_dense(hidden_states)
391
+ hidden_states = self.intermediate_activation(hidden_states)
392
+ hidden_states = self.intermediate_dropout(
393
+ hidden_states, deterministic=deterministic
394
+ )
395
+
396
+ hidden_states = self.output_dense(hidden_states)
397
+ hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
398
+
399
+ return hidden_states
400
+
401
+
402
+ class FlaxHubertEncoderLayer(nn.Module):
403
+ config: HubertConfig
404
+ dtype: jnp.dtype = jnp.float32
405
+
406
+ def setup(self):
407
+ self.attention = FlaxHubertAttention(
408
+ config=self.config,
409
+ embed_dim=self.config.hidden_size,
410
+ num_heads=self.config.num_attention_heads,
411
+ dropout=self.config.attention_dropout,
412
+ dtype=self.dtype,
413
+ )
414
+ self.dropout = nn.Dropout(self.config.hidden_dropout)
415
+ self.layer_norm = nn.LayerNorm(
416
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
417
+ )
418
+ self.feed_forward = FlaxHubertFeedForward(self.config, dtype=self.dtype)
419
+ self.final_layer_norm = nn.LayerNorm(
420
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
421
+ )
422
+
423
+ def __call__(
424
+ self,
425
+ hidden_states,
426
+ attention_mask: Optional[jnp.ndarray] = None,
427
+ output_attentions: bool = False,
428
+ deterministic=True,
429
+ ):
430
+ attn_residual = hidden_states
431
+ hidden_states, attn_weights = self.attention(
432
+ hidden_states=hidden_states,
433
+ attention_mask=attention_mask,
434
+ output_attentions=output_attentions,
435
+ deterministic=deterministic,
436
+ )
437
+
438
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
439
+ hidden_states = attn_residual + hidden_states
440
+
441
+ hidden_states = self.layer_norm(hidden_states)
442
+ hidden_states = hidden_states + self.feed_forward(
443
+ hidden_states, deterministic=deterministic
444
+ )
445
+ hidden_states = self.final_layer_norm(hidden_states)
446
+
447
+ outputs = (hidden_states,)
448
+
449
+ if output_attentions:
450
+ outputs += (attn_weights,)
451
+
452
+ return outputs
453
+
454
+
455
+ class FlaxHubertEncoderLayerStableLayerNorm(nn.Module):
456
+ config: HubertConfig
457
+ dtype: jnp.dtype = jnp.float32
458
+
459
+ def setup(self):
460
+ self.attention = FlaxHubertAttention(
461
+ config=self.config,
462
+ embed_dim=self.config.hidden_size,
463
+ num_heads=self.config.num_attention_heads,
464
+ dropout=self.config.attention_dropout,
465
+ dtype=self.dtype,
466
+ )
467
+ self.dropout = nn.Dropout(self.config.hidden_dropout)
468
+ self.layer_norm = nn.LayerNorm(
469
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
470
+ )
471
+ self.feed_forward = FlaxHubertFeedForward(self.config, dtype=self.dtype)
472
+ self.final_layer_norm = nn.LayerNorm(
473
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
474
+ )
475
+
476
+ def __call__(
477
+ self,
478
+ hidden_states,
479
+ attention_mask: Optional[jnp.ndarray] = None,
480
+ output_attentions: bool = False,
481
+ deterministic=True,
482
+ ):
483
+ attn_residual = hidden_states
484
+ hidden_states = self.layer_norm(hidden_states)
485
+ hidden_states, attn_weights = self.attention(
486
+ hidden_states=hidden_states,
487
+ attention_mask=attention_mask,
488
+ output_attentions=output_attentions,
489
+ deterministic=deterministic,
490
+ )
491
+
492
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
493
+ hidden_states = attn_residual + hidden_states
494
+
495
+ hidden_states = hidden_states + self.feed_forward(
496
+ self.final_layer_norm(hidden_states), deterministic=deterministic
497
+ )
498
+
499
+ outputs = (hidden_states,)
500
+
501
+ if output_attentions:
502
+ outputs += (attn_weights,)
503
+
504
+ return outputs
505
+
506
+
507
+ class FlaxHubertLayerCollection(nn.Module):
508
+ config: HubertConfig
509
+ dtype: jnp.dtype = jnp.float32
510
+
511
+ def setup(self):
512
+ self.layers = [
513
+ FlaxHubertEncoderLayer(self.config, name=str(i), dtype=self.dtype)
514
+ for i in range(self.config.num_hidden_layers)
515
+ ]
516
+
517
+ def __call__(
518
+ self,
519
+ hidden_states,
520
+ attention_mask=None,
521
+ deterministic: bool = True,
522
+ output_attentions: bool = False,
523
+ output_hidden_states: bool = False,
524
+ return_dict: bool = True,
525
+ ):
526
+ all_attentions = () if output_attentions else None
527
+ all_hidden_states = () if output_hidden_states else None
528
+
529
+ for i, layer in enumerate(self.layers):
530
+ if output_hidden_states:
531
+ all_hidden_states += (hidden_states,)
532
+
533
+ layer_outputs = layer(
534
+ hidden_states,
535
+ attention_mask,
536
+ deterministic=deterministic,
537
+ output_attentions=output_attentions,
538
+ )
539
+
540
+ hidden_states = layer_outputs[0]
541
+
542
+ if output_attentions:
543
+ all_attentions += (layer_outputs[1],)
544
+
545
+ if output_hidden_states:
546
+ all_hidden_states += (hidden_states,)
547
+
548
+ outputs = (hidden_states, all_hidden_states, all_attentions)
549
+
550
+ if not return_dict:
551
+ return tuple(v for v in outputs if v is not None)
552
+
553
+ return FlaxBaseModelOutput(
554
+ last_hidden_state=hidden_states,
555
+ hidden_states=all_hidden_states,
556
+ attentions=all_attentions,
557
+ )
558
+
559
+
560
+ class FlaxHubertEncoder(nn.Module):
561
+ config: HubertConfig
562
+ dtype: jnp.dtype = jnp.float32
563
+
564
+ def setup(self):
565
+ self.pos_conv_embed = FlaxHubertPositionalConvEmbedding(
566
+ self.config, dtype=self.dtype
567
+ )
568
+ self.layer_norm = nn.LayerNorm(
569
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
570
+ )
571
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
572
+ self.layers = FlaxHubertLayerCollection(self.config, dtype=self.dtype)
573
+
574
+ def __call__(
575
+ self,
576
+ hidden_states,
577
+ attention_mask: Optional[jnp.ndarray] = None,
578
+ output_attentions: bool = False,
579
+ output_hidden_states: bool = False,
580
+ return_dict: bool = True,
581
+ deterministic: bool = True,
582
+ ):
583
+ if attention_mask is not None:
584
+ # make sure padded tokens are not attended to
585
+ hidden_states = jnp.where(
586
+ jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape),
587
+ hidden_states,
588
+ 0,
589
+ )
590
+
591
+ position_embeddings = self.pos_conv_embed(hidden_states)
592
+
593
+ hidden_states = hidden_states + position_embeddings
594
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
595
+
596
+ outputs = self.layers(
597
+ hidden_states,
598
+ attention_mask,
599
+ deterministic=deterministic,
600
+ output_attentions=output_attentions,
601
+ output_hidden_states=output_hidden_states,
602
+ return_dict=return_dict,
603
+ )
604
+
605
+ last_hidden_state = self.layer_norm(outputs[0])
606
+
607
+ hidden_states = None
608
+ if output_hidden_states:
609
+ hidden_states = outputs[1]
610
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
611
+
612
+ if not return_dict:
613
+ outputs = (last_hidden_state, hidden_states) + (
614
+ outputs[2:] if output_hidden_states else outputs[1:]
615
+ )
616
+ return tuple(v for v in outputs if v is not None)
617
+
618
+ return FlaxBaseModelOutput(
619
+ last_hidden_state=last_hidden_state,
620
+ hidden_states=hidden_states,
621
+ attentions=outputs.attentions,
622
+ )
623
+
624
+
625
+ class FlaxHubertLayerStableLayerNormCollection(nn.Module):
626
+ config: HubertConfig
627
+ dtype: jnp.dtype = jnp.float32
628
+
629
+ def setup(self):
630
+ self.layers = [
631
+ FlaxHubertEncoderLayerStableLayerNorm(
632
+ self.config, name=str(i), dtype=self.dtype
633
+ )
634
+ for i in range(self.config.num_hidden_layers)
635
+ ]
636
+
637
+ def __call__(
638
+ self,
639
+ hidden_states,
640
+ attention_mask=None,
641
+ deterministic: bool = True,
642
+ output_attentions: bool = False,
643
+ output_hidden_states: bool = False,
644
+ return_dict: bool = True,
645
+ ):
646
+ all_attentions = () if output_attentions else None
647
+ all_hidden_states = () if output_hidden_states else None
648
+
649
+ for i, layer in enumerate(self.layers):
650
+ if output_hidden_states:
651
+ all_hidden_states += (hidden_states,)
652
+
653
+ layer_outputs = layer(
654
+ hidden_states,
655
+ attention_mask,
656
+ deterministic=deterministic,
657
+ output_attentions=output_attentions,
658
+ )
659
+
660
+ hidden_states = layer_outputs[0]
661
+
662
+ if output_attentions:
663
+ all_attentions += (layer_outputs[1],)
664
+
665
+ if output_hidden_states:
666
+ all_hidden_states += (hidden_states,)
667
+
668
+ outputs = (hidden_states, all_hidden_states, all_attentions)
669
+
670
+ if not return_dict:
671
+ return tuple(v for v in outputs if v is not None)
672
+
673
+ return FlaxBaseModelOutput(
674
+ last_hidden_state=hidden_states,
675
+ hidden_states=all_hidden_states,
676
+ attentions=all_attentions,
677
+ )
678
+
679
+
680
+ class FlaxHubertEncoderStableLayerNorm(nn.Module):
681
+ config: HubertConfig
682
+ dtype: jnp.dtype = jnp.float32
683
+
684
+ def setup(self):
685
+ self.pos_conv_embed = FlaxHubertPositionalConvEmbedding(
686
+ self.config, dtype=self.dtype
687
+ )
688
+ self.layer_norm = nn.LayerNorm(
689
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
690
+ )
691
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
692
+ self.layers = FlaxHubertLayerStableLayerNormCollection(
693
+ self.config, dtype=self.dtype
694
+ )
695
+
696
+ def __call__(
697
+ self,
698
+ hidden_states,
699
+ attention_mask: Optional[jnp.ndarray] = None,
700
+ output_attentions: bool = False,
701
+ output_hidden_states: bool = False,
702
+ return_dict: bool = True,
703
+ deterministic: bool = True,
704
+ ):
705
+ if attention_mask is not None:
706
+ hidden_states = jnp.where(
707
+ jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape),
708
+ hidden_states,
709
+ 0,
710
+ )
711
+
712
+ position_embeddings = self.pos_conv_embed(hidden_states)
713
+
714
+ hidden_states = hidden_states + position_embeddings
715
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
716
+
717
+ outputs = self.layers(
718
+ hidden_states,
719
+ attention_mask,
720
+ deterministic=deterministic,
721
+ output_attentions=output_attentions,
722
+ output_hidden_states=output_hidden_states,
723
+ return_dict=return_dict,
724
+ )
725
+
726
+ last_hidden_state = self.layer_norm(outputs[0])
727
+
728
+ hidden_states = None
729
+ if output_hidden_states:
730
+ hidden_states = outputs[1]
731
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
732
+
733
+ if not return_dict:
734
+ outputs = (last_hidden_state, hidden_states) + (
735
+ outputs[2:] if output_hidden_states else outputs[1:]
736
+ )
737
+ return tuple(v for v in outputs if v is not None)
738
+
739
+ return FlaxBaseModelOutput(
740
+ last_hidden_state=last_hidden_state,
741
+ hidden_states=hidden_states,
742
+ attentions=outputs.attentions,
743
+ )
744
+
745
+
746
+ class FlaxHubertPreTrainedModel(FlaxPreTrainedModel):
747
+ config_class = HubertConfig
748
+ base_model_prefix = "hubert"
749
+ main_input_name = "input_values"
750
+ module_class: nn.Module = None
751
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
752
+
753
+ def __init__(
754
+ self,
755
+ config: HubertConfig,
756
+ input_shape: Tuple = (1, 1024),
757
+ seed: int = 0,
758
+ dtype: jnp.dtype = jnp.float32,
759
+ _do_init: bool = True,
760
+ **kwargs,
761
+ ):
762
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
763
+ super().__init__(
764
+ config,
765
+ module,
766
+ input_shape=input_shape,
767
+ seed=seed,
768
+ dtype=dtype,
769
+ _do_init=_do_init,
770
+ )
771
+
772
+ def init_weights(
773
+ self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None
774
+ ) -> FrozenDict:
775
+ input_values = jnp.zeros(input_shape, dtype="i4")
776
+ attention_mask = jnp.ones_like(input_values)
777
+ params_rng, dropout_rng = jax.random.split(rng, 2)
778
+ rngs = {"params": params_rng, "dropout": dropout_rng}
779
+
780
+ random_params = self.module.init(
781
+ rngs, input_values, attention_mask, return_dict=False
782
+ )["params"]
783
+
784
+ if params is not None:
785
+ random_params = flatten_dict(unfreeze(random_params))
786
+ params = flatten_dict(unfreeze(params))
787
+ for missing_key in self._missing_keys:
788
+ params[missing_key] = random_params[missing_key]
789
+ self._missing_keys = set()
790
+ return freeze(unflatten_dict(params))
791
+ else:
792
+ return random_params
793
+
794
+ def __call__(
795
+ self,
796
+ input_values,
797
+ attention_mask=None,
798
+ mask_time_indices=None,
799
+ params: dict = None,
800
+ dropout_rng: jax.random.PRNGKey = None,
801
+ train: bool = False,
802
+ output_attentions: Optional[bool] = None,
803
+ output_hidden_states: Optional[bool] = None,
804
+ freeze_feature_encoder: bool = False,
805
+ return_dict: Optional[bool] = None,
806
+ ):
807
+ output_attentions = (
808
+ output_attentions
809
+ if output_attentions is not None
810
+ else self.config.output_attentions
811
+ )
812
+ output_hidden_states = (
813
+ output_hidden_states
814
+ if output_hidden_states is not None
815
+ else self.config.output_hidden_states
816
+ )
817
+ return_dict = (
818
+ return_dict if return_dict is not None else self.config.return_dict
819
+ )
820
+
821
+ batch_size, sequence_length = input_values.shape
822
+
823
+ if attention_mask is None:
824
+ attention_mask = jnp.ones((batch_size, sequence_length))
825
+
826
+ rngs = {}
827
+ if dropout_rng is not None:
828
+ rngs["dropout"] = dropout_rng
829
+
830
+ inputs = {"params": params or self.params}
831
+
832
+ return self.module.apply(
833
+ inputs,
834
+ jnp.array(input_values, dtype="f4"),
835
+ jnp.array(attention_mask, dtype="i4"),
836
+ mask_time_indices,
837
+ not train,
838
+ output_attentions,
839
+ output_hidden_states,
840
+ freeze_feature_encoder,
841
+ return_dict,
842
+ rngs=rngs,
843
+ )
844
+
845
+
846
+ class FlaxHubertModule(nn.Module):
847
+ config: HubertConfig
848
+ dtype: jnp.dtype = jnp.float32
849
+
850
+ def setup(self):
851
+ self.feature_extractor = FlaxHubertFeatureEncoder(self.config, dtype=self.dtype)
852
+ self.feature_projection = FlaxHubertFeatureProjection(
853
+ self.config, dtype=self.dtype
854
+ )
855
+
856
+ if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
857
+ self.masked_spec_embed = self.param(
858
+ "masked_spec_embed",
859
+ nn.initializers.uniform(dtype=self.dtype),
860
+ (self.config.hidden_size,),
861
+ )
862
+
863
+ if self.config.do_stable_layer_norm:
864
+ self.encoder = FlaxHubertEncoderStableLayerNorm(self.config)
865
+ else:
866
+ self.encoder = FlaxHubertEncoder(self.config)
867
+
868
+ def __call__(
869
+ self,
870
+ input_values: Optional[jnp.ndarray],
871
+ attention_mask: Optional[jnp.ndarray] = None,
872
+ mask_time_indices: Optional[jnp.ndarray] = None,
873
+ deterministic: bool = True,
874
+ output_attentions: Optional[bool] = None,
875
+ output_hidden_states: Optional[bool] = None,
876
+ freeze_feature_encoder: bool = False,
877
+ return_dict: Optional[bool] = None,
878
+ ) -> Union[Tuple, FlaxHubertOutput]:
879
+ output_attentions = (
880
+ output_attentions
881
+ if output_attentions is not None
882
+ else self.config.output_attentions
883
+ )
884
+ output_hidden_states = (
885
+ output_hidden_states
886
+ if output_hidden_states is not None
887
+ else self.config.output_hidden_states
888
+ )
889
+ return_dict = (
890
+ return_dict if return_dict is not None else self.config.use_return_dict
891
+ )
892
+
893
+ extract_features = self.feature_extractor(input_values, freeze_feature_encoder)
894
+
895
+ if attention_mask is not None:
896
+ attention_mask = self._get_feature_vector_attention_mask(
897
+ extract_features.shape[1], attention_mask
898
+ )
899
+
900
+ hidden_states = self.feature_projection(
901
+ extract_features, deterministic=deterministic
902
+ )
903
+ if mask_time_indices is not None:
904
+ hidden_states = jnp.where(
905
+ jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
906
+ jnp.broadcast_to(
907
+ self.masked_spec_embed[None, None, :], hidden_states.shape
908
+ ),
909
+ hidden_states,
910
+ )
911
+
912
+ encoder_outputs = self.encoder(
913
+ hidden_states,
914
+ attention_mask=attention_mask,
915
+ deterministic=deterministic,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ return_dict=return_dict,
919
+ )
920
+
921
+ hidden_states = encoder_outputs[0]
922
+
923
+ if not return_dict:
924
+ return (hidden_states,) + encoder_outputs[1:]
925
+
926
+ return FlaxHubertOutput(
927
+ last_hidden_state=hidden_states,
928
+ hidden_states=encoder_outputs.hidden_states,
929
+ attentions=encoder_outputs.attentions,
930
+ extract_features=extract_features,
931
+ )
932
+
933
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
934
+ def _conv_out_length(input_length, kernel_size, stride):
935
+ return (input_length - kernel_size) // stride + 1
936
+
937
+ for kernel_size, stride in zip(
938
+ self.config.conv_kernel, self.config.conv_stride
939
+ ):
940
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
941
+
942
+ return input_lengths
943
+
944
+ def _get_feature_vector_attention_mask(
945
+ self, feature_vector_length: int, attention_mask: jnp.ndarray
946
+ ):
947
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
948
+
949
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths)
950
+
951
+ batch_size = attention_mask.shape[0]
952
+
953
+ attention_mask = jnp.zeros(
954
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype
955
+ )
956
+ attention_mask = attention_mask.at[
957
+ jnp.arange(attention_mask.shape[0]), output_lengths - 1
958
+ ].set(1)
959
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype(
960
+ "bool"
961
+ )
962
+ return attention_mask
963
+
964
+
965
+ class FlaxHubertModel(FlaxHubertPreTrainedModel):
966
+ module_class = FlaxHubertModule