Aman K
commited on
Commit
·
a1c1315
1
Parent(s):
5ae3a9f
prepared alignment model to be loaded using flaxautomodel
Browse files- config.json +4 -1
- flax_model.msgpack +3 -0
- flax_modeling_alignment.py +181 -0
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "amankhandelia/
|
3 |
"activation_dropout": 0.1,
|
4 |
"adapter_attn_dim": null,
|
5 |
"adapter_kernel_size": 3,
|
@@ -9,6 +9,9 @@
|
|
9 |
"architectures": [
|
10 |
"Wav2Vec2ForAudioFrameClassification"
|
11 |
],
|
|
|
|
|
|
|
12 |
"attention_dropout": 0.0,
|
13 |
"bos_token_id": 1,
|
14 |
"classifier_proj_size": 256,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "amankhandelia/flax_mms_alignment_model",
|
3 |
"activation_dropout": 0.1,
|
4 |
"adapter_attn_dim": null,
|
5 |
"adapter_kernel_size": 3,
|
|
|
9 |
"architectures": [
|
10 |
"Wav2Vec2ForAudioFrameClassification"
|
11 |
],
|
12 |
+
"auto_map": {
|
13 |
+
"FlaxAutoModel": "flax_modeling_alignment.FlaxWav2Vec2ForAudioFrameClassification"
|
14 |
+
},
|
15 |
"attention_dropout": 0.0,
|
16 |
"bos_token_id": 1,
|
17 |
"classifier_proj_size": 256,
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a9a569d76919565b0879dec40c3f545a00a03cd839820a248058dc021e862a6
|
3 |
+
size 1261893241
|
flax_modeling_alignment.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import flax.linen as nn
|
6 |
+
|
7 |
+
from transformers.modeling_flax_outputs import FlaxCausalLMOutput
|
8 |
+
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
|
9 |
+
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
|
10 |
+
FlaxWav2Vec2FeatureEncoder,
|
11 |
+
FlaxWav2Vec2FeatureProjection,
|
12 |
+
FlaxWav2Vec2StableLayerNormEncoder,
|
13 |
+
FlaxWav2Vec2Adapter,
|
14 |
+
FlaxWav2Vec2PreTrainedModel,
|
15 |
+
FlaxWav2Vec2BaseModelOutput,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class FlaxWav2Vec2Module(nn.Module):
|
20 |
+
config: Wav2Vec2Config
|
21 |
+
dtype: jnp.dtype = jnp.float32
|
22 |
+
|
23 |
+
def setup(self):
|
24 |
+
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
|
25 |
+
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
|
26 |
+
if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
|
27 |
+
self.masked_spec_embed = self.param(
|
28 |
+
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
|
29 |
+
)
|
30 |
+
|
31 |
+
if self.config.do_stable_layer_norm:
|
32 |
+
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
|
33 |
+
else:
|
34 |
+
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
|
35 |
+
|
36 |
+
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
|
37 |
+
|
38 |
+
def __call__(
|
39 |
+
self,
|
40 |
+
input_values,
|
41 |
+
attention_mask=None,
|
42 |
+
mask_time_indices=None,
|
43 |
+
deterministic=True,
|
44 |
+
output_attentions=None,
|
45 |
+
output_hidden_states=None,
|
46 |
+
freeze_feature_encoder=False,
|
47 |
+
return_dict=None,
|
48 |
+
):
|
49 |
+
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
|
50 |
+
|
51 |
+
# make sure that no loss is computed on padded inputs
|
52 |
+
if attention_mask is not None:
|
53 |
+
# compute reduced attention_mask corresponding to feature vectors
|
54 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
55 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
56 |
+
)
|
57 |
+
|
58 |
+
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
59 |
+
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
|
60 |
+
hidden_states = jnp.where(
|
61 |
+
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
|
62 |
+
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
|
63 |
+
hidden_states,
|
64 |
+
)
|
65 |
+
|
66 |
+
encoder_outputs = self.encoder(
|
67 |
+
hidden_states,
|
68 |
+
attention_mask=attention_mask,
|
69 |
+
deterministic=deterministic,
|
70 |
+
output_attentions=output_attentions,
|
71 |
+
output_hidden_states=output_hidden_states,
|
72 |
+
return_dict=return_dict,
|
73 |
+
)
|
74 |
+
|
75 |
+
hidden_states = encoder_outputs[0]
|
76 |
+
|
77 |
+
if self.adapter is not None:
|
78 |
+
hidden_states = self.adapter(hidden_states)
|
79 |
+
|
80 |
+
if not return_dict:
|
81 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
82 |
+
|
83 |
+
return FlaxWav2Vec2BaseModelOutput(
|
84 |
+
last_hidden_state=hidden_states,
|
85 |
+
extract_features=extract_features,
|
86 |
+
hidden_states=encoder_outputs.hidden_states,
|
87 |
+
attentions=encoder_outputs.attentions,
|
88 |
+
)
|
89 |
+
|
90 |
+
def _get_feat_extract_output_lengths(
|
91 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Computes the output length of the convolutional layers
|
95 |
+
"""
|
96 |
+
|
97 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
98 |
+
|
99 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
100 |
+
# 1D convolutional layer output length formula taken
|
101 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
102 |
+
return (input_length - kernel_size) // stride + 1
|
103 |
+
|
104 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
105 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
106 |
+
|
107 |
+
if add_adapter:
|
108 |
+
for _ in range(self.config.num_adapter_layers):
|
109 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
110 |
+
|
111 |
+
return input_lengths
|
112 |
+
|
113 |
+
def _get_feature_vector_attention_mask(
|
114 |
+
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
115 |
+
):
|
116 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
117 |
+
# on inference mode.
|
118 |
+
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
119 |
+
|
120 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
121 |
+
|
122 |
+
batch_size = attention_mask.shape[0]
|
123 |
+
|
124 |
+
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
125 |
+
# these two operations makes sure that all values
|
126 |
+
# before the output lengths indices are attended to
|
127 |
+
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
128 |
+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
129 |
+
return attention_mask
|
130 |
+
|
131 |
+
|
132 |
+
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
|
133 |
+
module_class = FlaxWav2Vec2Module
|
134 |
+
|
135 |
+
|
136 |
+
class FlaxWav2Vec2ForAudioFrameClassificationModule(nn.Module):
|
137 |
+
config: Wav2Vec2Config
|
138 |
+
dtype: jnp.dtype = jnp.float32
|
139 |
+
|
140 |
+
def setup(self):
|
141 |
+
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
|
142 |
+
self.classifier = nn.Dense(
|
143 |
+
self.config.num_labels,
|
144 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
145 |
+
dtype=self.dtype,
|
146 |
+
)
|
147 |
+
|
148 |
+
def __call__(
|
149 |
+
self,
|
150 |
+
input_values,
|
151 |
+
attention_mask=None,
|
152 |
+
mask_time_indices=None,
|
153 |
+
deterministic=True,
|
154 |
+
output_attentions=None,
|
155 |
+
output_hidden_states=None,
|
156 |
+
freeze_feature_encoder=False,
|
157 |
+
return_dict=None,
|
158 |
+
):
|
159 |
+
outputs = self.wav2vec2(
|
160 |
+
input_values,
|
161 |
+
attention_mask=attention_mask,
|
162 |
+
mask_time_indices=mask_time_indices,
|
163 |
+
deterministic=deterministic,
|
164 |
+
output_attentions=output_attentions,
|
165 |
+
output_hidden_states=output_hidden_states,
|
166 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
167 |
+
return_dict=return_dict,
|
168 |
+
)
|
169 |
+
|
170 |
+
hidden_states = outputs[0]
|
171 |
+
|
172 |
+
logits = self.classifier(hidden_states)
|
173 |
+
|
174 |
+
if not return_dict:
|
175 |
+
return (logits,) + outputs[2:]
|
176 |
+
|
177 |
+
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
178 |
+
|
179 |
+
|
180 |
+
class FlaxWav2Vec2ForAudioFrameClassification(FlaxWav2Vec2PreTrainedModel):
|
181 |
+
module_class = FlaxWav2Vec2ForAudioFrameClassificationModule
|