Commit
•
6e60fe4
1
Parent(s):
080ad7c
Add scripts and weights
Browse files- README.md +39 -0
- config.json +109 -0
- flax_model.msgpack +3 -0
- get_ctc_tokenizer.py +102 -0
- models/__init__.py +6 -0
- models/configuration_bart.py +183 -0
- models/configuration_speech_encoder_decoder.py +121 -0
- models/configuration_wav2vec2.py +344 -0
- models/modeling_flax_bart.py +816 -0
- models/modeling_flax_speech_encoder_decoder.py +1245 -0
- models/modeling_flax_wav2vec2.py +975 -0
- preprocessor_config.json +10 -0
- run_flax_speech_recognition_ctc.py +1398 -0
- run_tedlium.sh +27 -0
- special_tokens_map.json +6 -0
- tokenizer_config.json +12 -0
- vocab.json +34 -0
README.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
tags:
|
5 |
+
- esc
|
6 |
+
datasets:
|
7 |
+
- tedlium
|
8 |
+
---
|
9 |
+
|
10 |
+
To reproduce this run, first call `get_ctc_tokenizer.py` to train the CTC tokenizer and then execute the following command to train the CTC system:
|
11 |
+
```python
|
12 |
+
#!/usr/bin/env bash
|
13 |
+
python run_flax_speech_recognition_ctc.py \
|
14 |
+
--model_name_or_path="esc/wav2vec2-pretrained" \
|
15 |
+
--tokenizer_name="wav2vec2-ctc-tedlium-tokenizer" \
|
16 |
+
--dataset_name="esc/esc-datasets" \
|
17 |
+
--dataset_config_name="tedlium" \
|
18 |
+
--output_dir="./" \
|
19 |
+
--wandb_project="wav2vec2-ctc" \
|
20 |
+
--wandb_name="wav2vec2-ctc-tedlium" \
|
21 |
+
--max_steps="50000" \
|
22 |
+
--save_steps="10000" \
|
23 |
+
--eval_steps="10000" \
|
24 |
+
--learning_rate="3e-4" \
|
25 |
+
--logging_steps="25" \
|
26 |
+
--warmup_steps="5000" \
|
27 |
+
--preprocessing_num_workers="1" \
|
28 |
+
--hidden_dropout="0.2" \
|
29 |
+
--activation_dropout="0.2" \
|
30 |
+
--feat_proj_dropout="0.2" \
|
31 |
+
--do_train \
|
32 |
+
--do_eval \
|
33 |
+
--do_predict \
|
34 |
+
--overwrite_output_dir \
|
35 |
+
--gradient_checkpointing \
|
36 |
+
--freeze_feature_encoder \
|
37 |
+
--push_to_hub \
|
38 |
+
--use_auth_token
|
39 |
+
```
|
config.json
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.2,
|
3 |
+
"adapter_kernel_size": 3,
|
4 |
+
"adapter_stride": 2,
|
5 |
+
"add_adapter": false,
|
6 |
+
"apply_spec_augment": true,
|
7 |
+
"architectures": [
|
8 |
+
"Wav2Vec2ForCTC"
|
9 |
+
],
|
10 |
+
"attention_dropout": 0.1,
|
11 |
+
"bos_token_id": 1,
|
12 |
+
"classifier_proj_size": 256,
|
13 |
+
"codevector_dim": 768,
|
14 |
+
"contrastive_logits_temperature": 0.1,
|
15 |
+
"conv_bias": true,
|
16 |
+
"conv_dim": [
|
17 |
+
512,
|
18 |
+
512,
|
19 |
+
512,
|
20 |
+
512,
|
21 |
+
512,
|
22 |
+
512,
|
23 |
+
512
|
24 |
+
],
|
25 |
+
"conv_kernel": [
|
26 |
+
10,
|
27 |
+
3,
|
28 |
+
3,
|
29 |
+
3,
|
30 |
+
3,
|
31 |
+
2,
|
32 |
+
2
|
33 |
+
],
|
34 |
+
"conv_stride": [
|
35 |
+
5,
|
36 |
+
2,
|
37 |
+
2,
|
38 |
+
2,
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2
|
42 |
+
],
|
43 |
+
"ctc_loss_reduction": "sum",
|
44 |
+
"ctc_zero_infinity": false,
|
45 |
+
"diversity_loss_weight": 0.1,
|
46 |
+
"do_stable_layer_norm": true,
|
47 |
+
"eos_token_id": 2,
|
48 |
+
"feat_extract_activation": "gelu",
|
49 |
+
"feat_extract_dropout": 0.0,
|
50 |
+
"feat_extract_norm": "layer",
|
51 |
+
"feat_proj_dropout": 0.2,
|
52 |
+
"feat_quantizer_dropout": 0.0,
|
53 |
+
"final_dropout": 0.0,
|
54 |
+
"fuse_matmuls": false,
|
55 |
+
"gradient_checkpointing": true,
|
56 |
+
"hidden_act": "gelu",
|
57 |
+
"hidden_dropout": 0.2,
|
58 |
+
"hidden_dropout_prob": 0.1,
|
59 |
+
"hidden_size": 1024,
|
60 |
+
"initializer_range": 0.02,
|
61 |
+
"intermediate_size": 4096,
|
62 |
+
"layer_norm_eps": 1e-05,
|
63 |
+
"layerdrop": 0.0,
|
64 |
+
"mask_feature_length": 10,
|
65 |
+
"mask_feature_min_masks": 0,
|
66 |
+
"mask_feature_prob": 0.0,
|
67 |
+
"mask_time_length": 10,
|
68 |
+
"mask_time_min_masks": 2,
|
69 |
+
"mask_time_prob": 0.1,
|
70 |
+
"model_type": "wav2vec2",
|
71 |
+
"num_adapter_layers": 3,
|
72 |
+
"num_attention_heads": 16,
|
73 |
+
"num_codevector_groups": 2,
|
74 |
+
"num_codevectors_per_group": 320,
|
75 |
+
"num_conv_pos_embedding_groups": 16,
|
76 |
+
"num_conv_pos_embeddings": 128,
|
77 |
+
"num_feat_extract_layers": 7,
|
78 |
+
"num_hidden_layers": 24,
|
79 |
+
"num_negatives": 100,
|
80 |
+
"output_hidden_size": 1024,
|
81 |
+
"pad_token_id": 0,
|
82 |
+
"proj_codevector_dim": 768,
|
83 |
+
"tdnn_dilation": [
|
84 |
+
1,
|
85 |
+
2,
|
86 |
+
3,
|
87 |
+
1,
|
88 |
+
1
|
89 |
+
],
|
90 |
+
"tdnn_dim": [
|
91 |
+
512,
|
92 |
+
512,
|
93 |
+
512,
|
94 |
+
512,
|
95 |
+
1500
|
96 |
+
],
|
97 |
+
"tdnn_kernel": [
|
98 |
+
5,
|
99 |
+
3,
|
100 |
+
3,
|
101 |
+
1,
|
102 |
+
1
|
103 |
+
],
|
104 |
+
"transformers_version": "4.21.0.dev0",
|
105 |
+
"use_scan": true,
|
106 |
+
"use_weighted_layer_sum": false,
|
107 |
+
"vocab_size": 32,
|
108 |
+
"xvector_output_dim": 512
|
109 |
+
}
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fcfeb02537a755eff6e041ef612a85e67c80ccc473dfe55d781b92d643f6d724
|
3 |
+
size 1261888150
|
get_ctc_tokenizer.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from datasets import load_dataset
|
3 |
+
from collections import Counter
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
from transformers import Wav2Vec2CTCTokenizer
|
8 |
+
|
9 |
+
# which dataset
|
10 |
+
dataset_name = "tedlium"
|
11 |
+
# which split -> we should only use train to train our tokenizer
|
12 |
+
split = "train"
|
13 |
+
# in case the dataset requires access
|
14 |
+
use_auth_token = True
|
15 |
+
# name of tok to upload to the Hub
|
16 |
+
tokenizer_name = f"wav2vec2-ctc-{dataset_name}-tokenizer"
|
17 |
+
|
18 |
+
# FIX the cutoff freq for all datasets -> an entirely dataset-agnostic approach
|
19 |
+
cutoff_freq = 0.01
|
20 |
+
|
21 |
+
dataset = load_dataset(
|
22 |
+
"esc/esc-datasets",
|
23 |
+
dataset_name,
|
24 |
+
split=split,
|
25 |
+
use_auth_token=use_auth_token,
|
26 |
+
)
|
27 |
+
|
28 |
+
# remove all data that is unnecessary to save RAM
|
29 |
+
dataset = dataset.remove_columns(list(set(dataset.column_names) - {"text"}))
|
30 |
+
|
31 |
+
# define function to see stats about letters and to create vocab
|
32 |
+
def create_vocabulary_from_data(dataset, word_delimiter_token="|", cutoff_freq=0.0):
|
33 |
+
def extract_all_chars(batch):
|
34 |
+
all_text = " ".join(batch["text"])
|
35 |
+
|
36 |
+
count_chars_dict = Counter(list(all_text))
|
37 |
+
# sort by freq
|
38 |
+
count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0]))
|
39 |
+
# retrieve dict, freq
|
40 |
+
vocab, freqs = zip(*count_chars_dict)
|
41 |
+
|
42 |
+
return {"vocab": list(vocab), "freqs": list(freqs)}
|
43 |
+
|
44 |
+
dataset = dataset.map(
|
45 |
+
extract_all_chars,
|
46 |
+
batched=True,
|
47 |
+
batch_size=-1,
|
48 |
+
remove_columns=dataset.column_names,
|
49 |
+
)
|
50 |
+
|
51 |
+
vocab, freqs = dataset["vocab"], dataset["freqs"]
|
52 |
+
total_num_chars = sum(freqs)
|
53 |
+
chars_to_remove = []
|
54 |
+
|
55 |
+
print("Character Occurences")
|
56 |
+
print(f"Total characters in dataset: {total_num_chars}")
|
57 |
+
print(50 * "-")
|
58 |
+
print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |")
|
59 |
+
print(50 * "-")
|
60 |
+
for char, freq in zip(vocab, freqs):
|
61 |
+
freq_in_percent = freq / total_num_chars * 100
|
62 |
+
print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |")
|
63 |
+
if freq_in_percent < cutoff_freq:
|
64 |
+
chars_to_remove.append(char)
|
65 |
+
print(50 * "-")
|
66 |
+
|
67 |
+
vocab = list(set(vocab) - set(chars_to_remove))
|
68 |
+
|
69 |
+
# Wav2Vec2CTC Tokenizers always have those as the first tokens (important for CTC)
|
70 |
+
vocab = ["<pad>", "<s>", "</s>", "<unk>"] + vocab
|
71 |
+
|
72 |
+
# create json dict
|
73 |
+
vocab_dict = {v: k for k, v in enumerate(list(vocab))}
|
74 |
+
|
75 |
+
# replace white space with delimiter token
|
76 |
+
if word_delimiter_token is not None:
|
77 |
+
vocab_dict[word_delimiter_token] = vocab_dict[" "]
|
78 |
+
del vocab_dict[" "]
|
79 |
+
|
80 |
+
return vocab_dict
|
81 |
+
|
82 |
+
# Note that the functions accepts the following important args
|
83 |
+
# --cutoff_freq
|
84 |
+
# => This is very important! Lots of datasets will contain "wrong" characters in the training set, e.g.
|
85 |
+
# characters that just occur a couple of times.
|
86 |
+
# By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared to the "super frequent" letters. We can see such characters as "errors" or irrelevant in the
|
87 |
+
# dataset, so that we should delete them from the vocab. During training, they would then just be classified
|
88 |
+
# unknown <unk> tokens which the model can handle.
|
89 |
+
# In this script, we deploy a mechanism to remove all chars whose freq in % is below a certain threshold.
|
90 |
+
# We FIX this threshold for all datasets (i.e. dataset-agnostic)
|
91 |
+
|
92 |
+
vocab_dict = create_vocabulary_from_data(dataset, cutoff_freq=cutoff_freq)
|
93 |
+
|
94 |
+
# save vocab dict to be loaded into tokenizer
|
95 |
+
with tempfile.TemporaryDirectory() as tmp:
|
96 |
+
with open(os.path.join(tmp, "vocab.json"), "w") as file:
|
97 |
+
json.dump(vocab_dict, file)
|
98 |
+
|
99 |
+
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tmp)
|
100 |
+
|
101 |
+
# push tokenizer to the Hub
|
102 |
+
tokenizer.push_to_hub(tokenizer_name)
|
models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.configuration_bart import BartConfig
|
2 |
+
from models.configuration_wav2vec2 import Wav2Vec2Config
|
3 |
+
from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
|
4 |
+
from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
|
5 |
+
from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
|
6 |
+
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
models/configuration_bart.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" BART model configuration"""
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
25 |
+
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
|
26 |
+
# See all BART models at https://huggingface.co/models?filter=bart
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
class BartConfig(PretrainedConfig):
|
31 |
+
r"""
|
32 |
+
This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
|
33 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
34 |
+
defaults will yield a similar configuration to that of the BART
|
35 |
+
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 50265):
|
43 |
+
Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
|
44 |
+
`inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
|
45 |
+
d_model (`int`, *optional*, defaults to 1024):
|
46 |
+
Dimensionality of the layers and the pooler layer.
|
47 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
48 |
+
Number of encoder layers.
|
49 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
50 |
+
Number of decoder layers.
|
51 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
52 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
53 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
54 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
55 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
56 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
57 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
58 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
59 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
60 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
61 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
62 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
63 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
64 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
65 |
+
The dropout ratio for the attention probabilities.
|
66 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
67 |
+
The dropout ratio for activations inside the fully connected layer.
|
68 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
69 |
+
The dropout ratio for classifier.
|
70 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
71 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
72 |
+
just in case (e.g., 512 or 1024 or 2048).
|
73 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
74 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
75 |
+
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
76 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
77 |
+
for more details.
|
78 |
+
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
79 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
80 |
+
for more details.
|
81 |
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
82 |
+
Scale embeddings by diving by sqrt(d_model).
|
83 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
84 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
85 |
+
num_labels: (`int`, *optional*, defaults to 3):
|
86 |
+
The number of labels to use in [`BartForSequenceClassification`].
|
87 |
+
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
88 |
+
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
89 |
+
`eos_token_id`.
|
90 |
+
use_scan (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether or not to use nn.scan in the Flax Bart attention layers.
|
92 |
+
|
93 |
+
Example:
|
94 |
+
|
95 |
+
```python
|
96 |
+
>>> from transformers import BartModel, BartConfig
|
97 |
+
|
98 |
+
>>> # Initializing a BART facebook/bart-large style configuration
|
99 |
+
>>> configuration = BartConfig()
|
100 |
+
|
101 |
+
>>> # Initializing a model from the facebook/bart-large style configuration
|
102 |
+
>>> model = BartModel(configuration)
|
103 |
+
|
104 |
+
>>> # Accessing the model configuration
|
105 |
+
>>> configuration = model.config
|
106 |
+
```"""
|
107 |
+
model_type = "bart"
|
108 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
109 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
vocab_size=50265,
|
114 |
+
max_position_embeddings=1024,
|
115 |
+
encoder_layers=12,
|
116 |
+
encoder_ffn_dim=4096,
|
117 |
+
encoder_attention_heads=16,
|
118 |
+
decoder_layers=12,
|
119 |
+
decoder_ffn_dim=4096,
|
120 |
+
decoder_attention_heads=16,
|
121 |
+
encoder_layerdrop=0.0,
|
122 |
+
decoder_layerdrop=0.0,
|
123 |
+
activation_function="gelu",
|
124 |
+
d_model=1024,
|
125 |
+
dropout=0.1,
|
126 |
+
attention_dropout=0.0,
|
127 |
+
activation_dropout=0.0,
|
128 |
+
init_std=0.02,
|
129 |
+
classifier_dropout=0.0,
|
130 |
+
scale_embedding=False,
|
131 |
+
use_cache=True,
|
132 |
+
use_scan=False,
|
133 |
+
fuse_matmuls=False,
|
134 |
+
num_labels=3,
|
135 |
+
pad_token_id=1,
|
136 |
+
bos_token_id=0,
|
137 |
+
eos_token_id=2,
|
138 |
+
is_encoder_decoder=True,
|
139 |
+
decoder_start_token_id=2,
|
140 |
+
forced_eos_token_id=2,
|
141 |
+
**kwargs
|
142 |
+
):
|
143 |
+
self.vocab_size = vocab_size
|
144 |
+
self.max_position_embeddings = max_position_embeddings
|
145 |
+
self.d_model = d_model
|
146 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
147 |
+
self.encoder_layers = encoder_layers
|
148 |
+
self.encoder_attention_heads = encoder_attention_heads
|
149 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
150 |
+
self.decoder_layers = decoder_layers
|
151 |
+
self.decoder_attention_heads = decoder_attention_heads
|
152 |
+
self.dropout = dropout
|
153 |
+
self.attention_dropout = attention_dropout
|
154 |
+
self.activation_dropout = activation_dropout
|
155 |
+
self.activation_function = activation_function
|
156 |
+
self.init_std = init_std
|
157 |
+
self.encoder_layerdrop = encoder_layerdrop
|
158 |
+
self.decoder_layerdrop = decoder_layerdrop
|
159 |
+
self.classifier_dropout = classifier_dropout
|
160 |
+
self.use_cache = use_cache
|
161 |
+
self.use_scan = use_scan
|
162 |
+
self.fuse_matmuls = fuse_matmuls
|
163 |
+
self.num_hidden_layers = encoder_layers
|
164 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
165 |
+
|
166 |
+
super().__init__(
|
167 |
+
num_labels=num_labels,
|
168 |
+
pad_token_id=pad_token_id,
|
169 |
+
bos_token_id=bos_token_id,
|
170 |
+
eos_token_id=eos_token_id,
|
171 |
+
is_encoder_decoder=is_encoder_decoder,
|
172 |
+
decoder_start_token_id=decoder_start_token_id,
|
173 |
+
forced_eos_token_id=forced_eos_token_id,
|
174 |
+
**kwargs,
|
175 |
+
)
|
176 |
+
|
177 |
+
# ensure backward compatibility for BART CNN models
|
178 |
+
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
179 |
+
self.forced_bos_token_id = self.bos_token_id
|
180 |
+
warnings.warn(
|
181 |
+
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
182 |
+
"The config can simply be saved and uploaded again to be fixed."
|
183 |
+
)
|
models/configuration_speech_encoder_decoder.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import copy
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
from models.configuration_wav2vec2 import Wav2Vec2Config
|
22 |
+
from models.configuration_bart import BartConfig
|
23 |
+
from transformers import AutoConfig
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class SpeechEncoderDecoderConfig(PretrainedConfig):
|
30 |
+
r"""
|
31 |
+
[`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
|
32 |
+
[`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
|
33 |
+
arguments, defining the encoder and decoder configs.
|
34 |
+
|
35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
36 |
+
documentation from [`PretrainedConfig`] for more information.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
kwargs (*optional*):
|
40 |
+
Dictionary of keyword arguments. Notably:
|
41 |
+
|
42 |
+
- **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
43 |
+
the encoder config.
|
44 |
+
- **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
45 |
+
the decoder config.
|
46 |
+
|
47 |
+
Examples:
|
48 |
+
|
49 |
+
```python
|
50 |
+
>>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
|
51 |
+
|
52 |
+
>>> # Initializing a Wav2Vec2 & BERT style configuration
|
53 |
+
>>> config_encoder = Wav2Vec2Config()
|
54 |
+
>>> config_decoder = BertConfig()
|
55 |
+
|
56 |
+
>>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
|
57 |
+
|
58 |
+
>>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
|
59 |
+
>>> model = SpeechEncoderDecoderModel(config=config)
|
60 |
+
|
61 |
+
>>> # Accessing the model configuration
|
62 |
+
>>> config_encoder = model.config.encoder
|
63 |
+
>>> config_decoder = model.config.decoder
|
64 |
+
>>> # set decoder config to causal lm
|
65 |
+
>>> config_decoder.is_decoder = True
|
66 |
+
>>> config_decoder.add_cross_attention = True
|
67 |
+
|
68 |
+
>>> # Saving the model, including its configuration
|
69 |
+
>>> model.save_pretrained("my-model")
|
70 |
+
|
71 |
+
>>> # loading model and config from pretrained folder
|
72 |
+
>>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
|
73 |
+
>>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
|
74 |
+
```"""
|
75 |
+
model_type = "speech-encoder-decoder"
|
76 |
+
is_composition = True
|
77 |
+
|
78 |
+
def __init__(self, **kwargs):
|
79 |
+
super().__init__(**kwargs)
|
80 |
+
if "encoder" not in kwargs or "decoder" not in kwargs:
|
81 |
+
raise ValueError(
|
82 |
+
f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
|
83 |
+
)
|
84 |
+
|
85 |
+
encoder_config = kwargs.pop("encoder")
|
86 |
+
decoder_config = kwargs.pop("decoder")
|
87 |
+
|
88 |
+
# TODO: Load configs from AutoConfig (as done in Transformers 🤗)
|
89 |
+
self.encoder = Wav2Vec2Config(**encoder_config)
|
90 |
+
self.decoder = BartConfig(**decoder_config)
|
91 |
+
self.is_encoder_decoder = True
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def from_encoder_decoder_configs(
|
95 |
+
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
96 |
+
) -> PretrainedConfig:
|
97 |
+
r"""
|
98 |
+
Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
|
99 |
+
configuration and decoder model configuration.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
[`SpeechEncoderDecoderConfig`]: An instance of a configuration object
|
103 |
+
"""
|
104 |
+
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
105 |
+
decoder_config.is_decoder = True
|
106 |
+
decoder_config.add_cross_attention = True
|
107 |
+
|
108 |
+
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
|
109 |
+
|
110 |
+
def to_dict(self):
|
111 |
+
"""
|
112 |
+
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
116 |
+
"""
|
117 |
+
output = copy.deepcopy(self.__dict__)
|
118 |
+
output["encoder"] = self.encoder.to_dict()
|
119 |
+
output["decoder"] = self.decoder.to_dict()
|
120 |
+
output["model_type"] = self.__class__.model_type
|
121 |
+
return output
|
models/configuration_wav2vec2.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Wav2Vec2 model configuration"""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
import operator
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
|
28 |
+
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class Wav2Vec2Config(PretrainedConfig):
|
33 |
+
r"""
|
34 |
+
This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
|
35 |
+
Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
36 |
+
with the defaults will yield a similar configuration to that of the Wav2Vec2
|
37 |
+
[facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
|
38 |
+
|
39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
40 |
+
documentation from [`PretrainedConfig`] for more information.
|
41 |
+
|
42 |
+
|
43 |
+
Args:
|
44 |
+
vocab_size (`int`, *optional*, defaults to 32):
|
45 |
+
Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
|
46 |
+
the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
|
47 |
+
model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
|
48 |
+
method of [`Wav2Vec2Model`].
|
49 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
50 |
+
Dimensionality of the encoder layers and the pooler layer.
|
51 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
52 |
+
Number of hidden layers in the Transformer encoder.
|
53 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
54 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
55 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
56 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
57 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
58 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
59 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
60 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
61 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
62 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
63 |
+
The dropout ratio for the attention probabilities.
|
64 |
+
final_dropout (`float`, *optional*, defaults to 0.1):
|
65 |
+
The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
|
66 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
67 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
68 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
69 |
+
The epsilon used by the layer normalization layers.
|
70 |
+
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
|
71 |
+
The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
|
72 |
+
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
|
73 |
+
convolutional layers.
|
74 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
75 |
+
The dropout probability for output of the feature encoder.
|
76 |
+
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
77 |
+
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
78 |
+
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
79 |
+
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
80 |
+
The dropout probabilitiy for quantized feature encoder states.
|
81 |
+
conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
|
82 |
+
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
83 |
+
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
|
84 |
+
conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
|
85 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
|
86 |
+
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
|
87 |
+
conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
|
88 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
|
89 |
+
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
90 |
+
*conv_dim*.
|
91 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
92 |
+
Whether the 1D convolutional layers have a bias.
|
93 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
94 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
95 |
+
embeddings layer.
|
96 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
97 |
+
Number of groups of 1D convolutional positional embeddings layer.
|
98 |
+
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
|
99 |
+
Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
|
100 |
+
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
|
101 |
+
False` corresponds to applying layer norm after the attention layer.
|
102 |
+
apply_spec_augment (`bool`, *optional*, defaults to `True`):
|
103 |
+
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
|
104 |
+
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
105 |
+
Recognition](https://arxiv.org/abs/1904.08779).
|
106 |
+
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
107 |
+
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
108 |
+
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
109 |
+
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
110 |
+
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
111 |
+
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
|
112 |
+
mask_time_length (`int`, *optional*, defaults to 10):
|
113 |
+
Length of vector span along the time axis.
|
114 |
+
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
115 |
+
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
116 |
+
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
117 |
+
mask_time_min_masks''
|
118 |
+
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
119 |
+
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
120 |
+
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
121 |
+
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
122 |
+
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
123 |
+
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
124 |
+
True`.
|
125 |
+
mask_feature_length (`int`, *optional*, defaults to 10):
|
126 |
+
Length of vector span along the feature axis.
|
127 |
+
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
128 |
+
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
129 |
+
step, irrespectively of `mask_feature_prob`. Only relevant if
|
130 |
+
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
131 |
+
num_codevectors_per_group (`int`, *optional*, defaults to 320):
|
132 |
+
Number of entries in each quantization codebook (group).
|
133 |
+
num_codevector_groups (`int`, *optional*, defaults to 2):
|
134 |
+
Number of codevector groups for product codevector quantization.
|
135 |
+
contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
|
136 |
+
The temperature *kappa* in the contrastive loss.
|
137 |
+
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
138 |
+
The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
|
139 |
+
num_negatives (`int`, *optional*, defaults to 100):
|
140 |
+
Number of negative samples for the contrastive loss.
|
141 |
+
codevector_dim (`int`, *optional*, defaults to 256):
|
142 |
+
Dimensionality of the quantized feature vectors.
|
143 |
+
proj_codevector_dim (`int`, *optional*, defaults to 256):
|
144 |
+
Dimensionality of the final projection of both the quantized and the transformer features.
|
145 |
+
diversity_loss_weight (`int`, *optional*, defaults to 0.1):
|
146 |
+
The weight of the codebook diversity loss component.
|
147 |
+
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
148 |
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
149 |
+
instance of [`Wav2Vec2ForCTC`].
|
150 |
+
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
151 |
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
152 |
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
153 |
+
of [`Wav2Vec2ForCTC`].
|
154 |
+
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
|
155 |
+
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
156 |
+
instance of [`Wav2Vec2ForSequenceClassification`].
|
157 |
+
classifier_proj_size (`int`, *optional*, defaults to 256):
|
158 |
+
Dimensionality of the projection before token mean-pooling for classification.
|
159 |
+
tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
|
160 |
+
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
|
161 |
+
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
|
162 |
+
tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
|
163 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
|
164 |
+
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
|
165 |
+
tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
|
166 |
+
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
|
167 |
+
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
|
168 |
+
xvector_output_dim (`int`, *optional*, defaults to 512):
|
169 |
+
Dimensionality of the *XVector* embedding vectors.
|
170 |
+
add_adapter (`bool`, *optional*, defaults to `False`):
|
171 |
+
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
172 |
+
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
173 |
+
adapter_kernel_size (`int`, *optional*, defaults to 3):
|
174 |
+
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
175 |
+
adapter_stride (`int`, *optional*, defaults to 2):
|
176 |
+
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
177 |
+
num_adapter_layers (`int`, *optional*, defaults to 3):
|
178 |
+
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
|
179 |
+
True`.
|
180 |
+
output_hidden_size (`int`, *optional*):
|
181 |
+
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
|
182 |
+
if `add_adapter is True`.
|
183 |
+
use_scan (`bool`, *optional*, defaults to `False`):
|
184 |
+
Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
|
185 |
+
|
186 |
+
Example:
|
187 |
+
|
188 |
+
```python
|
189 |
+
>>> from transformers import Wav2Vec2Model, Wav2Vec2Config
|
190 |
+
|
191 |
+
>>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
|
192 |
+
>>> configuration = Wav2Vec2Config()
|
193 |
+
|
194 |
+
>>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
|
195 |
+
>>> model = Wav2Vec2Model(configuration)
|
196 |
+
|
197 |
+
>>> # Accessing the model configuration
|
198 |
+
>>> configuration = model.config
|
199 |
+
```"""
|
200 |
+
model_type = "wav2vec2"
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
vocab_size=32,
|
205 |
+
hidden_size=768,
|
206 |
+
num_hidden_layers=12,
|
207 |
+
num_attention_heads=12,
|
208 |
+
intermediate_size=3072,
|
209 |
+
hidden_act="gelu",
|
210 |
+
hidden_dropout=0.1,
|
211 |
+
activation_dropout=0.1,
|
212 |
+
attention_dropout=0.1,
|
213 |
+
feat_proj_dropout=0.0,
|
214 |
+
feat_quantizer_dropout=0.0,
|
215 |
+
final_dropout=0.1,
|
216 |
+
layerdrop=0.1,
|
217 |
+
initializer_range=0.02,
|
218 |
+
layer_norm_eps=1e-5,
|
219 |
+
feat_extract_norm="group",
|
220 |
+
feat_extract_activation="gelu",
|
221 |
+
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
222 |
+
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
223 |
+
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
224 |
+
conv_bias=False,
|
225 |
+
num_conv_pos_embeddings=128,
|
226 |
+
num_conv_pos_embedding_groups=16,
|
227 |
+
do_stable_layer_norm=False,
|
228 |
+
apply_spec_augment=True,
|
229 |
+
mask_time_prob=0.05,
|
230 |
+
mask_time_length=10,
|
231 |
+
mask_time_min_masks=2,
|
232 |
+
mask_feature_prob=0.0,
|
233 |
+
mask_feature_length=10,
|
234 |
+
mask_feature_min_masks=0,
|
235 |
+
num_codevectors_per_group=320,
|
236 |
+
num_codevector_groups=2,
|
237 |
+
contrastive_logits_temperature=0.1,
|
238 |
+
num_negatives=100,
|
239 |
+
codevector_dim=256,
|
240 |
+
proj_codevector_dim=256,
|
241 |
+
diversity_loss_weight=0.1,
|
242 |
+
ctc_loss_reduction="sum",
|
243 |
+
ctc_zero_infinity=False,
|
244 |
+
use_weighted_layer_sum=False,
|
245 |
+
classifier_proj_size=256,
|
246 |
+
tdnn_dim=(512, 512, 512, 512, 1500),
|
247 |
+
tdnn_kernel=(5, 3, 3, 1, 1),
|
248 |
+
tdnn_dilation=(1, 2, 3, 1, 1),
|
249 |
+
xvector_output_dim=512,
|
250 |
+
pad_token_id=0,
|
251 |
+
bos_token_id=1,
|
252 |
+
eos_token_id=2,
|
253 |
+
add_adapter=False,
|
254 |
+
adapter_kernel_size=3,
|
255 |
+
adapter_stride=2,
|
256 |
+
num_adapter_layers=3,
|
257 |
+
output_hidden_size=None,
|
258 |
+
use_scan=False,
|
259 |
+
fuse_matmuls=False,
|
260 |
+
**kwargs
|
261 |
+
):
|
262 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
263 |
+
self.hidden_size = hidden_size
|
264 |
+
self.feat_extract_norm = feat_extract_norm
|
265 |
+
self.feat_extract_activation = feat_extract_activation
|
266 |
+
self.conv_dim = list(conv_dim)
|
267 |
+
self.conv_stride = list(conv_stride)
|
268 |
+
self.conv_kernel = list(conv_kernel)
|
269 |
+
self.conv_bias = conv_bias
|
270 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
271 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
272 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
273 |
+
self.num_hidden_layers = num_hidden_layers
|
274 |
+
self.intermediate_size = intermediate_size
|
275 |
+
self.hidden_act = hidden_act
|
276 |
+
self.num_attention_heads = num_attention_heads
|
277 |
+
self.hidden_dropout = hidden_dropout
|
278 |
+
self.attention_dropout = attention_dropout
|
279 |
+
self.activation_dropout = activation_dropout
|
280 |
+
self.feat_proj_dropout = feat_proj_dropout
|
281 |
+
self.final_dropout = final_dropout
|
282 |
+
self.layerdrop = layerdrop
|
283 |
+
self.layer_norm_eps = layer_norm_eps
|
284 |
+
self.initializer_range = initializer_range
|
285 |
+
self.vocab_size = vocab_size
|
286 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
287 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
288 |
+
self.use_scan = use_scan
|
289 |
+
self.fuse_matmuls = fuse_matmuls
|
290 |
+
|
291 |
+
if (
|
292 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
293 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
294 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
295 |
+
):
|
296 |
+
raise ValueError(
|
297 |
+
"Configuration for convolutional layers is incorrect. "
|
298 |
+
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
|
299 |
+
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
|
300 |
+
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
301 |
+
)
|
302 |
+
|
303 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
304 |
+
self.apply_spec_augment = apply_spec_augment
|
305 |
+
self.mask_time_prob = mask_time_prob
|
306 |
+
self.mask_time_length = mask_time_length
|
307 |
+
self.mask_time_min_masks = mask_time_min_masks
|
308 |
+
self.mask_feature_prob = mask_feature_prob
|
309 |
+
self.mask_feature_length = mask_feature_length
|
310 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
311 |
+
|
312 |
+
# parameters for pretraining with codevector quantized representations
|
313 |
+
self.num_codevectors_per_group = num_codevectors_per_group
|
314 |
+
self.num_codevector_groups = num_codevector_groups
|
315 |
+
self.contrastive_logits_temperature = contrastive_logits_temperature
|
316 |
+
self.feat_quantizer_dropout = feat_quantizer_dropout
|
317 |
+
self.num_negatives = num_negatives
|
318 |
+
self.codevector_dim = codevector_dim
|
319 |
+
self.proj_codevector_dim = proj_codevector_dim
|
320 |
+
self.diversity_loss_weight = diversity_loss_weight
|
321 |
+
|
322 |
+
# ctc loss
|
323 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
324 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
325 |
+
|
326 |
+
# adapter
|
327 |
+
self.add_adapter = add_adapter
|
328 |
+
self.adapter_kernel_size = adapter_kernel_size
|
329 |
+
self.adapter_stride = adapter_stride
|
330 |
+
self.num_adapter_layers = num_adapter_layers
|
331 |
+
self.output_hidden_size = output_hidden_size or hidden_size
|
332 |
+
|
333 |
+
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
334 |
+
self.classifier_proj_size = classifier_proj_size
|
335 |
+
|
336 |
+
# XVector-specific parameters. Feel free to ignore for other classes.
|
337 |
+
self.tdnn_dim = list(tdnn_dim)
|
338 |
+
self.tdnn_kernel = list(tdnn_kernel)
|
339 |
+
self.tdnn_dilation = list(tdnn_dilation)
|
340 |
+
self.xvector_output_dim = xvector_output_dim
|
341 |
+
|
342 |
+
@property
|
343 |
+
def inputs_to_logits_ratio(self):
|
344 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|
models/modeling_flax_bart.py
ADDED
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Flax Bart model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import random
|
19 |
+
from functools import partial
|
20 |
+
from typing import Optional, Tuple
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
import flax.linen as nn
|
25 |
+
import jax
|
26 |
+
import jax.numpy as jnp
|
27 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
28 |
+
from flax.linen import combine_masks, make_causal_mask
|
29 |
+
from flax.linen import partitioning as nn_partitioning
|
30 |
+
from flax.linen.attention import dot_product_attention_weights
|
31 |
+
from jax import lax
|
32 |
+
from jax.random import PRNGKey
|
33 |
+
|
34 |
+
from transformers.modeling_flax_outputs import (
|
35 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
36 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
37 |
+
)
|
38 |
+
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
39 |
+
|
40 |
+
from models import BartConfig
|
41 |
+
|
42 |
+
|
43 |
+
scan_with_axes = nn_partitioning.scan_with_axes
|
44 |
+
remat = nn_partitioning.remat
|
45 |
+
|
46 |
+
|
47 |
+
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
|
48 |
+
"""
|
49 |
+
Shift input ids one token to the right.
|
50 |
+
"""
|
51 |
+
shifted_input_ids = np.zeros_like(input_ids)
|
52 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
53 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
54 |
+
|
55 |
+
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
56 |
+
return shifted_input_ids
|
57 |
+
|
58 |
+
|
59 |
+
class FlaxBartAttention(nn.Module):
|
60 |
+
config: BartConfig
|
61 |
+
embed_dim: int
|
62 |
+
num_heads: int
|
63 |
+
dropout: float = 0.0
|
64 |
+
causal: bool = False
|
65 |
+
bias: bool = True
|
66 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
67 |
+
|
68 |
+
def setup(self) -> None:
|
69 |
+
self.head_dim = self.embed_dim // self.num_heads
|
70 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
71 |
+
raise ValueError(
|
72 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
73 |
+
f" and `num_heads`: {self.num_heads})."
|
74 |
+
)
|
75 |
+
|
76 |
+
dense = partial(
|
77 |
+
nn.Dense,
|
78 |
+
self.embed_dim,
|
79 |
+
use_bias=self.bias,
|
80 |
+
dtype=self.dtype,
|
81 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
82 |
+
)
|
83 |
+
|
84 |
+
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
85 |
+
|
86 |
+
self.fused_proj = nn.Dense(
|
87 |
+
self.embed_dim * 3,
|
88 |
+
use_bias=self.bias,
|
89 |
+
dtype=self.dtype,
|
90 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
91 |
+
)
|
92 |
+
|
93 |
+
self.fused_key_value = nn.Dense(
|
94 |
+
self.embed_dim * 2,
|
95 |
+
use_bias=self.bias,
|
96 |
+
dtype=self.dtype,
|
97 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
98 |
+
)
|
99 |
+
|
100 |
+
self.out_proj = dense()
|
101 |
+
|
102 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
103 |
+
|
104 |
+
if self.causal:
|
105 |
+
self.causal_mask = make_causal_mask(
|
106 |
+
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
107 |
+
)
|
108 |
+
|
109 |
+
def _split_heads(self, hidden_states):
|
110 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
111 |
+
|
112 |
+
def _merge_heads(self, hidden_states):
|
113 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
114 |
+
|
115 |
+
@nn.compact
|
116 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
117 |
+
"""
|
118 |
+
This function takes projected key, value states from a single input token and concatenates the states to cached
|
119 |
+
states from previous steps. This function is slighly adapted from the official Flax repository:
|
120 |
+
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
121 |
+
"""
|
122 |
+
# detect if we're initializing by absence of existing cache data.
|
123 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
124 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
125 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
126 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
127 |
+
|
128 |
+
if is_initialized:
|
129 |
+
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
130 |
+
# update key, value caches with our new 1d spatial slices
|
131 |
+
cur_index = cache_index.value
|
132 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
133 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
134 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
135 |
+
cached_key.value = key
|
136 |
+
cached_value.value = value
|
137 |
+
num_updated_cache_vectors = query.shape[1]
|
138 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
139 |
+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
140 |
+
pad_mask = jnp.broadcast_to(
|
141 |
+
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
142 |
+
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
143 |
+
)
|
144 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
145 |
+
return key, value, attention_mask
|
146 |
+
|
147 |
+
def __call__(
|
148 |
+
self,
|
149 |
+
hidden_states: jnp.ndarray,
|
150 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
151 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
152 |
+
init_cache: bool = False,
|
153 |
+
deterministic: bool = True,
|
154 |
+
) -> Tuple[jnp.ndarray]:
|
155 |
+
"""Input shape: Batch x Time x Channel"""
|
156 |
+
|
157 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
158 |
+
# for the decoder
|
159 |
+
is_cross_attention = key_value_states is not None
|
160 |
+
batch_size = hidden_states.shape[0]
|
161 |
+
|
162 |
+
if self.config.fuse_matmuls:
|
163 |
+
# get key, value proj
|
164 |
+
if is_cross_attention:
|
165 |
+
# get query proj
|
166 |
+
query_states = self.q_proj(hidden_states)
|
167 |
+
# cross_attentions
|
168 |
+
attention_states = self.fused_key_value(key_value_states)
|
169 |
+
key_states, value_states = jnp.split(attention_states, 2, axis=-1)
|
170 |
+
else:
|
171 |
+
attention_states = self.fused_proj(hidden_states)
|
172 |
+
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
|
173 |
+
|
174 |
+
else:
|
175 |
+
# get query proj
|
176 |
+
query_states = self.q_proj(hidden_states)
|
177 |
+
# get key, value proj
|
178 |
+
if is_cross_attention:
|
179 |
+
# cross_attentions
|
180 |
+
key_states = self.k_proj(key_value_states)
|
181 |
+
value_states = self.v_proj(key_value_states)
|
182 |
+
else:
|
183 |
+
# self_attention
|
184 |
+
key_states = self.k_proj(hidden_states)
|
185 |
+
value_states = self.v_proj(hidden_states)
|
186 |
+
|
187 |
+
query_states = self._split_heads(query_states)
|
188 |
+
key_states = self._split_heads(key_states)
|
189 |
+
value_states = self._split_heads(value_states)
|
190 |
+
|
191 |
+
# handle cache prepare causal attention mask
|
192 |
+
if self.causal:
|
193 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
194 |
+
if self.has_variable("cache", "cached_key"):
|
195 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
196 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
197 |
+
causal_mask = lax.dynamic_slice(
|
198 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
202 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
203 |
+
|
204 |
+
# combine masks if needed
|
205 |
+
if attention_mask is not None and self.causal:
|
206 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
207 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
208 |
+
elif self.causal:
|
209 |
+
attention_mask = causal_mask
|
210 |
+
elif attention_mask is not None:
|
211 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
212 |
+
|
213 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
214 |
+
# and cache the keys and values step by step.
|
215 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
216 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
217 |
+
key_states, value_states, query_states, attention_mask
|
218 |
+
)
|
219 |
+
|
220 |
+
# Convert the boolean attention mask to an attention bias.
|
221 |
+
if attention_mask is not None:
|
222 |
+
# attention mask in the form of attention bias
|
223 |
+
attention_bias = lax.select(
|
224 |
+
attention_mask > 0,
|
225 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
226 |
+
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
attention_bias = None
|
230 |
+
|
231 |
+
dropout_rng = None
|
232 |
+
if not deterministic and self.dropout > 0.0:
|
233 |
+
dropout_rng = self.make_rng("dropout")
|
234 |
+
|
235 |
+
attn_weights = dot_product_attention_weights(
|
236 |
+
query_states,
|
237 |
+
key_states,
|
238 |
+
bias=attention_bias,
|
239 |
+
dropout_rng=dropout_rng,
|
240 |
+
dropout_rate=self.dropout,
|
241 |
+
broadcast_dropout=True,
|
242 |
+
deterministic=deterministic,
|
243 |
+
dtype=self.dtype,
|
244 |
+
precision=None,
|
245 |
+
)
|
246 |
+
|
247 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
248 |
+
attn_output = self._merge_heads(attn_output)
|
249 |
+
attn_output = self.out_proj(attn_output)
|
250 |
+
|
251 |
+
return attn_output, attn_weights
|
252 |
+
|
253 |
+
|
254 |
+
class FlaxBartDecoderLayer(nn.Module):
|
255 |
+
config: BartConfig
|
256 |
+
dtype: jnp.dtype = jnp.float32
|
257 |
+
|
258 |
+
def setup(self) -> None:
|
259 |
+
self.embed_dim = self.config.d_model
|
260 |
+
self.self_attn = FlaxBartAttention(
|
261 |
+
config=self.config,
|
262 |
+
embed_dim=self.embed_dim,
|
263 |
+
num_heads=self.config.decoder_attention_heads,
|
264 |
+
dropout=self.config.attention_dropout,
|
265 |
+
causal=True,
|
266 |
+
dtype=self.dtype,
|
267 |
+
)
|
268 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
269 |
+
self.activation_fn = ACT2FN[self.config.activation_function]
|
270 |
+
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
271 |
+
|
272 |
+
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
273 |
+
self.encoder_attn = FlaxBartAttention(
|
274 |
+
config=self.config,
|
275 |
+
embed_dim=self.embed_dim,
|
276 |
+
num_heads=self.config.decoder_attention_heads,
|
277 |
+
dropout=self.config.attention_dropout,
|
278 |
+
dtype=self.dtype,
|
279 |
+
)
|
280 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
281 |
+
self.fc1 = nn.Dense(
|
282 |
+
self.config.encoder_ffn_dim,
|
283 |
+
dtype=self.dtype,
|
284 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
285 |
+
)
|
286 |
+
self.fc2 = nn.Dense(
|
287 |
+
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
288 |
+
)
|
289 |
+
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
290 |
+
|
291 |
+
def __call__(
|
292 |
+
self,
|
293 |
+
hidden_states: jnp.ndarray,
|
294 |
+
attention_mask: jnp.ndarray,
|
295 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
296 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
297 |
+
init_cache: bool = False,
|
298 |
+
output_attentions: bool = True,
|
299 |
+
deterministic: bool = True,
|
300 |
+
) -> Tuple[jnp.ndarray]:
|
301 |
+
|
302 |
+
if self.config.use_scan:
|
303 |
+
hidden_states = hidden_states[0]
|
304 |
+
|
305 |
+
residual = hidden_states
|
306 |
+
|
307 |
+
# Self Attention
|
308 |
+
hidden_states, self_attn_weights = self.self_attn(
|
309 |
+
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
|
310 |
+
)
|
311 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
312 |
+
hidden_states = residual + hidden_states
|
313 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
314 |
+
|
315 |
+
# Cross-Attention Block
|
316 |
+
cross_attn_weights = None
|
317 |
+
if encoder_hidden_states is not None:
|
318 |
+
residual = hidden_states
|
319 |
+
|
320 |
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
321 |
+
hidden_states=hidden_states,
|
322 |
+
key_value_states=encoder_hidden_states,
|
323 |
+
attention_mask=encoder_attention_mask,
|
324 |
+
)
|
325 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
326 |
+
hidden_states = residual + hidden_states
|
327 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
328 |
+
|
329 |
+
# Fully Connected
|
330 |
+
residual = hidden_states
|
331 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
332 |
+
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
|
333 |
+
hidden_states = self.fc2(hidden_states)
|
334 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
335 |
+
hidden_states = residual + hidden_states
|
336 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
337 |
+
|
338 |
+
outputs = (hidden_states,)
|
339 |
+
|
340 |
+
if output_attentions:
|
341 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
342 |
+
|
343 |
+
if self.config.use_scan:
|
344 |
+
outputs = (outputs, None)
|
345 |
+
|
346 |
+
return outputs
|
347 |
+
|
348 |
+
|
349 |
+
class FlaxBartDecoderLayerCollection(nn.Module):
|
350 |
+
config: BartConfig
|
351 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
352 |
+
|
353 |
+
@nn.compact
|
354 |
+
def __call__(
|
355 |
+
self,
|
356 |
+
hidden_states,
|
357 |
+
attention_mask,
|
358 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
359 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
360 |
+
deterministic: bool = True,
|
361 |
+
init_cache: bool = False,
|
362 |
+
output_attentions: bool = False,
|
363 |
+
output_hidden_states: bool = False,
|
364 |
+
return_dict: bool = True,
|
365 |
+
):
|
366 |
+
# decoder layers
|
367 |
+
all_hidden_states = () if output_hidden_states else None
|
368 |
+
all_self_attns = () if output_attentions else None
|
369 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
370 |
+
|
371 |
+
num_decoder_layers = self.config.decoder_layers
|
372 |
+
BlockDecoderLayer = (
|
373 |
+
remat(
|
374 |
+
FlaxBartDecoderLayer,
|
375 |
+
static_argnums=(4, 5, 6),
|
376 |
+
prevent_cse=not self.config.use_scan,
|
377 |
+
)
|
378 |
+
if self.config.gradient_checkpointing
|
379 |
+
else FlaxBartDecoderLayer
|
380 |
+
)
|
381 |
+
|
382 |
+
if self.config.use_scan:
|
383 |
+
# since all decoder layers are the same, we use nn.scan directly
|
384 |
+
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
|
385 |
+
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
|
386 |
+
hidden_states = (hidden_states,)
|
387 |
+
|
388 |
+
# TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
|
389 |
+
hidden_states, _ = scan_with_axes(
|
390 |
+
BlockDecoderLayer,
|
391 |
+
variable_axes={"params": 0, "cache": 0},
|
392 |
+
split_rngs={"params": True, "dropout": True},
|
393 |
+
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
|
394 |
+
length=num_decoder_layers,
|
395 |
+
)(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
|
396 |
+
hidden_states,
|
397 |
+
attention_mask,
|
398 |
+
encoder_hidden_states,
|
399 |
+
encoder_attention_mask,
|
400 |
+
init_cache,
|
401 |
+
output_attentions,
|
402 |
+
deterministic,
|
403 |
+
)
|
404 |
+
hidden_states = hidden_states[0]
|
405 |
+
|
406 |
+
else:
|
407 |
+
for layer in range(num_decoder_layers):
|
408 |
+
if output_hidden_states:
|
409 |
+
all_hidden_states += (hidden_states,)
|
410 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
411 |
+
dropout_probability = random.uniform(0, 1)
|
412 |
+
if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
|
413 |
+
layer_outputs = (None, None, None)
|
414 |
+
else:
|
415 |
+
layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
|
416 |
+
hidden_states,
|
417 |
+
attention_mask,
|
418 |
+
encoder_hidden_states,
|
419 |
+
encoder_attention_mask,
|
420 |
+
init_cache,
|
421 |
+
output_attentions,
|
422 |
+
deterministic,
|
423 |
+
)
|
424 |
+
|
425 |
+
hidden_states = layer_outputs[0]
|
426 |
+
if output_attentions:
|
427 |
+
all_self_attns += (layer_outputs[1],)
|
428 |
+
|
429 |
+
if encoder_hidden_states is not None:
|
430 |
+
all_cross_attentions += (layer_outputs[2],)
|
431 |
+
|
432 |
+
# add hidden states from the last decoder layer
|
433 |
+
if output_hidden_states:
|
434 |
+
all_hidden_states += (hidden_states,)
|
435 |
+
|
436 |
+
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
|
437 |
+
|
438 |
+
if not return_dict:
|
439 |
+
return tuple(v for v in outputs if v is not None)
|
440 |
+
|
441 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
442 |
+
last_hidden_state=hidden_states,
|
443 |
+
hidden_states=all_hidden_states,
|
444 |
+
attentions=all_self_attns,
|
445 |
+
cross_attentions=all_cross_attentions,
|
446 |
+
)
|
447 |
+
|
448 |
+
|
449 |
+
class FlaxBartDecoder(nn.Module):
|
450 |
+
config: BartConfig
|
451 |
+
embed_tokens: nn.Embed
|
452 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
453 |
+
|
454 |
+
def setup(self):
|
455 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
456 |
+
|
457 |
+
embed_dim = self.config.d_model
|
458 |
+
self.padding_idx = self.config.pad_token_id
|
459 |
+
self.max_target_positions = self.config.max_position_embeddings
|
460 |
+
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
461 |
+
|
462 |
+
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
463 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
464 |
+
self.offset = 2
|
465 |
+
self.embed_positions = nn.Embed(
|
466 |
+
self.config.max_position_embeddings + self.offset,
|
467 |
+
embed_dim,
|
468 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
469 |
+
)
|
470 |
+
|
471 |
+
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
472 |
+
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
473 |
+
|
474 |
+
def __call__(
|
475 |
+
self,
|
476 |
+
input_ids,
|
477 |
+
attention_mask,
|
478 |
+
position_ids,
|
479 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
480 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
481 |
+
init_cache: bool = False,
|
482 |
+
output_attentions: bool = False,
|
483 |
+
output_hidden_states: bool = False,
|
484 |
+
return_dict: bool = True,
|
485 |
+
deterministic: bool = True,
|
486 |
+
):
|
487 |
+
input_shape = input_ids.shape
|
488 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
489 |
+
|
490 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
491 |
+
|
492 |
+
# embed positions
|
493 |
+
positions = self.embed_positions(position_ids + self.offset)
|
494 |
+
|
495 |
+
hidden_states = inputs_embeds + positions
|
496 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
497 |
+
|
498 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
499 |
+
|
500 |
+
outputs = self.layers(
|
501 |
+
hidden_states,
|
502 |
+
attention_mask,
|
503 |
+
encoder_hidden_states,
|
504 |
+
encoder_attention_mask,
|
505 |
+
deterministic=deterministic,
|
506 |
+
init_cache=init_cache,
|
507 |
+
output_attentions=output_attentions,
|
508 |
+
output_hidden_states=output_hidden_states,
|
509 |
+
return_dict=return_dict,
|
510 |
+
)
|
511 |
+
|
512 |
+
if not return_dict:
|
513 |
+
return outputs
|
514 |
+
|
515 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
516 |
+
last_hidden_state=outputs.last_hidden_state,
|
517 |
+
hidden_states=outputs.hidden_states,
|
518 |
+
attentions=outputs.attentions,
|
519 |
+
cross_attentions=outputs.cross_attentions,
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
|
524 |
+
config_class = BartConfig
|
525 |
+
base_model_prefix: str = "model"
|
526 |
+
module_class: nn.Module = None
|
527 |
+
|
528 |
+
def __init__(
|
529 |
+
self,
|
530 |
+
config: BartConfig,
|
531 |
+
input_shape: Tuple[int] = (1, 1),
|
532 |
+
seed: int = 0,
|
533 |
+
dtype: jnp.dtype = jnp.float32,
|
534 |
+
_do_init: bool = True,
|
535 |
+
**kwargs
|
536 |
+
):
|
537 |
+
config.is_decoder = True
|
538 |
+
config.is_encoder_decoder = False
|
539 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
540 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
541 |
+
|
542 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
543 |
+
# init input tensors
|
544 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
545 |
+
attention_mask = jnp.ones_like(input_ids)
|
546 |
+
|
547 |
+
batch_size, sequence_length = input_ids.shape
|
548 |
+
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
549 |
+
|
550 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
551 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
552 |
+
encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
|
553 |
+
encoder_attention_mask = attention_mask
|
554 |
+
module_init_outputs = self.module.init(
|
555 |
+
rngs,
|
556 |
+
input_ids,
|
557 |
+
attention_mask,
|
558 |
+
position_ids,
|
559 |
+
encoder_hidden_states,
|
560 |
+
encoder_attention_mask,
|
561 |
+
return_dict=False,
|
562 |
+
)
|
563 |
+
return module_init_outputs["params"]
|
564 |
+
|
565 |
+
def init_cache(self, batch_size, max_length):
|
566 |
+
r"""
|
567 |
+
Args:
|
568 |
+
batch_size (`int`):
|
569 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
570 |
+
max_length (`int`):
|
571 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
572 |
+
cache.
|
573 |
+
"""
|
574 |
+
# init input variables to retrieve cache
|
575 |
+
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
576 |
+
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
577 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
578 |
+
|
579 |
+
init_variables = self.module.init(
|
580 |
+
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
581 |
+
)
|
582 |
+
return unfreeze(init_variables["cache"])
|
583 |
+
|
584 |
+
def __call__(
|
585 |
+
self,
|
586 |
+
input_ids: jnp.ndarray,
|
587 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
588 |
+
position_ids: Optional[jnp.ndarray] = None,
|
589 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
590 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
591 |
+
output_attentions: Optional[bool] = None,
|
592 |
+
output_hidden_states: Optional[bool] = None,
|
593 |
+
return_dict: Optional[bool] = None,
|
594 |
+
train: bool = False,
|
595 |
+
params: dict = None,
|
596 |
+
past_key_values: dict = None,
|
597 |
+
dropout_rng: PRNGKey = None,
|
598 |
+
):
|
599 |
+
"""
|
600 |
+
Args:
|
601 |
+
input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
|
602 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
603 |
+
|
604 |
+
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
605 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
606 |
+
|
607 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
608 |
+
|
609 |
+
For translation and summarization training, `decoder_input_ids` should be provided. If no
|
610 |
+
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
|
611 |
+
for denoising pre-training following the paper.
|
612 |
+
attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
|
613 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
614 |
+
be used by default.
|
615 |
+
|
616 |
+
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
|
617 |
+
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
618 |
+
position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
|
619 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
620 |
+
range `[0, config.max_position_embeddings - 1]`.
|
621 |
+
encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
622 |
+
A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
623 |
+
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
624 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
625 |
+
|
626 |
+
- 1 for tokens that are **not masked**,
|
627 |
+
- 0 for tokens that are **masked**.
|
628 |
+
|
629 |
+
[What are attention masks?](../glossary#attention-mask)
|
630 |
+
output_attentions (`bool`, *optional*):
|
631 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
632 |
+
tensors for more detail.
|
633 |
+
output_hidden_states (`bool`, *optional*):
|
634 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
635 |
+
more detail.
|
636 |
+
return_dict (`bool`, *optional*):
|
637 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
638 |
+
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
639 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
640 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
641 |
+
"""
|
642 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
643 |
+
output_hidden_states = (
|
644 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
645 |
+
)
|
646 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
647 |
+
|
648 |
+
if encoder_hidden_states is not None and encoder_attention_mask is None:
|
649 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
650 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
651 |
+
|
652 |
+
# prepare decoder inputs
|
653 |
+
if attention_mask is None:
|
654 |
+
attention_mask = jnp.ones_like(input_ids)
|
655 |
+
if position_ids is None:
|
656 |
+
batch_size, sequence_length = input_ids.shape
|
657 |
+
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
658 |
+
|
659 |
+
# Handle any PRNG if needed
|
660 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
661 |
+
|
662 |
+
inputs = {"params": params or self.params}
|
663 |
+
|
664 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
665 |
+
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
666 |
+
# changed by FlaxBartAttention module
|
667 |
+
if past_key_values:
|
668 |
+
inputs["cache"] = past_key_values
|
669 |
+
mutable = ["cache"]
|
670 |
+
else:
|
671 |
+
mutable = False
|
672 |
+
|
673 |
+
outputs = self.module.apply(
|
674 |
+
inputs,
|
675 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
676 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
677 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
678 |
+
encoder_hidden_states=encoder_hidden_states,
|
679 |
+
encoder_attention_mask=encoder_attention_mask,
|
680 |
+
output_attentions=output_attentions,
|
681 |
+
output_hidden_states=output_hidden_states,
|
682 |
+
return_dict=return_dict,
|
683 |
+
deterministic=not train,
|
684 |
+
rngs=rngs,
|
685 |
+
mutable=mutable,
|
686 |
+
)
|
687 |
+
|
688 |
+
# add updated cache to model output
|
689 |
+
if past_key_values is not None and return_dict:
|
690 |
+
outputs, past_key_values = outputs
|
691 |
+
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
692 |
+
return outputs
|
693 |
+
elif past_key_values is not None and not return_dict:
|
694 |
+
outputs, past_key_values = outputs
|
695 |
+
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
696 |
+
|
697 |
+
return outputs
|
698 |
+
|
699 |
+
|
700 |
+
class FlaxBartDecoderWrapper(nn.Module):
|
701 |
+
"""
|
702 |
+
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
703 |
+
used in combination with the [`EncoderDecoderModel`] framework.
|
704 |
+
"""
|
705 |
+
|
706 |
+
config: BartConfig
|
707 |
+
dtype: jnp.dtype = jnp.float32
|
708 |
+
|
709 |
+
def setup(self):
|
710 |
+
embed_dim = self.config.d_model
|
711 |
+
embed_tokens = nn.Embed(
|
712 |
+
self.config.vocab_size,
|
713 |
+
embed_dim,
|
714 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
715 |
+
)
|
716 |
+
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
|
717 |
+
|
718 |
+
def __call__(self, *args, **kwargs):
|
719 |
+
return self.decoder(*args, **kwargs)
|
720 |
+
|
721 |
+
|
722 |
+
class FlaxBartForCausalLMModule(nn.Module):
|
723 |
+
"""Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
|
724 |
+
e.g. for autoregressive tasks.
|
725 |
+
"""
|
726 |
+
|
727 |
+
config: BartConfig
|
728 |
+
dtype: jnp.dtype = jnp.float32
|
729 |
+
|
730 |
+
def setup(self):
|
731 |
+
self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
|
732 |
+
self.lm_head = nn.Dense(
|
733 |
+
self.config.vocab_size,
|
734 |
+
use_bias=False,
|
735 |
+
dtype=self.dtype,
|
736 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
737 |
+
)
|
738 |
+
|
739 |
+
def __call__(
|
740 |
+
self,
|
741 |
+
input_ids,
|
742 |
+
attention_mask,
|
743 |
+
position_ids,
|
744 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
745 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
746 |
+
init_cache: bool = False,
|
747 |
+
output_attentions: bool = False,
|
748 |
+
output_hidden_states: bool = False,
|
749 |
+
return_dict: bool = True,
|
750 |
+
deterministic: bool = True,
|
751 |
+
):
|
752 |
+
|
753 |
+
outputs = self.model(
|
754 |
+
input_ids,
|
755 |
+
attention_mask,
|
756 |
+
position_ids,
|
757 |
+
encoder_hidden_states,
|
758 |
+
encoder_attention_mask,
|
759 |
+
deterministic=deterministic,
|
760 |
+
init_cache=init_cache,
|
761 |
+
output_attentions=output_attentions,
|
762 |
+
output_hidden_states=output_hidden_states,
|
763 |
+
return_dict=return_dict,
|
764 |
+
)
|
765 |
+
|
766 |
+
hidden_states = outputs[0]
|
767 |
+
|
768 |
+
if self.config.tie_word_embeddings:
|
769 |
+
shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
|
770 |
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
771 |
+
else:
|
772 |
+
lm_logits = self.lm_head(hidden_states)
|
773 |
+
|
774 |
+
if not return_dict:
|
775 |
+
return (lm_logits,) + outputs[1:]
|
776 |
+
|
777 |
+
return FlaxCausalLMOutputWithCrossAttentions(
|
778 |
+
logits=lm_logits,
|
779 |
+
hidden_states=outputs.hidden_states,
|
780 |
+
attentions=outputs.attentions,
|
781 |
+
cross_attentions=outputs.cross_attentions,
|
782 |
+
)
|
783 |
+
|
784 |
+
|
785 |
+
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
|
786 |
+
"""Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
|
787 |
+
e.g. for autoregressive tasks.
|
788 |
+
"""
|
789 |
+
|
790 |
+
module_class = FlaxBartForCausalLMModule
|
791 |
+
|
792 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
793 |
+
# initializing the cache
|
794 |
+
batch_size, seq_length = input_ids.shape
|
795 |
+
|
796 |
+
past_key_values = self.init_cache(batch_size, max_length)
|
797 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
798 |
+
# But since the decoder uses a causal mask, those positions are masked anyway.
|
799 |
+
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
800 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
801 |
+
if attention_mask is not None:
|
802 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
803 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
804 |
+
else:
|
805 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
806 |
+
|
807 |
+
return {
|
808 |
+
"past_key_values": past_key_values,
|
809 |
+
"attention_mask": extended_attention_mask,
|
810 |
+
"position_ids": position_ids,
|
811 |
+
}
|
812 |
+
|
813 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
814 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
815 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
816 |
+
return model_kwargs
|
models/modeling_flax_speech_encoder_decoder.py
ADDED
@@ -0,0 +1,1245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Classes to support Flax Speech-Encoder-Decoder architectures"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from functools import partial
|
19 |
+
from typing import Optional, Tuple, Union, Dict
|
20 |
+
|
21 |
+
import flax
|
22 |
+
import flax.linen as nn
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
26 |
+
from jax import lax
|
27 |
+
from jax.random import PRNGKey
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
31 |
+
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
32 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
|
33 |
+
from transformers.generation_flax_utils import FlaxLogitsProcessorList
|
34 |
+
from models import (
|
35 |
+
FlaxWav2Vec2Model,
|
36 |
+
FlaxWav2Vec2Module,
|
37 |
+
FlaxBartForCausalLM,
|
38 |
+
FlaxBartForCausalLMModule,
|
39 |
+
BartConfig,
|
40 |
+
Wav2Vec2Config,
|
41 |
+
SpeechEncoderDecoderConfig,
|
42 |
+
)
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__)
|
45 |
+
|
46 |
+
_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
|
47 |
+
|
48 |
+
SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
|
49 |
+
This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
|
50 |
+
autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
|
51 |
+
loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
|
52 |
+
[`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
|
53 |
+
and should be fine-tuned on a downstream generative task, like summarization.
|
54 |
+
|
55 |
+
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
|
56 |
+
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
|
57 |
+
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
|
58 |
+
Zhou, Wei Li, Peter J. Liu.
|
59 |
+
|
60 |
+
Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
|
61 |
+
Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
|
62 |
+
translation yields a significant performance improvement.
|
63 |
+
|
64 |
+
After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
|
65 |
+
models (see the examples for more information).
|
66 |
+
|
67 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
68 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
69 |
+
etc.)
|
70 |
+
|
71 |
+
This model is also a Flax Linen
|
72 |
+
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
73 |
+
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
|
77 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
78 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
79 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
80 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
81 |
+
`jax.numpy.bfloat16` (on TPUs).
|
82 |
+
|
83 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
84 |
+
specified all the computation will be performed with the given `dtype`.
|
85 |
+
|
86 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
87 |
+
parameters.**
|
88 |
+
|
89 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
90 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
91 |
+
"""
|
92 |
+
|
93 |
+
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
94 |
+
Args:
|
95 |
+
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
96 |
+
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
97 |
+
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
98 |
+
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
99 |
+
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
100 |
+
*torch.FloatTensor*.
|
101 |
+
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
102 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
103 |
+
|
104 |
+
- 1 for tokens that are **not masked**,
|
105 |
+
- 0 for tokens that are **masked**.
|
106 |
+
|
107 |
+
[What are attention masks?](../glossary#attention-mask)
|
108 |
+
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
109 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
110 |
+
|
111 |
+
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
112 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
113 |
+
|
114 |
+
[What are input IDs?](../glossary#input-ids)
|
115 |
+
|
116 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
117 |
+
`past_key_values`).
|
118 |
+
|
119 |
+
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
120 |
+
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
121 |
+
and prepending them with the `decoder_start_token_id`.
|
122 |
+
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
123 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
124 |
+
be used by default.
|
125 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
126 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
127 |
+
range `[0, config.decoder.max_position_embeddings - 1]`.
|
128 |
+
output_hidden_states (`bool`, *optional*):
|
129 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
130 |
+
more detail.
|
131 |
+
return_dict (`bool`, *optional*):
|
132 |
+
If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
|
133 |
+
"""
|
134 |
+
|
135 |
+
SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
|
136 |
+
Args:
|
137 |
+
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
138 |
+
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
139 |
+
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
140 |
+
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
141 |
+
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
142 |
+
*torch.FloatTensor*.
|
143 |
+
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
144 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
145 |
+
|
146 |
+
- 1 for tokens that are **not masked**,
|
147 |
+
- 0 for tokens that are **masked**.
|
148 |
+
|
149 |
+
[What are attention masks?](../glossary#attention-mask)
|
150 |
+
output_attentions (`bool`, *optional*):
|
151 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
152 |
+
tensors for more detail.
|
153 |
+
output_hidden_states (`bool`, *optional*):
|
154 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
155 |
+
more detail.
|
156 |
+
return_dict (`bool`, *optional*):
|
157 |
+
If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
|
158 |
+
"""
|
159 |
+
|
160 |
+
SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
|
161 |
+
Args:
|
162 |
+
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
163 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
164 |
+
|
165 |
+
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
166 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
167 |
+
|
168 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
169 |
+
|
170 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
171 |
+
`past_key_values`).
|
172 |
+
|
173 |
+
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
174 |
+
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
175 |
+
and prepending them with the `decoder_start_token_id`.
|
176 |
+
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
|
177 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
178 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
179 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
180 |
+
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
181 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
182 |
+
|
183 |
+
- 1 for tokens that are **not masked**,
|
184 |
+
- 0 for tokens that are **masked**.
|
185 |
+
|
186 |
+
[What are attention masks?](../glossary#attention-mask)
|
187 |
+
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
188 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
189 |
+
be used by default.
|
190 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
191 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
192 |
+
range `[0, config.decoder.max_position_embeddings - 1]`.
|
193 |
+
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
194 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
195 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
196 |
+
output_attentions (`bool`, *optional*):
|
197 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
198 |
+
tensors for more detail.
|
199 |
+
output_hidden_states (`bool`, *optional*):
|
200 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
201 |
+
more detail.
|
202 |
+
return_dict (`bool`, *optional*):
|
203 |
+
If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
|
204 |
+
plain tuple.
|
205 |
+
"""
|
206 |
+
|
207 |
+
@flax.struct.dataclass
|
208 |
+
class FlaxBeamSearchOutput(ModelOutput):
|
209 |
+
"""
|
210 |
+
Flax Base class for outputs of decoder-only generation models using greedy search.
|
211 |
+
|
212 |
+
|
213 |
+
Args:
|
214 |
+
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
215 |
+
The generated sequences.
|
216 |
+
scores (`jnp.ndarray` of shape `(batch_size,)`):
|
217 |
+
The scores (log probabilites) of the generated sequences.
|
218 |
+
"""
|
219 |
+
|
220 |
+
sequences: jnp.ndarray = None
|
221 |
+
scores: jnp.ndarray = None
|
222 |
+
|
223 |
+
|
224 |
+
@flax.struct.dataclass
|
225 |
+
class BeamSearchState:
|
226 |
+
cur_len: jnp.ndarray
|
227 |
+
running_sequences: jnp.ndarray
|
228 |
+
running_scores: jnp.ndarray
|
229 |
+
sequences: jnp.ndarray
|
230 |
+
scores: jnp.ndarray
|
231 |
+
is_sent_finished: jnp.ndarray
|
232 |
+
model_kwargs: Dict[str, jnp.ndarray]
|
233 |
+
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
class FlaxSpeechEncoderDecoderModule(nn.Module):
|
238 |
+
config: SpeechEncoderDecoderConfig
|
239 |
+
dtype: jnp.dtype = jnp.float32
|
240 |
+
|
241 |
+
def setup(self):
|
242 |
+
encoder_config = self.config.encoder
|
243 |
+
decoder_config = self.config.decoder
|
244 |
+
|
245 |
+
# TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
|
246 |
+
encoder_module = FlaxWav2Vec2Module
|
247 |
+
decoder_module = FlaxBartForCausalLMModule
|
248 |
+
|
249 |
+
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
|
250 |
+
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
|
251 |
+
|
252 |
+
# encoder outputs might need to be projected to different dimension for decoder
|
253 |
+
if (
|
254 |
+
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
255 |
+
and self.decoder.config.cross_attention_hidden_size is None
|
256 |
+
):
|
257 |
+
self.enc_to_dec_proj = nn.Dense(
|
258 |
+
self.decoder.config.hidden_size,
|
259 |
+
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
|
260 |
+
dtype=self.dtype,
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
self.enc_to_dec_proj = None
|
264 |
+
|
265 |
+
def _get_feat_extract_output_lengths(
|
266 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
267 |
+
):
|
268 |
+
"""
|
269 |
+
Computes the output length of the convolutional layers
|
270 |
+
"""
|
271 |
+
|
272 |
+
add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
|
273 |
+
|
274 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
275 |
+
# 1D convolutional layer output length formula taken
|
276 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
277 |
+
return (input_length - kernel_size) // stride + 1
|
278 |
+
|
279 |
+
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
|
280 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
281 |
+
|
282 |
+
if add_adapter:
|
283 |
+
for _ in range(self.config.encoder.num_adapter_layers):
|
284 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
|
285 |
+
|
286 |
+
return input_lengths
|
287 |
+
|
288 |
+
def _get_encoder_module(self):
|
289 |
+
return self.encoder
|
290 |
+
|
291 |
+
def _get_projection_module(self):
|
292 |
+
return self.enc_to_dec_proj
|
293 |
+
|
294 |
+
def _get_decoder_module(self):
|
295 |
+
return self.decoder
|
296 |
+
|
297 |
+
def __call__(
|
298 |
+
self,
|
299 |
+
inputs,
|
300 |
+
attention_mask,
|
301 |
+
decoder_input_ids,
|
302 |
+
decoder_attention_mask,
|
303 |
+
decoder_position_ids,
|
304 |
+
encoder_outputs=None,
|
305 |
+
extract_features=None,
|
306 |
+
output_attentions: bool = False,
|
307 |
+
output_hidden_states: bool = False,
|
308 |
+
output_features: bool = False,
|
309 |
+
return_dict: bool = True,
|
310 |
+
deterministic: bool = True,
|
311 |
+
freeze_feature_encoder: bool = False,
|
312 |
+
):
|
313 |
+
if encoder_outputs is None:
|
314 |
+
encoder_outputs = self.encoder(
|
315 |
+
inputs,
|
316 |
+
attention_mask=attention_mask,
|
317 |
+
extract_features=extract_features,
|
318 |
+
output_attentions=output_attentions,
|
319 |
+
output_hidden_states=output_hidden_states,
|
320 |
+
output_features=output_features,
|
321 |
+
return_dict=return_dict,
|
322 |
+
deterministic=deterministic,
|
323 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
324 |
+
)
|
325 |
+
|
326 |
+
if output_features:
|
327 |
+
return encoder_outputs
|
328 |
+
|
329 |
+
encoder_hidden_states = encoder_outputs[0]
|
330 |
+
|
331 |
+
# optionally project encoder_hidden_states
|
332 |
+
if self.enc_to_dec_proj is not None:
|
333 |
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
334 |
+
|
335 |
+
# compute correct encoder attention mask
|
336 |
+
if attention_mask is not None:
|
337 |
+
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
|
338 |
+
encoder_hidden_states.shape[1], attention_mask
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
encoder_attention_mask = None
|
342 |
+
|
343 |
+
# flax script modeling_flax_wav2vec2.py
|
344 |
+
decoder_outputs = self.decoder(
|
345 |
+
input_ids=decoder_input_ids,
|
346 |
+
attention_mask=decoder_attention_mask,
|
347 |
+
position_ids=decoder_position_ids,
|
348 |
+
encoder_hidden_states=encoder_hidden_states,
|
349 |
+
encoder_attention_mask=encoder_attention_mask,
|
350 |
+
output_attentions=output_attentions,
|
351 |
+
output_hidden_states=output_hidden_states,
|
352 |
+
return_dict=return_dict,
|
353 |
+
deterministic=deterministic,
|
354 |
+
)
|
355 |
+
|
356 |
+
if not return_dict:
|
357 |
+
return decoder_outputs + encoder_outputs
|
358 |
+
|
359 |
+
return FlaxSeq2SeqLMOutput(
|
360 |
+
logits=decoder_outputs.logits,
|
361 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
362 |
+
decoder_attentions=decoder_outputs.attentions,
|
363 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
364 |
+
encoder_last_hidden_state=encoder_hidden_states,
|
365 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
366 |
+
encoder_attentions=encoder_outputs.attentions,
|
367 |
+
)
|
368 |
+
|
369 |
+
|
370 |
+
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
|
371 |
+
class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
372 |
+
r"""
|
373 |
+
[`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
|
374 |
+
with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
|
375 |
+
as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
|
376 |
+
encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
|
377 |
+
"""
|
378 |
+
|
379 |
+
config_class = SpeechEncoderDecoderConfig
|
380 |
+
base_model_prefix: str = "speech_encoder_decoder"
|
381 |
+
module_class = FlaxSpeechEncoderDecoderModule
|
382 |
+
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
config: SpeechEncoderDecoderConfig,
|
386 |
+
input_shape: Optional[Tuple] = None,
|
387 |
+
seed: int = 0,
|
388 |
+
dtype: jnp.dtype = jnp.float32,
|
389 |
+
_do_init: bool = True,
|
390 |
+
**kwargs
|
391 |
+
):
|
392 |
+
|
393 |
+
if not _do_init:
|
394 |
+
raise ValueError(
|
395 |
+
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
396 |
+
)
|
397 |
+
|
398 |
+
if config.decoder.cross_attention_hidden_size is not None:
|
399 |
+
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
|
400 |
+
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
401 |
+
raise ValueError(
|
402 |
+
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
403 |
+
"it has to be equal to the encoder's `hidden_size`. "
|
404 |
+
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
405 |
+
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
406 |
+
)
|
407 |
+
|
408 |
+
# make sure input & output embeddings are not tied
|
409 |
+
config.tie_word_embeddings = False
|
410 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
411 |
+
|
412 |
+
if input_shape is None:
|
413 |
+
# speech encoders almost always downsample the sequence length dimension
|
414 |
+
encoder_input_length = 1024
|
415 |
+
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
|
416 |
+
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
|
417 |
+
|
418 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
419 |
+
|
420 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
421 |
+
encoder_input_shape, decoder_input_shape = input_shape
|
422 |
+
|
423 |
+
# init input DeviceArrays
|
424 |
+
inputs = jnp.zeros(encoder_input_shape, dtype="f4")
|
425 |
+
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
426 |
+
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
|
427 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
428 |
+
|
429 |
+
batch_size, sequence_length = inputs.shape
|
430 |
+
|
431 |
+
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
|
432 |
+
if not decoder_batch_size == batch_size:
|
433 |
+
raise ValueError(
|
434 |
+
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
|
435 |
+
)
|
436 |
+
decoder_position_ids = jnp.broadcast_to(
|
437 |
+
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
|
438 |
+
)
|
439 |
+
|
440 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
441 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
442 |
+
|
443 |
+
return self.module.init(
|
444 |
+
rngs,
|
445 |
+
inputs,
|
446 |
+
attention_mask,
|
447 |
+
decoder_input_ids,
|
448 |
+
decoder_attention_mask,
|
449 |
+
decoder_position_ids,
|
450 |
+
)["params"]
|
451 |
+
|
452 |
+
def init_cache(self, batch_size, max_length, encoder_outputs):
|
453 |
+
r"""
|
454 |
+
Args:
|
455 |
+
batch_size (`int`):
|
456 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
457 |
+
max_length (`int`):
|
458 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
459 |
+
cache.
|
460 |
+
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
|
461 |
+
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
|
462 |
+
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
|
463 |
+
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
464 |
+
cross-attention of the decoder.
|
465 |
+
"""
|
466 |
+
# init input variables to retrieve cache
|
467 |
+
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
468 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
469 |
+
decoder_position_ids = jnp.broadcast_to(
|
470 |
+
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
|
471 |
+
)
|
472 |
+
|
473 |
+
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
|
474 |
+
decoder_module = module._get_decoder_module()
|
475 |
+
return decoder_module(
|
476 |
+
input_ids=decoder_input_ids,
|
477 |
+
attention_mask=decoder_attention_mask,
|
478 |
+
position_ids=decoder_position_ids,
|
479 |
+
**kwargs,
|
480 |
+
)
|
481 |
+
|
482 |
+
init_variables = self.module.init(
|
483 |
+
jax.random.PRNGKey(0),
|
484 |
+
decoder_input_ids=decoder_input_ids,
|
485 |
+
decoder_attention_mask=decoder_attention_mask,
|
486 |
+
decoder_position_ids=decoder_position_ids,
|
487 |
+
encoder_hidden_states=encoder_outputs[0],
|
488 |
+
init_cache=True,
|
489 |
+
method=_decoder_forward, # we only need to call the decoder to init the cache
|
490 |
+
)
|
491 |
+
return unfreeze(init_variables["cache"])
|
492 |
+
|
493 |
+
def _get_feat_extract_output_lengths(
|
494 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
495 |
+
):
|
496 |
+
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
|
497 |
+
|
498 |
+
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
|
499 |
+
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
500 |
+
def encode(
|
501 |
+
self,
|
502 |
+
inputs: jnp.ndarray,
|
503 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
504 |
+
extract_features: Optional[jnp.ndarray] = None,
|
505 |
+
output_attentions: Optional[bool] = None,
|
506 |
+
output_hidden_states: Optional[bool] = None,
|
507 |
+
output_features: Optional[bool] = None,
|
508 |
+
return_dict: Optional[bool] = None,
|
509 |
+
train: bool = False,
|
510 |
+
freeze_feature_encoder: bool = False,
|
511 |
+
params: dict = None,
|
512 |
+
dropout_rng: PRNGKey = None,
|
513 |
+
):
|
514 |
+
r"""
|
515 |
+
Returns:
|
516 |
+
|
517 |
+
Example:
|
518 |
+
|
519 |
+
```python
|
520 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
521 |
+
|
522 |
+
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
523 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
524 |
+
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
525 |
+
... )
|
526 |
+
|
527 |
+
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
528 |
+
>>> encoder_outputs = model.encode(inputs)
|
529 |
+
```"""
|
530 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
531 |
+
output_hidden_states = (
|
532 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
533 |
+
)
|
534 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
535 |
+
|
536 |
+
if attention_mask is None:
|
537 |
+
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
538 |
+
|
539 |
+
if extract_features is not None:
|
540 |
+
extract_features = jnp.array(extract_features, dtype="f4")
|
541 |
+
|
542 |
+
# Handle any PRNG if needed
|
543 |
+
rngs = {}
|
544 |
+
if dropout_rng is not None:
|
545 |
+
rngs["dropout"] = dropout_rng
|
546 |
+
|
547 |
+
def _encoder_forward(module, inputs, attention_mask, **kwargs):
|
548 |
+
encode_module = module._get_encoder_module()
|
549 |
+
return encode_module(inputs, attention_mask, **kwargs)
|
550 |
+
|
551 |
+
outputs = self.module.apply(
|
552 |
+
{"params": params or self.params},
|
553 |
+
inputs=jnp.array(inputs, dtype="f4"),
|
554 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
555 |
+
extract_features=extract_features,
|
556 |
+
output_attentions=output_attentions,
|
557 |
+
output_hidden_states=output_hidden_states,
|
558 |
+
output_features=output_features,
|
559 |
+
return_dict=return_dict,
|
560 |
+
deterministic=not train,
|
561 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
562 |
+
rngs=rngs,
|
563 |
+
method=_encoder_forward,
|
564 |
+
)
|
565 |
+
|
566 |
+
if return_dict and not output_features:
|
567 |
+
outputs = FlaxBaseModelOutput(
|
568 |
+
last_hidden_state=outputs.last_hidden_state,
|
569 |
+
hidden_states=outputs.hidden_states,
|
570 |
+
attentions=outputs.attentions,
|
571 |
+
)
|
572 |
+
|
573 |
+
return outputs
|
574 |
+
|
575 |
+
@add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
|
576 |
+
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
577 |
+
def decode(
|
578 |
+
self,
|
579 |
+
decoder_input_ids,
|
580 |
+
encoder_outputs,
|
581 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
582 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
583 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
584 |
+
past_key_values: dict = None,
|
585 |
+
output_attentions: Optional[bool] = None,
|
586 |
+
output_hidden_states: Optional[bool] = None,
|
587 |
+
return_dict: Optional[bool] = None,
|
588 |
+
train: bool = False,
|
589 |
+
params: dict = None,
|
590 |
+
dropout_rng: PRNGKey = None,
|
591 |
+
):
|
592 |
+
r"""
|
593 |
+
Returns:
|
594 |
+
|
595 |
+
Example:
|
596 |
+
|
597 |
+
```python
|
598 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
599 |
+
>>> import jax.numpy as jnp
|
600 |
+
|
601 |
+
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
602 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
603 |
+
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
604 |
+
... )
|
605 |
+
|
606 |
+
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
607 |
+
>>> encoder_outputs = model.encode(inputs)
|
608 |
+
|
609 |
+
>>> decoder_start_token_id = model.config.decoder.bos_token_id
|
610 |
+
>>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
|
611 |
+
|
612 |
+
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
613 |
+
>>> logits = outputs.logits
|
614 |
+
```"""
|
615 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
616 |
+
output_hidden_states = (
|
617 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
618 |
+
)
|
619 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
620 |
+
|
621 |
+
encoder_hidden_states = encoder_outputs[0]
|
622 |
+
if encoder_attention_mask is None:
|
623 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
624 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
625 |
+
|
626 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
627 |
+
if decoder_attention_mask is None:
|
628 |
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
629 |
+
|
630 |
+
if decoder_position_ids is None:
|
631 |
+
if past_key_values is not None:
|
632 |
+
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
|
633 |
+
|
634 |
+
decoder_position_ids = jnp.broadcast_to(
|
635 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
636 |
+
)
|
637 |
+
|
638 |
+
# Handle any PRNG if needed
|
639 |
+
rngs = {}
|
640 |
+
if dropout_rng is not None:
|
641 |
+
rngs["dropout"] = dropout_rng
|
642 |
+
|
643 |
+
params = {"params": params or self.params}
|
644 |
+
|
645 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
646 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
647 |
+
# it can be changed by FlaxBartAttention module
|
648 |
+
if past_key_values:
|
649 |
+
params["cache"] = past_key_values
|
650 |
+
mutable = ["cache"]
|
651 |
+
else:
|
652 |
+
mutable = False
|
653 |
+
|
654 |
+
def _decoder_forward(
|
655 |
+
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
|
656 |
+
):
|
657 |
+
|
658 |
+
projection_module = module._get_projection_module()
|
659 |
+
decoder_module = module._get_decoder_module()
|
660 |
+
|
661 |
+
# optionally project encoder_hidden_states
|
662 |
+
if projection_module is not None:
|
663 |
+
encoder_hidden_states = projection_module(encoder_hidden_states)
|
664 |
+
|
665 |
+
return decoder_module(
|
666 |
+
decoder_input_ids,
|
667 |
+
decoder_attention_mask,
|
668 |
+
decoder_position_ids,
|
669 |
+
encoder_hidden_states,
|
670 |
+
**kwargs,
|
671 |
+
)
|
672 |
+
|
673 |
+
outputs = self.module.apply(
|
674 |
+
params,
|
675 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
676 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
677 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
678 |
+
encoder_hidden_states=encoder_hidden_states,
|
679 |
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
680 |
+
output_attentions=output_attentions,
|
681 |
+
output_hidden_states=output_hidden_states,
|
682 |
+
return_dict=return_dict,
|
683 |
+
deterministic=not train,
|
684 |
+
rngs=rngs,
|
685 |
+
mutable=mutable,
|
686 |
+
method=_decoder_forward,
|
687 |
+
)
|
688 |
+
|
689 |
+
# add updated cache to model output
|
690 |
+
if past_key_values is not None and return_dict:
|
691 |
+
outputs, past = outputs
|
692 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
693 |
+
return outputs
|
694 |
+
elif past_key_values is not None and not return_dict:
|
695 |
+
outputs, past = outputs
|
696 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
697 |
+
|
698 |
+
return outputs
|
699 |
+
|
700 |
+
@add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
|
701 |
+
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
702 |
+
def __call__(
|
703 |
+
self,
|
704 |
+
inputs: jnp.ndarray,
|
705 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
706 |
+
extract_features: Optional[jnp.ndarray] = None,
|
707 |
+
decoder_input_ids: Optional[jnp.ndarray] = None,
|
708 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
709 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
710 |
+
output_attentions: Optional[bool] = None,
|
711 |
+
output_hidden_states: Optional[bool] = None,
|
712 |
+
output_features: Optional[bool] = None,
|
713 |
+
return_dict: Optional[bool] = None,
|
714 |
+
train: bool = False,
|
715 |
+
freeze_feature_encoder: bool = False,
|
716 |
+
params: dict = None,
|
717 |
+
dropout_rng: PRNGKey = None,
|
718 |
+
):
|
719 |
+
r"""
|
720 |
+
Returns:
|
721 |
+
|
722 |
+
Examples:
|
723 |
+
|
724 |
+
```python
|
725 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
|
726 |
+
|
727 |
+
>>> # load a fine-tuned wav2vec2-2-bart model
|
728 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
|
729 |
+
>>> # load output tokenizer
|
730 |
+
>>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
|
731 |
+
|
732 |
+
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
733 |
+
|
734 |
+
>>> # use bart's special bos, pad and eos tokens
|
735 |
+
>>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
|
736 |
+
>>> model.config.pad_token_id = model.decoder.config.pad_token_id
|
737 |
+
>>> model.config.eos_token_id = model.decoder.config.eos_token_id
|
738 |
+
|
739 |
+
>>> outputs = model.generate(inputs)
|
740 |
+
# Assert something? More interesting input? dtype correct?
|
741 |
+
```
|
742 |
+
"""
|
743 |
+
|
744 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
745 |
+
output_hidden_states = (
|
746 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
747 |
+
)
|
748 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
749 |
+
|
750 |
+
# prepare encoder inputs
|
751 |
+
if attention_mask is None:
|
752 |
+
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
753 |
+
|
754 |
+
if extract_features is not None:
|
755 |
+
inputs = None # we can omit passing the inputs to the model to save memory
|
756 |
+
extract_features = jnp.array(extract_features, dtype="f4")
|
757 |
+
else:
|
758 |
+
inputs = jnp.array(inputs, dtype="f4")
|
759 |
+
|
760 |
+
# prepare decoder inputs
|
761 |
+
if decoder_input_ids is None:
|
762 |
+
raise ValueError(
|
763 |
+
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
|
764 |
+
)
|
765 |
+
if decoder_attention_mask is None:
|
766 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
767 |
+
if decoder_position_ids is None:
|
768 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
769 |
+
decoder_position_ids = jnp.broadcast_to(
|
770 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
771 |
+
)
|
772 |
+
|
773 |
+
# Handle any PRNG if needed
|
774 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
775 |
+
|
776 |
+
return self.module.apply(
|
777 |
+
{"params": params or self.params},
|
778 |
+
inputs=inputs,
|
779 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
780 |
+
extract_features=extract_features,
|
781 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
782 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
783 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
784 |
+
output_attentions=output_attentions,
|
785 |
+
output_hidden_states=output_hidden_states,
|
786 |
+
output_features=output_features,
|
787 |
+
return_dict=return_dict,
|
788 |
+
deterministic=not train,
|
789 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
790 |
+
rngs=rngs,
|
791 |
+
)
|
792 |
+
|
793 |
+
def prepare_inputs_for_generation(
|
794 |
+
self,
|
795 |
+
decoder_input_ids,
|
796 |
+
max_length,
|
797 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
798 |
+
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
799 |
+
encoder_outputs=None,
|
800 |
+
**kwargs
|
801 |
+
):
|
802 |
+
# initializing the cache
|
803 |
+
batch_size, seq_length = decoder_input_ids.shape
|
804 |
+
|
805 |
+
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
806 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
|
807 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
808 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
809 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
810 |
+
if decoder_attention_mask is not None:
|
811 |
+
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
812 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
|
813 |
+
else:
|
814 |
+
decoder_position_ids = jnp.broadcast_to(
|
815 |
+
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
816 |
+
)
|
817 |
+
|
818 |
+
return {
|
819 |
+
"past_key_values": past_key_values,
|
820 |
+
"encoder_outputs": encoder_outputs,
|
821 |
+
"encoder_attention_mask": attention_mask,
|
822 |
+
"decoder_attention_mask": extended_attention_mask,
|
823 |
+
"decoder_position_ids": decoder_position_ids,
|
824 |
+
}
|
825 |
+
|
826 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
827 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
828 |
+
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
829 |
+
return model_kwargs
|
830 |
+
|
831 |
+
@classmethod
|
832 |
+
def from_encoder_decoder_pretrained(
|
833 |
+
cls,
|
834 |
+
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
835 |
+
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
836 |
+
*model_args,
|
837 |
+
**kwargs
|
838 |
+
) -> FlaxPreTrainedModel:
|
839 |
+
r"""
|
840 |
+
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
|
841 |
+
checkpoints.
|
842 |
+
|
843 |
+
Params:
|
844 |
+
encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
|
845 |
+
Information necessary to initiate the encoder. Can be either:
|
846 |
+
|
847 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
848 |
+
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
849 |
+
user or organization name, like `dbmdz/bert-base-german-cased`.
|
850 |
+
- A path to a *directory* containing model weights saved using
|
851 |
+
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
852 |
+
|
853 |
+
decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
|
854 |
+
Information necessary to initiate the decoder. Can be either:
|
855 |
+
|
856 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
857 |
+
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
858 |
+
user or organization name, like `dbmdz/bert-base-german-cased`.
|
859 |
+
- A path to a *directory* containing model weights saved using
|
860 |
+
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
861 |
+
|
862 |
+
model_args (remaining positional arguments, *optional*):
|
863 |
+
All remaning positional arguments will be passed to the underlying model's `__init__` method.
|
864 |
+
|
865 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
866 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
867 |
+
`output_attentions=True`).
|
868 |
+
|
869 |
+
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
|
870 |
+
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
|
871 |
+
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
872 |
+
|
873 |
+
Behaves differently depending on whether a `config` is provided or automatically loaded.
|
874 |
+
|
875 |
+
Example:
|
876 |
+
|
877 |
+
```python
|
878 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
879 |
+
|
880 |
+
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
881 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
882 |
+
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
883 |
+
... )
|
884 |
+
>>> # saving model after fine-tuning
|
885 |
+
>>> model.save_pretrained("./wav2vec2-2-bart-large")
|
886 |
+
>>> # load fine-tuned model
|
887 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
|
888 |
+
```"""
|
889 |
+
|
890 |
+
kwargs_encoder = {
|
891 |
+
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
892 |
+
}
|
893 |
+
|
894 |
+
kwargs_decoder = {
|
895 |
+
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
896 |
+
}
|
897 |
+
|
898 |
+
# remove encoder, decoder kwargs from kwargs
|
899 |
+
for key in kwargs_encoder.keys():
|
900 |
+
del kwargs["encoder_" + key]
|
901 |
+
for key in kwargs_decoder.keys():
|
902 |
+
del kwargs["decoder_" + key]
|
903 |
+
|
904 |
+
# Load and initialize the encoder and decoder
|
905 |
+
# The distinction between encoder and decoder at the model level is made
|
906 |
+
# by the value of the flag `is_decoder` that we need to set correctly.
|
907 |
+
encoder = kwargs_encoder.pop("model", None)
|
908 |
+
if encoder is None:
|
909 |
+
if encoder_pretrained_model_name_or_path is None:
|
910 |
+
raise ValueError(
|
911 |
+
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
912 |
+
"to be defined."
|
913 |
+
)
|
914 |
+
|
915 |
+
if "config" not in kwargs_encoder:
|
916 |
+
# TODO: AutoConfig .from_pretrained
|
917 |
+
encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
|
918 |
+
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
919 |
+
)
|
920 |
+
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
921 |
+
logger.info(
|
922 |
+
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
923 |
+
"from a decoder model. Cross-attention and casual mask are disabled."
|
924 |
+
)
|
925 |
+
encoder_config.is_decoder = False
|
926 |
+
encoder_config.add_cross_attention = False
|
927 |
+
|
928 |
+
kwargs_encoder["config"] = encoder_config
|
929 |
+
|
930 |
+
# TODO: FlaxAutoModel .from_pretrained
|
931 |
+
encoder = FlaxWav2Vec2Model.from_pretrained(
|
932 |
+
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
933 |
+
)
|
934 |
+
|
935 |
+
decoder = kwargs_decoder.pop("model", None)
|
936 |
+
if decoder is None:
|
937 |
+
if decoder_pretrained_model_name_or_path is None:
|
938 |
+
raise ValueError(
|
939 |
+
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
940 |
+
"to be defined."
|
941 |
+
)
|
942 |
+
|
943 |
+
if "config" not in kwargs_decoder:
|
944 |
+
# TODO: AutoConfig .from_pretrained
|
945 |
+
decoder_config, kwargs_decoder = BartConfig.from_pretrained(
|
946 |
+
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
947 |
+
)
|
948 |
+
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
949 |
+
logger.info(
|
950 |
+
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
951 |
+
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
952 |
+
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
953 |
+
"cross attention layers."
|
954 |
+
)
|
955 |
+
decoder_config.is_decoder = True
|
956 |
+
decoder_config.add_cross_attention = True
|
957 |
+
|
958 |
+
kwargs_decoder["config"] = decoder_config
|
959 |
+
|
960 |
+
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
961 |
+
logger.warning(
|
962 |
+
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
963 |
+
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
964 |
+
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
965 |
+
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
966 |
+
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
967 |
+
)
|
968 |
+
|
969 |
+
# TODO: FlaxAutoModelForCausalLM .from_pretrained
|
970 |
+
decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
971 |
+
|
972 |
+
# instantiate config with corresponding kwargs
|
973 |
+
dtype = kwargs.pop("dtype", jnp.float32)
|
974 |
+
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
975 |
+
|
976 |
+
# make sure input & output word embeddings are not tied
|
977 |
+
config.tie_word_embeddings = False
|
978 |
+
|
979 |
+
# init model
|
980 |
+
model = cls(config, dtype=dtype)
|
981 |
+
model.params["encoder"] = encoder.params
|
982 |
+
model.params["decoder"] = decoder.params
|
983 |
+
|
984 |
+
return model
|
985 |
+
|
986 |
+
def _beam_search(
|
987 |
+
self,
|
988 |
+
input_ids: None,
|
989 |
+
max_length: Optional[int] = None,
|
990 |
+
pad_token_id: Optional[int] = None,
|
991 |
+
eos_token_id: Optional[int] = None,
|
992 |
+
length_penalty: Optional[float] = None,
|
993 |
+
early_stopping: Optional[bool] = None,
|
994 |
+
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
995 |
+
trace: bool = True,
|
996 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
997 |
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
998 |
+
):
|
999 |
+
"""
|
1000 |
+
This beam search function is heavily inspired by Flax's official example:
|
1001 |
+
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
|
1002 |
+
"""
|
1003 |
+
|
1004 |
+
def flatten_beam_dim(tensor):
|
1005 |
+
"""Flattens the first two dimensions of a non-scalar array."""
|
1006 |
+
# ignore scalars (e.g. cache index)
|
1007 |
+
if tensor.ndim == 0 or tensor.ndim == 1:
|
1008 |
+
return tensor
|
1009 |
+
elif tensor.ndim == 6:
|
1010 |
+
return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
|
1011 |
+
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
1012 |
+
|
1013 |
+
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
1014 |
+
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
1015 |
+
# ignore scalars (e.g. cache index)
|
1016 |
+
if tensor.ndim == 0 or tensor.ndim == 1:
|
1017 |
+
return tensor
|
1018 |
+
if tensor.ndim == 5:
|
1019 |
+
return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
|
1020 |
+
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
1021 |
+
|
1022 |
+
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
1023 |
+
"""
|
1024 |
+
Gathers the beam slices indexed by beam_indices into new beam array.
|
1025 |
+
"""
|
1026 |
+
batch_indices = jnp.reshape(
|
1027 |
+
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
def gather_fn(tensor):
|
1031 |
+
# ignore scalars (e.g. cache index)
|
1032 |
+
if tensor.ndim == 0 or tensor.ndim == 1:
|
1033 |
+
return tensor
|
1034 |
+
if tensor.ndim == 6:
|
1035 |
+
return tensor[:, batch_indices, beam_indices]
|
1036 |
+
return tensor[batch_indices, beam_indices]
|
1037 |
+
|
1038 |
+
return jax.tree_map(gather_fn, nested)
|
1039 |
+
|
1040 |
+
# init values
|
1041 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
1042 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
1043 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
1044 |
+
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
1045 |
+
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
1046 |
+
|
1047 |
+
batch_size, num_beams, cur_len = input_ids.shape
|
1048 |
+
|
1049 |
+
eos_token_id = jnp.array(eos_token_id)
|
1050 |
+
pad_token_id = jnp.array(pad_token_id)
|
1051 |
+
cur_len = jnp.array(cur_len)
|
1052 |
+
|
1053 |
+
# per batch,beam-item holding current token in loop.
|
1054 |
+
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
1055 |
+
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
1056 |
+
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
1057 |
+
|
1058 |
+
# per batch,beam-item state bit indicating if sentence has finished.
|
1059 |
+
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
1060 |
+
|
1061 |
+
# per batch,beam-item score, logprobs
|
1062 |
+
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
1063 |
+
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
1064 |
+
|
1065 |
+
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
1066 |
+
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
1067 |
+
model = self.decode if self.config.is_encoder_decoder else self
|
1068 |
+
|
1069 |
+
# flatten beam dim
|
1070 |
+
if "encoder_outputs" in model_kwargs:
|
1071 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
1072 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
1073 |
+
)
|
1074 |
+
if "attention_mask" in model_kwargs:
|
1075 |
+
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
1076 |
+
|
1077 |
+
# initialize model specific kwargs
|
1078 |
+
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
1079 |
+
|
1080 |
+
# initialize state
|
1081 |
+
state = BeamSearchState(
|
1082 |
+
cur_len=cur_len,
|
1083 |
+
running_sequences=running_sequences,
|
1084 |
+
running_scores=running_scores,
|
1085 |
+
sequences=sequences,
|
1086 |
+
scores=scores,
|
1087 |
+
is_sent_finished=is_sent_finished,
|
1088 |
+
model_kwargs=model_kwargs,
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
def beam_search_cond_fn(state):
|
1092 |
+
"""beam search state termination condition fn."""
|
1093 |
+
|
1094 |
+
# 1. is less than max length?
|
1095 |
+
not_max_length_yet = state.cur_len < max_length
|
1096 |
+
|
1097 |
+
# 2. can the new beams still improve?
|
1098 |
+
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
|
1099 |
+
worst_finished_score = jnp.where(
|
1100 |
+
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
1101 |
+
)
|
1102 |
+
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
1103 |
+
|
1104 |
+
# 3. is there still a beam that has not finished?
|
1105 |
+
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
1106 |
+
|
1107 |
+
return not_max_length_yet & still_open_beam & improvement_still_possible
|
1108 |
+
|
1109 |
+
def beam_search_body_fn(state, input_ids_length=1):
|
1110 |
+
"""beam search state update fn."""
|
1111 |
+
# 1. Forward current tokens
|
1112 |
+
# Collect the current position slice along length to feed the fast
|
1113 |
+
# autoregressive decoder model. Flatten the beam dimension into batch
|
1114 |
+
# dimension for feeding into the model.
|
1115 |
+
# unflatten beam dimension
|
1116 |
+
# Unflatten beam dimension in attention cache arrays
|
1117 |
+
input_token = flatten_beam_dim(
|
1118 |
+
lax.dynamic_slice(
|
1119 |
+
state.running_sequences,
|
1120 |
+
(0, 0, state.cur_len - input_ids_length),
|
1121 |
+
(batch_size, num_beams, input_ids_length),
|
1122 |
+
)
|
1123 |
+
)
|
1124 |
+
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
1125 |
+
|
1126 |
+
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
1127 |
+
cache = jax.tree_map(
|
1128 |
+
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
# adapt logits for FlaxMarianMTModel
|
1132 |
+
logits = self._adapt_logits_for_beam_search(logits)
|
1133 |
+
|
1134 |
+
# 2. Compute log probs
|
1135 |
+
# get log probabilities from logits,
|
1136 |
+
# process logits with processors (*e.g.* min_length, ...), and
|
1137 |
+
# add new logprobs to existing running logprobs scores.
|
1138 |
+
log_probs = jax.nn.log_softmax(logits)
|
1139 |
+
log_probs = logits_processor(
|
1140 |
+
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
1141 |
+
)
|
1142 |
+
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
1143 |
+
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
1144 |
+
vocab_size = log_probs.shape[2]
|
1145 |
+
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
1146 |
+
|
1147 |
+
# 3. Retrieve top-K
|
1148 |
+
# Each item in batch has num_beams * vocab_size candidate sequences.
|
1149 |
+
# For each item, get the top 2*k candidates with the highest log-
|
1150 |
+
# probabilities. We gather the top 2*K beams here so that even if the best
|
1151 |
+
# K sequences reach EOS simultaneously, we have another K sequences
|
1152 |
+
# remaining to continue the live beam search.
|
1153 |
+
# Gather the top 2*K scores from _all_ beams.
|
1154 |
+
# Gather 2*k top beams.
|
1155 |
+
# Recover the beam index by floor division.
|
1156 |
+
# Recover token id by modulo division and expand Id array for broadcasting.
|
1157 |
+
# Update sequences for the 2*K top-k new sequences.
|
1158 |
+
beams_to_keep = 2 * num_beams
|
1159 |
+
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
1160 |
+
topk_beam_indices = topk_indices // vocab_size
|
1161 |
+
topk_running_sequences = gather_beams(
|
1162 |
+
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
1163 |
+
)
|
1164 |
+
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
1165 |
+
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
1166 |
+
|
1167 |
+
# 4. Check which sequences have ended
|
1168 |
+
# Update current sequences:
|
1169 |
+
# Did any of these sequences reach an end marker?
|
1170 |
+
# To prevent these just finished sequences from being added to the current sequences
|
1171 |
+
# set of active beam search sequences, set their log probs to a very large
|
1172 |
+
# negative value.
|
1173 |
+
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
1174 |
+
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
1175 |
+
# 5. Get running sequences scores for next
|
1176 |
+
# Determine the top k beam indices (from top 2*k beams) from log probs
|
1177 |
+
# and gather top k beams (from top 2*k beams).
|
1178 |
+
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
|
1179 |
+
next_running_sequences, next_running_scores = gather_beams(
|
1180 |
+
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
|
1181 |
+
)
|
1182 |
+
|
1183 |
+
# 6. Process topk logits
|
1184 |
+
# Further process log probs:
|
1185 |
+
# - add length penalty
|
1186 |
+
# - make sure no scores can be added anymore if beam is full
|
1187 |
+
# - make sure still running sequences cannot be chosen as finalized beam
|
1188 |
+
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
|
1189 |
+
beams_in_batch_are_full = (
|
1190 |
+
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
1191 |
+
& early_stopping
|
1192 |
+
)
|
1193 |
+
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
1194 |
+
topk_log_probs += add_penalty * np.array(-1.0e7)
|
1195 |
+
|
1196 |
+
# 7. Get scores, sequences, is sentence finished for next.
|
1197 |
+
# Combine sequences, scores, and flags along the beam dimension and compare
|
1198 |
+
# new finished sequence scores to existing finished scores and select the
|
1199 |
+
# best from the new set of beams
|
1200 |
+
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
1201 |
+
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
1202 |
+
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
1203 |
+
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
1204 |
+
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
1205 |
+
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
1206 |
+
)
|
1207 |
+
|
1208 |
+
# 8. Update model kwargs.
|
1209 |
+
# Determine the top k beam indices from the original set of all beams.
|
1210 |
+
# With these, gather the top k beam-associated caches.
|
1211 |
+
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
1212 |
+
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
1213 |
+
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
1214 |
+
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
1215 |
+
|
1216 |
+
return BeamSearchState(
|
1217 |
+
cur_len=state.cur_len + 1,
|
1218 |
+
running_scores=next_running_scores,
|
1219 |
+
running_sequences=next_running_sequences,
|
1220 |
+
scores=next_scores,
|
1221 |
+
sequences=next_sequences,
|
1222 |
+
is_sent_finished=next_is_sent_finished,
|
1223 |
+
model_kwargs=next_model_kwargs,
|
1224 |
+
)
|
1225 |
+
|
1226 |
+
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
1227 |
+
if input_ids.shape[-1] > 1:
|
1228 |
+
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
1229 |
+
|
1230 |
+
if not trace:
|
1231 |
+
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
1232 |
+
else:
|
1233 |
+
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
1234 |
+
|
1235 |
+
# Account for the edge-case where there are no finished sequences for a
|
1236 |
+
# particular batch item. If so, return running sequences for that batch item.
|
1237 |
+
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
1238 |
+
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
1239 |
+
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
1240 |
+
|
1241 |
+
# return all beams for each batch and the best score
|
1242 |
+
sequences = sequences[:, :]
|
1243 |
+
scores = scores[:, -1]
|
1244 |
+
|
1245 |
+
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
models/modeling_flax_wav2vec2.py
ADDED
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Flax Wav2Vec2 model."""
|
16 |
+
|
17 |
+
from functools import partial
|
18 |
+
from typing import Optional, Tuple, Union
|
19 |
+
|
20 |
+
import flax
|
21 |
+
import flax.linen as nn
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
from flax.core.frozen_dict import FrozenDict
|
25 |
+
from flax.linen import partitioning as nn_partitioning
|
26 |
+
from flax.linen.attention import dot_product_attention_weights
|
27 |
+
from jax import lax
|
28 |
+
|
29 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
30 |
+
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
31 |
+
from transformers.utils import ModelOutput
|
32 |
+
|
33 |
+
from models import Wav2Vec2Config
|
34 |
+
|
35 |
+
scan_with_axes = nn_partitioning.scan_with_axes
|
36 |
+
remat = nn_partitioning.remat
|
37 |
+
|
38 |
+
|
39 |
+
@flax.struct.dataclass
|
40 |
+
class FlaxWav2Vec2BaseModelOutput(ModelOutput):
|
41 |
+
"""
|
42 |
+
Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
46 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
47 |
+
extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
|
48 |
+
Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
|
49 |
+
being the dimension of the last convolutional layer.
|
50 |
+
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
51 |
+
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
52 |
+
`(batch_size, sequence_length, hidden_size)`.
|
53 |
+
|
54 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
55 |
+
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
56 |
+
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
57 |
+
sequence_length)`.
|
58 |
+
|
59 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
60 |
+
heads.
|
61 |
+
"""
|
62 |
+
|
63 |
+
last_hidden_state: jnp.ndarray = None
|
64 |
+
extract_features: jnp.ndarray = None
|
65 |
+
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
66 |
+
attentions: Optional[Tuple[jnp.ndarray]] = None
|
67 |
+
|
68 |
+
|
69 |
+
WAV_2_VEC_2_START_DOCSTRING = r"""
|
70 |
+
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
71 |
+
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
72 |
+
Auli.
|
73 |
+
|
74 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
75 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
76 |
+
etc.)
|
77 |
+
|
78 |
+
This model is also a Flax Linen
|
79 |
+
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
80 |
+
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
81 |
+
|
82 |
+
Finally, this model supports inherent JAX features such as:
|
83 |
+
|
84 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
85 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
86 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
87 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
88 |
+
|
89 |
+
Parameters:
|
90 |
+
config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
|
91 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
92 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
93 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
94 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
95 |
+
`jax.numpy.bfloat16` (on TPUs).
|
96 |
+
|
97 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
98 |
+
specified all the computation will be performed with the given `dtype`.
|
99 |
+
|
100 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
101 |
+
parameters.**
|
102 |
+
|
103 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
104 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
105 |
+
"""
|
106 |
+
|
107 |
+
|
108 |
+
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
109 |
+
Args:
|
110 |
+
input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
111 |
+
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
112 |
+
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
113 |
+
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
|
114 |
+
and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
|
115 |
+
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
116 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
117 |
+
1]`:
|
118 |
+
|
119 |
+
- 1 for tokens that are **not masked**,
|
120 |
+
- 0 for tokens that are **masked**.
|
121 |
+
|
122 |
+
[What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
|
123 |
+
if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
|
124 |
+
has `config.return_attention_mask == False`, such as
|
125 |
+
[wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
|
126 |
+
passed to avoid degraded performance when doing batched inference. For such models `input_values` should
|
127 |
+
simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
|
128 |
+
different results depending on whether `input_values` is padded or not.
|
129 |
+
mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
130 |
+
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
131 |
+
masked extracted features in *config.proj_codevector_dim* space.
|
132 |
+
output_attentions (`bool`, *optional*):
|
133 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
134 |
+
tensors for more detail.
|
135 |
+
output_hidden_states (`bool`, *optional*):
|
136 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
137 |
+
more detail.
|
138 |
+
return_dict (`bool`, *optional*):
|
139 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
140 |
+
"""
|
141 |
+
|
142 |
+
|
143 |
+
class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
|
144 |
+
config: Wav2Vec2Config
|
145 |
+
layer_id: int = 0
|
146 |
+
dtype: jnp.dtype = jnp.float32
|
147 |
+
|
148 |
+
def setup(self):
|
149 |
+
self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
|
150 |
+
self.out_conv_dim = self.config.conv_dim[self.layer_id]
|
151 |
+
|
152 |
+
self.conv = nn.Conv(
|
153 |
+
features=self.config.conv_dim[self.layer_id],
|
154 |
+
kernel_size=(self.config.conv_kernel[self.layer_id],),
|
155 |
+
strides=(self.config.conv_stride[self.layer_id],),
|
156 |
+
use_bias=self.config.conv_bias,
|
157 |
+
kernel_init=jax.nn.initializers.he_normal(),
|
158 |
+
padding="VALID",
|
159 |
+
dtype=self.dtype,
|
160 |
+
)
|
161 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
162 |
+
self.activation = ACT2FN[self.config.feat_extract_activation]
|
163 |
+
|
164 |
+
def __call__(self, hidden_states):
|
165 |
+
hidden_states = self.conv(hidden_states)
|
166 |
+
hidden_states = self.layer_norm(hidden_states)
|
167 |
+
hidden_states = self.activation(hidden_states)
|
168 |
+
return hidden_states
|
169 |
+
|
170 |
+
|
171 |
+
class FlaxConvWithWeightNorm(nn.Module):
|
172 |
+
config: Wav2Vec2Config
|
173 |
+
dtype: jnp.dtype = jnp.float32
|
174 |
+
|
175 |
+
def setup(self):
|
176 |
+
self.conv = nn.Conv(
|
177 |
+
features=self.config.hidden_size,
|
178 |
+
kernel_size=(self.config.num_conv_pos_embeddings,),
|
179 |
+
kernel_init=jax.nn.initializers.he_normal(),
|
180 |
+
padding="VALID",
|
181 |
+
feature_group_count=self.config.num_conv_pos_embedding_groups,
|
182 |
+
dtype=self.dtype,
|
183 |
+
)
|
184 |
+
weight_shape = (
|
185 |
+
self.conv.features,
|
186 |
+
self.conv.features // self.conv.feature_group_count,
|
187 |
+
self.conv.kernel_size[0],
|
188 |
+
)
|
189 |
+
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
|
190 |
+
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
|
191 |
+
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
|
192 |
+
self.prev_padding = self.conv.kernel_size[0] // 2
|
193 |
+
|
194 |
+
def _get_normed_weights(self):
|
195 |
+
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
|
196 |
+
normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
|
197 |
+
normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
|
198 |
+
return normed_kernel
|
199 |
+
|
200 |
+
def __call__(self, hidden_states):
|
201 |
+
kernel = self._get_normed_weights()
|
202 |
+
hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
|
203 |
+
hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
|
204 |
+
return hidden_states
|
205 |
+
|
206 |
+
|
207 |
+
class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
|
208 |
+
config: Wav2Vec2Config
|
209 |
+
dtype: jnp.dtype = jnp.float32
|
210 |
+
|
211 |
+
def setup(self):
|
212 |
+
self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
|
213 |
+
self.activation = ACT2FN[self.config.feat_extract_activation]
|
214 |
+
self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
|
215 |
+
|
216 |
+
def __call__(self, hidden_states):
|
217 |
+
hidden_states = hidden_states.transpose((0, 1, 2))
|
218 |
+
|
219 |
+
hidden_states = self.conv(hidden_states)
|
220 |
+
|
221 |
+
if self.num_pad_remove > 0:
|
222 |
+
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
|
223 |
+
hidden_states = self.activation(hidden_states)
|
224 |
+
|
225 |
+
hidden_states = hidden_states.transpose((0, 1, 2))
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
class FlaxConvLayersCollection(nn.Module):
|
230 |
+
config: Wav2Vec2Config
|
231 |
+
dtype: jnp.dtype = jnp.float32
|
232 |
+
|
233 |
+
def setup(self):
|
234 |
+
if self.config.feat_extract_norm == "layer":
|
235 |
+
# note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
|
236 |
+
BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
|
237 |
+
self.layers = [
|
238 |
+
BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
|
239 |
+
for i in range(self.config.num_feat_extract_layers)
|
240 |
+
]
|
241 |
+
elif self.config.feat_extract_norm == "group":
|
242 |
+
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
|
243 |
+
else:
|
244 |
+
raise ValueError(
|
245 |
+
f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
246 |
+
)
|
247 |
+
|
248 |
+
def __call__(self, hidden_states):
|
249 |
+
for i, conv_layer in enumerate(self.layers):
|
250 |
+
hidden_states = conv_layer(hidden_states)
|
251 |
+
return hidden_states
|
252 |
+
|
253 |
+
|
254 |
+
class FlaxWav2Vec2FeatureEncoder(nn.Module):
|
255 |
+
"""Construct the features from raw audio waveform"""
|
256 |
+
|
257 |
+
config: Wav2Vec2Config
|
258 |
+
dtype: jnp.dtype = jnp.float32
|
259 |
+
|
260 |
+
def setup(self):
|
261 |
+
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
|
262 |
+
|
263 |
+
def __call__(self, input_values, freeze_feature_encoder=False):
|
264 |
+
hidden_states = input_values[:, :, None]
|
265 |
+
hidden_states = self.conv_layers(hidden_states)
|
266 |
+
if freeze_feature_encoder:
|
267 |
+
hidden_states = jax.lax.stop_gradient(hidden_states)
|
268 |
+
return hidden_states
|
269 |
+
|
270 |
+
|
271 |
+
class FlaxWav2Vec2FeatureProjection(nn.Module):
|
272 |
+
config: Wav2Vec2Config
|
273 |
+
dtype: jnp.dtype = jnp.float32
|
274 |
+
|
275 |
+
def setup(self):
|
276 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
277 |
+
self.projection = nn.Dense(
|
278 |
+
self.config.hidden_size,
|
279 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
280 |
+
dtype=self.dtype,
|
281 |
+
)
|
282 |
+
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
|
283 |
+
|
284 |
+
def __call__(self, hidden_states, deterministic=True):
|
285 |
+
norm_hidden_states = self.layer_norm(hidden_states)
|
286 |
+
hidden_states = self.projection(norm_hidden_states)
|
287 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
288 |
+
return hidden_states, norm_hidden_states
|
289 |
+
|
290 |
+
|
291 |
+
class FlaxWav2Vec2Attention(nn.Module):
|
292 |
+
config: Wav2Vec2Config
|
293 |
+
embed_dim: int
|
294 |
+
num_heads: int
|
295 |
+
dropout: float = 0.0
|
296 |
+
bias: bool = True
|
297 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
298 |
+
|
299 |
+
def setup(self) -> None:
|
300 |
+
self.head_dim = self.embed_dim // self.num_heads
|
301 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
302 |
+
raise ValueError(
|
303 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
304 |
+
)
|
305 |
+
|
306 |
+
dense = partial(
|
307 |
+
nn.Dense,
|
308 |
+
self.embed_dim,
|
309 |
+
use_bias=self.bias,
|
310 |
+
dtype=self.dtype,
|
311 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
312 |
+
)
|
313 |
+
|
314 |
+
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
315 |
+
|
316 |
+
self.fused_proj = nn.Dense(
|
317 |
+
self.embed_dim * 3,
|
318 |
+
use_bias=self.bias,
|
319 |
+
dtype=self.dtype,
|
320 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
321 |
+
)
|
322 |
+
|
323 |
+
self.out_proj = dense()
|
324 |
+
|
325 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
326 |
+
|
327 |
+
def _split_heads(self, hidden_states):
|
328 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
329 |
+
|
330 |
+
def _merge_heads(self, hidden_states):
|
331 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
332 |
+
|
333 |
+
def __call__(
|
334 |
+
self,
|
335 |
+
hidden_states: jnp.ndarray,
|
336 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
337 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
338 |
+
deterministic: bool = True,
|
339 |
+
) -> Tuple[jnp.ndarray]:
|
340 |
+
"""Input shape: Batch x Time x Channel"""
|
341 |
+
|
342 |
+
if self.config.fuse_matmuls:
|
343 |
+
attention_states = self.fused_proj(hidden_states)
|
344 |
+
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
|
345 |
+
|
346 |
+
else:
|
347 |
+
# get query proj
|
348 |
+
query_states = self.q_proj(hidden_states)
|
349 |
+
|
350 |
+
key_states = self.k_proj(hidden_states)
|
351 |
+
value_states = self.v_proj(hidden_states)
|
352 |
+
|
353 |
+
query_states = self._split_heads(query_states)
|
354 |
+
key_states = self._split_heads(key_states)
|
355 |
+
value_states = self._split_heads(value_states)
|
356 |
+
|
357 |
+
if attention_mask is not None:
|
358 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
359 |
+
|
360 |
+
# Convert the boolean attention mask to an attention bias.
|
361 |
+
if attention_mask is not None:
|
362 |
+
# attention mask in the form of attention bias
|
363 |
+
attention_bias = lax.select(
|
364 |
+
attention_mask > 0,
|
365 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
366 |
+
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
367 |
+
)
|
368 |
+
else:
|
369 |
+
attention_bias = None
|
370 |
+
|
371 |
+
dropout_rng = None
|
372 |
+
if not deterministic and self.dropout > 0.0:
|
373 |
+
dropout_rng = self.make_rng("dropout")
|
374 |
+
|
375 |
+
attn_weights = dot_product_attention_weights(
|
376 |
+
query_states,
|
377 |
+
key_states,
|
378 |
+
bias=attention_bias,
|
379 |
+
dropout_rng=dropout_rng,
|
380 |
+
dropout_rate=self.dropout,
|
381 |
+
broadcast_dropout=True,
|
382 |
+
deterministic=deterministic,
|
383 |
+
dtype=self.dtype,
|
384 |
+
precision=None,
|
385 |
+
)
|
386 |
+
|
387 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
388 |
+
attn_output = self._merge_heads(attn_output)
|
389 |
+
attn_output = self.out_proj(attn_output)
|
390 |
+
|
391 |
+
return attn_output, attn_weights
|
392 |
+
|
393 |
+
|
394 |
+
class FlaxWav2Vec2FeedForward(nn.Module):
|
395 |
+
config: Wav2Vec2Config
|
396 |
+
dtype: jnp.dtype = jnp.float32
|
397 |
+
|
398 |
+
def setup(self):
|
399 |
+
self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
|
400 |
+
|
401 |
+
self.intermediate_dense = nn.Dense(
|
402 |
+
self.config.intermediate_size,
|
403 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
404 |
+
dtype=self.dtype,
|
405 |
+
)
|
406 |
+
if isinstance(self.config.hidden_act, str):
|
407 |
+
self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
|
408 |
+
else:
|
409 |
+
self.intermediate_act_fn = self.config.hidden_act
|
410 |
+
|
411 |
+
self.output_dense = nn.Dense(
|
412 |
+
self.config.hidden_size,
|
413 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
414 |
+
dtype=self.dtype,
|
415 |
+
)
|
416 |
+
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
417 |
+
|
418 |
+
def __call__(self, hidden_states, deterministic=True):
|
419 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
420 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
421 |
+
hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
|
422 |
+
|
423 |
+
hidden_states = self.output_dense(hidden_states)
|
424 |
+
hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
|
425 |
+
return hidden_states
|
426 |
+
|
427 |
+
|
428 |
+
class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
429 |
+
config: Wav2Vec2Config
|
430 |
+
dtype: jnp.dtype = jnp.float32
|
431 |
+
|
432 |
+
def setup(self):
|
433 |
+
self.attention = FlaxWav2Vec2Attention(
|
434 |
+
config=self.config,
|
435 |
+
embed_dim=self.config.hidden_size,
|
436 |
+
num_heads=self.config.num_attention_heads,
|
437 |
+
dropout=self.config.attention_dropout,
|
438 |
+
dtype=self.dtype,
|
439 |
+
)
|
440 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
441 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
442 |
+
self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
|
443 |
+
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
444 |
+
|
445 |
+
def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
|
446 |
+
if self.config.use_scan:
|
447 |
+
hidden_states = hidden_states[0]
|
448 |
+
attn_residual = hidden_states
|
449 |
+
hidden_states = self.layer_norm(hidden_states)
|
450 |
+
hidden_states, attn_weights = self.attention(
|
451 |
+
hidden_states, attention_mask=attention_mask, deterministic=deterministic
|
452 |
+
)
|
453 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
454 |
+
hidden_states = attn_residual + hidden_states
|
455 |
+
hidden_states = hidden_states + self.feed_forward(
|
456 |
+
self.final_layer_norm(hidden_states), deterministic=deterministic
|
457 |
+
)
|
458 |
+
|
459 |
+
outputs = (hidden_states,)
|
460 |
+
|
461 |
+
if output_attentions:
|
462 |
+
outputs += (attn_weights,)
|
463 |
+
|
464 |
+
if self.config.use_scan:
|
465 |
+
outputs = (outputs, None)
|
466 |
+
|
467 |
+
return outputs
|
468 |
+
|
469 |
+
|
470 |
+
class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
|
471 |
+
config: Wav2Vec2Config
|
472 |
+
dtype: jnp.dtype = jnp.float32
|
473 |
+
|
474 |
+
@nn.compact
|
475 |
+
def __call__(
|
476 |
+
self,
|
477 |
+
hidden_states,
|
478 |
+
attention_mask=None,
|
479 |
+
deterministic: bool = True,
|
480 |
+
output_attentions: bool = False,
|
481 |
+
output_hidden_states: bool = False,
|
482 |
+
return_dict: bool = True,
|
483 |
+
):
|
484 |
+
all_attentions = () if output_attentions else None
|
485 |
+
all_hidden_states = () if output_hidden_states else None
|
486 |
+
|
487 |
+
num_layers = self.config.num_hidden_layers
|
488 |
+
BlockEncoderLayer = (
|
489 |
+
remat(
|
490 |
+
FlaxWav2Vec2EncoderLayerStableLayerNorm,
|
491 |
+
static_argnums=(2, 3),
|
492 |
+
prevent_cse=not self.config.use_scan,
|
493 |
+
)
|
494 |
+
if self.config.gradient_checkpointing
|
495 |
+
else FlaxWav2Vec2EncoderLayerStableLayerNorm
|
496 |
+
)
|
497 |
+
|
498 |
+
if self.config.use_scan:
|
499 |
+
# since all decoder layers are the same, we use nn.scan directly
|
500 |
+
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
|
501 |
+
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
|
502 |
+
hidden_states = (hidden_states,)
|
503 |
+
|
504 |
+
hidden_states, _ = scan_with_axes(
|
505 |
+
BlockEncoderLayer,
|
506 |
+
variable_axes={"params": 0, "cache": 0},
|
507 |
+
split_rngs={"params": True, "dropout": True},
|
508 |
+
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
|
509 |
+
length=num_layers,
|
510 |
+
)(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
|
511 |
+
hidden_states, attention_mask, deterministic, output_attentions
|
512 |
+
)
|
513 |
+
hidden_states = hidden_states[0]
|
514 |
+
|
515 |
+
else:
|
516 |
+
for layer in range(num_layers):
|
517 |
+
if output_hidden_states:
|
518 |
+
all_hidden_states += (hidden_states,)
|
519 |
+
|
520 |
+
layer_outputs = BlockEncoderLayer(
|
521 |
+
self.config,
|
522 |
+
dtype=self.dtype,
|
523 |
+
name=str(layer),
|
524 |
+
)(hidden_states, attention_mask, deterministic, output_attentions)
|
525 |
+
|
526 |
+
hidden_states = layer_outputs[0]
|
527 |
+
|
528 |
+
if output_attentions:
|
529 |
+
all_attentions += (layer_outputs[1],)
|
530 |
+
|
531 |
+
if output_hidden_states:
|
532 |
+
all_hidden_states += (hidden_states,)
|
533 |
+
|
534 |
+
outputs = (hidden_states, all_hidden_states, all_attentions)
|
535 |
+
|
536 |
+
if not return_dict:
|
537 |
+
return tuple(v for v in outputs if v is not None)
|
538 |
+
|
539 |
+
return FlaxBaseModelOutput(
|
540 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
|
545 |
+
config: Wav2Vec2Config
|
546 |
+
dtype: jnp.dtype = jnp.float32
|
547 |
+
|
548 |
+
def setup(self):
|
549 |
+
self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
|
550 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
551 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
552 |
+
self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
|
553 |
+
|
554 |
+
def __call__(
|
555 |
+
self,
|
556 |
+
hidden_states,
|
557 |
+
attention_mask=None,
|
558 |
+
deterministic=True,
|
559 |
+
output_attentions=False,
|
560 |
+
output_hidden_states=False,
|
561 |
+
return_dict=True,
|
562 |
+
):
|
563 |
+
|
564 |
+
if attention_mask is not None:
|
565 |
+
# make sure padded tokens are not attended to
|
566 |
+
hidden_states = jnp.where(
|
567 |
+
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
|
568 |
+
)
|
569 |
+
|
570 |
+
position_embeddings = self.pos_conv_embed(hidden_states)
|
571 |
+
|
572 |
+
hidden_states = hidden_states + position_embeddings
|
573 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
574 |
+
|
575 |
+
outputs = self.layers(
|
576 |
+
hidden_states,
|
577 |
+
attention_mask,
|
578 |
+
output_attentions=output_attentions,
|
579 |
+
output_hidden_states=output_hidden_states,
|
580 |
+
return_dict=return_dict,
|
581 |
+
)
|
582 |
+
|
583 |
+
last_hidden_state = self.layer_norm(outputs[0])
|
584 |
+
|
585 |
+
# update the last element in `hidden_states` after applying `layernorm` above
|
586 |
+
hidden_states = None
|
587 |
+
if output_hidden_states:
|
588 |
+
hidden_states = outputs[1]
|
589 |
+
hidden_states = hidden_states[:-1] + (last_hidden_state,)
|
590 |
+
|
591 |
+
if not return_dict:
|
592 |
+
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
|
593 |
+
return tuple(v for v in outputs if v is not None)
|
594 |
+
|
595 |
+
return FlaxBaseModelOutput(
|
596 |
+
last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
|
597 |
+
)
|
598 |
+
|
599 |
+
|
600 |
+
class FlaxWav2Vec2Adapter(nn.Module):
|
601 |
+
config: Wav2Vec2Config
|
602 |
+
dtype: jnp.dtype = jnp.float32
|
603 |
+
|
604 |
+
def setup(self):
|
605 |
+
# hidden_states require down-projection if feature dims don't match
|
606 |
+
if self.config.output_hidden_size != self.config.hidden_size:
|
607 |
+
self.proj = nn.Dense(
|
608 |
+
self.config.output_hidden_size,
|
609 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
610 |
+
dtype=self.dtype,
|
611 |
+
)
|
612 |
+
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
613 |
+
else:
|
614 |
+
self.proj = self.proj_layer_norm = None
|
615 |
+
|
616 |
+
self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
|
617 |
+
|
618 |
+
def __call__(self, hidden_states, deterministic=True):
|
619 |
+
# down-project hidden_states if required
|
620 |
+
if self.proj is not None and self.proj_layer_norm is not None:
|
621 |
+
hidden_states = self.proj(hidden_states)
|
622 |
+
hidden_states = self.proj_layer_norm(hidden_states)
|
623 |
+
|
624 |
+
hidden_states = self.layers(hidden_states)
|
625 |
+
|
626 |
+
return hidden_states
|
627 |
+
|
628 |
+
|
629 |
+
class FlaxWav2Vec2AdapterLayer(nn.Module):
|
630 |
+
config: Wav2Vec2Config
|
631 |
+
dtype: jnp.dtype = jnp.float32
|
632 |
+
|
633 |
+
def setup(self):
|
634 |
+
self.conv = nn.Conv(
|
635 |
+
features=2 * self.config.output_hidden_size,
|
636 |
+
kernel_size=(self.config.adapter_kernel_size,),
|
637 |
+
strides=(self.config.adapter_stride,),
|
638 |
+
padding=((1, 1),),
|
639 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
640 |
+
dtype=self.dtype,
|
641 |
+
)
|
642 |
+
|
643 |
+
def __call__(self, hidden_states):
|
644 |
+
hidden_states = self.conv(hidden_states)
|
645 |
+
hidden_states = nn.glu(hidden_states, axis=2)
|
646 |
+
|
647 |
+
return hidden_states
|
648 |
+
|
649 |
+
|
650 |
+
class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
|
651 |
+
config: Wav2Vec2Config
|
652 |
+
dtype: jnp.dtype = jnp.float32
|
653 |
+
|
654 |
+
def setup(self):
|
655 |
+
BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
|
656 |
+
self.layers = [
|
657 |
+
BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
|
658 |
+
for i in range(self.config.num_adapter_layers)
|
659 |
+
]
|
660 |
+
|
661 |
+
def __call__(self, hidden_states):
|
662 |
+
for conv_layer in self.layers:
|
663 |
+
hidden_states = conv_layer(hidden_states)
|
664 |
+
|
665 |
+
return hidden_states
|
666 |
+
|
667 |
+
|
668 |
+
class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
669 |
+
"""
|
670 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
671 |
+
models.
|
672 |
+
"""
|
673 |
+
|
674 |
+
config_class = Wav2Vec2Config
|
675 |
+
base_model_prefix: str = "wav2vec2"
|
676 |
+
main_input_name = "input_values"
|
677 |
+
module_class: nn.Module = None
|
678 |
+
|
679 |
+
def __init__(
|
680 |
+
self,
|
681 |
+
config: Wav2Vec2Config,
|
682 |
+
input_shape: Tuple = (1, 1024),
|
683 |
+
seed: int = 0,
|
684 |
+
dtype: jnp.dtype = jnp.float32,
|
685 |
+
_do_init: bool = True,
|
686 |
+
**kwargs,
|
687 |
+
):
|
688 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
689 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
690 |
+
|
691 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
692 |
+
# init input tensors
|
693 |
+
input_values = jnp.zeros(input_shape, dtype="i4")
|
694 |
+
attention_mask = jnp.ones_like(input_values)
|
695 |
+
params_rng, dropout_rng = jax.random.split(rng, 2)
|
696 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
697 |
+
|
698 |
+
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
|
699 |
+
|
700 |
+
def __call__(
|
701 |
+
self,
|
702 |
+
input_values,
|
703 |
+
attention_mask=None,
|
704 |
+
mask_time_indices=None,
|
705 |
+
extract_features=None,
|
706 |
+
params: dict = None,
|
707 |
+
dropout_rng: jax.random.PRNGKey = None,
|
708 |
+
train: bool = False,
|
709 |
+
output_attentions: Optional[bool] = None,
|
710 |
+
output_hidden_states: Optional[bool] = None,
|
711 |
+
output_features: Optional[bool] = None,
|
712 |
+
freeze_feature_encoder: bool = False,
|
713 |
+
return_dict: Optional[bool] = None,
|
714 |
+
):
|
715 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
716 |
+
output_hidden_states = (
|
717 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
718 |
+
)
|
719 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
720 |
+
|
721 |
+
if attention_mask is None:
|
722 |
+
batch_size, sequence_length = input_values.shape
|
723 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
724 |
+
|
725 |
+
if extract_features is not None:
|
726 |
+
extract_features = jnp.array(extract_features, dtype="f4")
|
727 |
+
|
728 |
+
# Handle any PRNG if needed
|
729 |
+
rngs = {}
|
730 |
+
if dropout_rng is not None:
|
731 |
+
rngs["dropout"] = dropout_rng
|
732 |
+
|
733 |
+
inputs = {"params": params or self.params}
|
734 |
+
|
735 |
+
return self.module.apply(
|
736 |
+
inputs,
|
737 |
+
jnp.array(input_values, dtype="f4"),
|
738 |
+
jnp.array(attention_mask, dtype="i4"),
|
739 |
+
mask_time_indices,
|
740 |
+
extract_features,
|
741 |
+
not train,
|
742 |
+
output_attentions,
|
743 |
+
output_hidden_states,
|
744 |
+
output_features,
|
745 |
+
freeze_feature_encoder,
|
746 |
+
return_dict,
|
747 |
+
rngs=rngs,
|
748 |
+
)
|
749 |
+
|
750 |
+
def _get_feat_extract_output_lengths(
|
751 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
752 |
+
):
|
753 |
+
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
|
754 |
+
|
755 |
+
def _get_feature_vector_attention_mask(
|
756 |
+
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
757 |
+
):
|
758 |
+
return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
|
759 |
+
|
760 |
+
|
761 |
+
class FlaxWav2Vec2Module(nn.Module):
|
762 |
+
config: Wav2Vec2Config
|
763 |
+
dtype: jnp.dtype = jnp.float32
|
764 |
+
|
765 |
+
def setup(self):
|
766 |
+
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
|
767 |
+
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
|
768 |
+
self.masked_spec_embed = self.param(
|
769 |
+
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
|
770 |
+
)
|
771 |
+
|
772 |
+
if self.config.do_stable_layer_norm:
|
773 |
+
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
|
774 |
+
else:
|
775 |
+
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
|
776 |
+
|
777 |
+
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
|
778 |
+
|
779 |
+
def __call__(
|
780 |
+
self,
|
781 |
+
input_values,
|
782 |
+
attention_mask=None,
|
783 |
+
mask_time_indices=None,
|
784 |
+
extract_features=None,
|
785 |
+
deterministic=True,
|
786 |
+
output_attentions=None,
|
787 |
+
output_hidden_states=None,
|
788 |
+
output_features=False,
|
789 |
+
freeze_feature_encoder=False,
|
790 |
+
return_dict=None,
|
791 |
+
):
|
792 |
+
|
793 |
+
# forward pass through the feature extractor if features not specified
|
794 |
+
if extract_features is None:
|
795 |
+
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
|
796 |
+
|
797 |
+
if output_features:
|
798 |
+
return extract_features
|
799 |
+
|
800 |
+
# make sure that no loss is computed on padded inputs
|
801 |
+
if attention_mask is not None:
|
802 |
+
# compute reduced attention_mask corresponding to feature vectors
|
803 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
804 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
805 |
+
)
|
806 |
+
|
807 |
+
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
808 |
+
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
|
809 |
+
hidden_states = jnp.where(
|
810 |
+
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
|
811 |
+
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
|
812 |
+
hidden_states,
|
813 |
+
)
|
814 |
+
|
815 |
+
encoder_outputs = self.encoder(
|
816 |
+
hidden_states,
|
817 |
+
attention_mask=attention_mask,
|
818 |
+
deterministic=deterministic,
|
819 |
+
output_attentions=output_attentions,
|
820 |
+
output_hidden_states=output_hidden_states,
|
821 |
+
return_dict=return_dict,
|
822 |
+
)
|
823 |
+
|
824 |
+
hidden_states = encoder_outputs[0]
|
825 |
+
|
826 |
+
if self.adapter is not None:
|
827 |
+
hidden_states = self.adapter(hidden_states)
|
828 |
+
|
829 |
+
if not return_dict:
|
830 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
831 |
+
|
832 |
+
return FlaxWav2Vec2BaseModelOutput(
|
833 |
+
last_hidden_state=hidden_states,
|
834 |
+
extract_features=extract_features,
|
835 |
+
hidden_states=encoder_outputs.hidden_states,
|
836 |
+
attentions=encoder_outputs.attentions,
|
837 |
+
)
|
838 |
+
|
839 |
+
def _get_feat_extract_output_lengths(
|
840 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
841 |
+
):
|
842 |
+
"""
|
843 |
+
Computes the output length of the convolutional layers
|
844 |
+
"""
|
845 |
+
|
846 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
847 |
+
|
848 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
849 |
+
# 1D convolutional layer output length formula taken
|
850 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
851 |
+
return (input_length - kernel_size) // stride + 1
|
852 |
+
|
853 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
854 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
855 |
+
|
856 |
+
if add_adapter:
|
857 |
+
for _ in range(self.config.num_adapter_layers):
|
858 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
859 |
+
|
860 |
+
return input_lengths
|
861 |
+
|
862 |
+
def _get_feature_vector_attention_mask(
|
863 |
+
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
864 |
+
):
|
865 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
866 |
+
# on inference mode.
|
867 |
+
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
868 |
+
|
869 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
870 |
+
|
871 |
+
batch_size = attention_mask.shape[0]
|
872 |
+
|
873 |
+
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
874 |
+
# these two operations makes sure that all values
|
875 |
+
# before the output lengths indices are attended to
|
876 |
+
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
877 |
+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
878 |
+
return attention_mask
|
879 |
+
|
880 |
+
|
881 |
+
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
|
882 |
+
module_class = FlaxWav2Vec2Module
|
883 |
+
|
884 |
+
|
885 |
+
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
886 |
+
config: Wav2Vec2Config
|
887 |
+
dtype: jnp.dtype = jnp.float32
|
888 |
+
|
889 |
+
def setup(self):
|
890 |
+
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
|
891 |
+
self.dropout = nn.Dropout(rate=self.config.final_dropout)
|
892 |
+
self.lm_head = nn.Dense(
|
893 |
+
self.config.vocab_size,
|
894 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
895 |
+
dtype=self.dtype,
|
896 |
+
)
|
897 |
+
|
898 |
+
def __call__(
|
899 |
+
self,
|
900 |
+
input_values,
|
901 |
+
attention_mask=None,
|
902 |
+
mask_time_indices=None,
|
903 |
+
extract_features=None,
|
904 |
+
deterministic=True,
|
905 |
+
output_attentions=None,
|
906 |
+
output_hidden_states=None,
|
907 |
+
output_features=False,
|
908 |
+
freeze_feature_encoder=False,
|
909 |
+
return_dict=None,
|
910 |
+
):
|
911 |
+
outputs = self.wav2vec2(
|
912 |
+
input_values,
|
913 |
+
attention_mask=attention_mask,
|
914 |
+
mask_time_indices=mask_time_indices,
|
915 |
+
deterministic=deterministic,
|
916 |
+
output_attentions=output_attentions,
|
917 |
+
output_hidden_states=output_hidden_states,
|
918 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
919 |
+
return_dict=return_dict,
|
920 |
+
)
|
921 |
+
|
922 |
+
hidden_states = outputs[0]
|
923 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
924 |
+
|
925 |
+
logits = self.lm_head(hidden_states)
|
926 |
+
|
927 |
+
if not return_dict:
|
928 |
+
return (logits,) + outputs[2:]
|
929 |
+
|
930 |
+
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
931 |
+
|
932 |
+
def _get_feat_extract_output_lengths(
|
933 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
934 |
+
):
|
935 |
+
"""
|
936 |
+
Computes the output length of the convolutional layers
|
937 |
+
"""
|
938 |
+
|
939 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
940 |
+
|
941 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
942 |
+
# 1D convolutional layer output length formula taken
|
943 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
944 |
+
return (input_length - kernel_size) // stride + 1
|
945 |
+
|
946 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
947 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
948 |
+
|
949 |
+
if add_adapter:
|
950 |
+
for _ in range(self.config.num_adapter_layers):
|
951 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
952 |
+
|
953 |
+
return input_lengths
|
954 |
+
|
955 |
+
def _get_feature_vector_attention_mask(
|
956 |
+
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
957 |
+
):
|
958 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
959 |
+
# on inference mode.
|
960 |
+
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
961 |
+
|
962 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
963 |
+
|
964 |
+
batch_size = attention_mask.shape[0]
|
965 |
+
|
966 |
+
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
967 |
+
# these two operations makes sure that all values
|
968 |
+
# before the output lengths indices are attended to
|
969 |
+
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
970 |
+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
971 |
+
return attention_mask
|
972 |
+
|
973 |
+
|
974 |
+
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
|
975 |
+
module_class = FlaxWav2Vec2ForCTCModule
|
preprocessor_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0.0,
|
7 |
+
"processor_class": "Wav2Vec2Processor",
|
8 |
+
"return_attention_mask": true,
|
9 |
+
"sampling_rate": 16000
|
10 |
+
}
|
run_flax_speech_recognition_ctc.py
ADDED
@@ -0,0 +1,1398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
|
17 |
+
"""
|
18 |
+
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
19 |
+
|
20 |
+
import logging
|
21 |
+
import math
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
import time
|
25 |
+
from dataclasses import dataclass, field
|
26 |
+
from pathlib import Path
|
27 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
28 |
+
|
29 |
+
import datasets
|
30 |
+
import numpy as np
|
31 |
+
from datasets import DatasetDict, load_dataset, load_metric
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
import flax
|
35 |
+
import jax
|
36 |
+
import jax.numpy as jnp
|
37 |
+
import optax
|
38 |
+
import transformers
|
39 |
+
import wandb as wandb
|
40 |
+
from flax import core, jax_utils, struct, traverse_util
|
41 |
+
from flax.jax_utils import unreplicate, pad_shard_unpad
|
42 |
+
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
43 |
+
from huggingface_hub import Repository
|
44 |
+
from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
|
45 |
+
from optax._src import linear_algebra
|
46 |
+
from transformers import (
|
47 |
+
AutoFeatureExtractor,
|
48 |
+
AutoProcessor,
|
49 |
+
AutoTokenizer,
|
50 |
+
HfArgumentParser,
|
51 |
+
TrainingArguments,
|
52 |
+
is_tensorboard_available,
|
53 |
+
)
|
54 |
+
from transformers.file_utils import get_full_repo_name
|
55 |
+
from transformers.utils import check_min_version
|
56 |
+
from transformers.utils.versions import require_version
|
57 |
+
|
58 |
+
|
59 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
60 |
+
check_min_version("4.17.0.dev0")
|
61 |
+
|
62 |
+
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
63 |
+
|
64 |
+
logger = logging.getLogger(__name__)
|
65 |
+
|
66 |
+
|
67 |
+
@flax.struct.dataclass
|
68 |
+
class ModelArguments:
|
69 |
+
"""
|
70 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
71 |
+
"""
|
72 |
+
|
73 |
+
model_name_or_path: str = field(
|
74 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
75 |
+
)
|
76 |
+
config_name: Optional[str] = field(
|
77 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
78 |
+
)
|
79 |
+
tokenizer_name: Optional[str] = field(
|
80 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
81 |
+
)
|
82 |
+
feature_extractor_name: Optional[str] = field(
|
83 |
+
default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
|
84 |
+
)
|
85 |
+
cache_dir: Optional[str] = field(
|
86 |
+
default=None,
|
87 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
88 |
+
)
|
89 |
+
use_fast_tokenizer: bool = field(
|
90 |
+
default=True,
|
91 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
92 |
+
)
|
93 |
+
model_revision: str = field(
|
94 |
+
default="main",
|
95 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
96 |
+
)
|
97 |
+
use_auth_token: bool = field(
|
98 |
+
default=False,
|
99 |
+
metadata={
|
100 |
+
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
101 |
+
"with private models)."
|
102 |
+
},
|
103 |
+
)
|
104 |
+
freeze_feature_encoder: bool = field(
|
105 |
+
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
|
106 |
+
)
|
107 |
+
activation_dropout: float = field(
|
108 |
+
default=0.1,
|
109 |
+
metadata={
|
110 |
+
"help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
|
111 |
+
},
|
112 |
+
)
|
113 |
+
hidden_dropout: float = field(
|
114 |
+
default=0.1,
|
115 |
+
metadata={
|
116 |
+
"help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
|
117 |
+
},
|
118 |
+
)
|
119 |
+
feat_proj_dropout: float = field(
|
120 |
+
default=0.0,
|
121 |
+
metadata={
|
122 |
+
"help": "The feat proj dropout probability for feature encoder representations."
|
123 |
+
},
|
124 |
+
)
|
125 |
+
mask_time_prob: float = field(
|
126 |
+
default=0.1,
|
127 |
+
metadata={
|
128 |
+
"help": "The spec aug dropout probability for feature encoder representations."
|
129 |
+
},
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
@flax.struct.dataclass
|
134 |
+
class DataTrainingArguments:
|
135 |
+
"""
|
136 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
137 |
+
"""
|
138 |
+
|
139 |
+
dataset_name: str = field(
|
140 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
141 |
+
)
|
142 |
+
dataset_config_name: Optional[str] = field(
|
143 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
144 |
+
)
|
145 |
+
text_column: Optional[str] = field(
|
146 |
+
default=None,
|
147 |
+
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
148 |
+
)
|
149 |
+
dataset_cache_dir: Optional[str] = field(
|
150 |
+
default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
|
151 |
+
)
|
152 |
+
overwrite_cache: bool = field(
|
153 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
154 |
+
)
|
155 |
+
preprocessing_num_workers: Optional[int] = field(
|
156 |
+
default=None,
|
157 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
158 |
+
)
|
159 |
+
max_train_samples: Optional[int] = field(
|
160 |
+
default=None,
|
161 |
+
metadata={
|
162 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
163 |
+
"value if set."
|
164 |
+
},
|
165 |
+
)
|
166 |
+
max_eval_samples: Optional[int] = field(
|
167 |
+
default=None,
|
168 |
+
metadata={
|
169 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
170 |
+
"value if set."
|
171 |
+
},
|
172 |
+
)
|
173 |
+
max_test_samples: Optional[int] = field(
|
174 |
+
default=None,
|
175 |
+
metadata={
|
176 |
+
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
177 |
+
"value if set."
|
178 |
+
},
|
179 |
+
)
|
180 |
+
audio_column_name: str = field(
|
181 |
+
default="audio",
|
182 |
+
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
183 |
+
)
|
184 |
+
text_column_name: str = field(
|
185 |
+
default="text",
|
186 |
+
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
|
187 |
+
)
|
188 |
+
max_duration_in_seconds: float = field(
|
189 |
+
default=20.0,
|
190 |
+
metadata={
|
191 |
+
"help": "Filter audio files in the training set that are longer than `max_duration_in_seconds` seconds"
|
192 |
+
},
|
193 |
+
)
|
194 |
+
min_duration_in_seconds: float = field(
|
195 |
+
default=0.0, metadata={"help": "Filter audio files in the training set that are shorter than `min_duration_in_seconds` seconds"}
|
196 |
+
)
|
197 |
+
max_label_length: Optional[int] = field(
|
198 |
+
default=512,
|
199 |
+
metadata={
|
200 |
+
"help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
|
201 |
+
"than this will be filtered."
|
202 |
+
},
|
203 |
+
)
|
204 |
+
min_label_length: Optional[int] = field(
|
205 |
+
default=0,
|
206 |
+
metadata={
|
207 |
+
"help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
|
208 |
+
"than this will be filtered."
|
209 |
+
},
|
210 |
+
)
|
211 |
+
max_eval_duration_in_seconds: float = field(
|
212 |
+
default=None,
|
213 |
+
metadata={
|
214 |
+
"help": "Filter audio files in the eval/test set that are longer than `max_duration_in_seconds` seconds"
|
215 |
+
},
|
216 |
+
)
|
217 |
+
pad_input_to_multiple_of: Optional[int] = field(
|
218 |
+
default=32000,
|
219 |
+
metadata={
|
220 |
+
"help": "If set will pad the input sequence to a multiple of the provided value. "
|
221 |
+
"This is important to avoid triggering recompilations on TPU."
|
222 |
+
},
|
223 |
+
)
|
224 |
+
pad_target_to_multiple_of: Optional[int] = field(
|
225 |
+
default=None,
|
226 |
+
metadata={
|
227 |
+
"help": "If set will pad the target sequence to a multiple of the provided value. "
|
228 |
+
"This is important to avoid triggering recompilations on TPU."
|
229 |
+
},
|
230 |
+
)
|
231 |
+
preprocessing_only: bool = field(
|
232 |
+
default=False,
|
233 |
+
metadata={
|
234 |
+
"help": "Whether to only do data preprocessing and skip training. "
|
235 |
+
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
236 |
+
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
237 |
+
"so that the cached datasets can consequently be loaded in distributed training"
|
238 |
+
},
|
239 |
+
)
|
240 |
+
train_split_name: str = field(
|
241 |
+
default="train",
|
242 |
+
metadata={
|
243 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
244 |
+
},
|
245 |
+
)
|
246 |
+
eval_split_name: str = field(
|
247 |
+
default="validation",
|
248 |
+
metadata={
|
249 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
250 |
+
},
|
251 |
+
)
|
252 |
+
wandb_project: str = field(
|
253 |
+
default="flax-speech-recognition-ctc",
|
254 |
+
metadata={"help": "The name of the wandb project."},
|
255 |
+
)
|
256 |
+
wandb_name: str = field(
|
257 |
+
default=None,
|
258 |
+
metadata={"help": "The name of the wandb run."},
|
259 |
+
)
|
260 |
+
wandb_job_type: str = field(
|
261 |
+
default="CTC",
|
262 |
+
metadata={"help": "The name of the wandb job type."},
|
263 |
+
)
|
264 |
+
test_split_name: str = field(
|
265 |
+
default="test",
|
266 |
+
metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
# @flax.struct.dataclass
|
271 |
+
@dataclass
|
272 |
+
class FlaxTrainingArguments(TrainingArguments):
|
273 |
+
precision: str = field(
|
274 |
+
default="full",
|
275 |
+
metadata={
|
276 |
+
"help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
|
277 |
+
"**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
|
278 |
+
},
|
279 |
+
)
|
280 |
+
matmul_precision: str = field(
|
281 |
+
default="default",
|
282 |
+
metadata={
|
283 |
+
"help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
|
284 |
+
"This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
|
285 |
+
"This configuration option does not change the behaviours of such calls with explicit precision arguments; "
|
286 |
+
"it only changes the behaviors of calls with no such argument provided. "
|
287 |
+
"One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
|
288 |
+
},
|
289 |
+
)
|
290 |
+
multisteps: bool = field(
|
291 |
+
default=False,
|
292 |
+
metadata={
|
293 |
+
"help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
|
294 |
+
"a custom gradient accumulation implementation will be employed."
|
295 |
+
},
|
296 |
+
)
|
297 |
+
|
298 |
+
|
299 |
+
def to_fp32(t):
|
300 |
+
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
301 |
+
|
302 |
+
|
303 |
+
def to_bf16(t):
|
304 |
+
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
|
305 |
+
|
306 |
+
|
307 |
+
class MixedPrecisionTrainState(struct.PyTreeNode):
|
308 |
+
"""Train state for use with a single Optax optimizer.
|
309 |
+
Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
|
310 |
+
|
311 |
+
Synopsis::
|
312 |
+
|
313 |
+
state = TrainState.create(
|
314 |
+
apply_fn=model.apply,
|
315 |
+
params=variables['params'],
|
316 |
+
tx=tx)
|
317 |
+
grad_fn = jax.grad(make_loss_fn(state.apply_fn))
|
318 |
+
for batch in data:
|
319 |
+
grads = grad_fn(state.params, batch)
|
320 |
+
state = state.apply_gradients(grads=grads)
|
321 |
+
|
322 |
+
Args:
|
323 |
+
step: Counter starts at 0 and is incremented by every call to
|
324 |
+
`.apply_gradients()`.
|
325 |
+
apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
|
326 |
+
convenience to have a shorter params list for the `train_step()` function
|
327 |
+
in your training loop.
|
328 |
+
params: The parameters to be updated by `tx` and used by `apply_fn`.
|
329 |
+
tx: An Optax gradient transformation.
|
330 |
+
opt_state: The state for `tx`.
|
331 |
+
dropout_rng: PRNG key for stochastic operations.
|
332 |
+
bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
|
333 |
+
"""
|
334 |
+
|
335 |
+
step: int
|
336 |
+
apply_fn: Callable = struct.field(pytree_node=False)
|
337 |
+
get_attention_mask_fn: Callable = struct.field(pytree_node=False)
|
338 |
+
params: core.FrozenDict[str, Any]
|
339 |
+
tx: optax.GradientTransformation = struct.field(pytree_node=False)
|
340 |
+
opt_state: optax.OptState
|
341 |
+
dropout_rng: jnp.ndarray
|
342 |
+
max_grad_norm: Optional[float] = 1.0
|
343 |
+
|
344 |
+
def apply_gradients(self, *, grads, to_dtype, **kwargs):
|
345 |
+
"""Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
|
346 |
+
|
347 |
+
Note that internally this function calls `.tx.update()` followed by a call
|
348 |
+
to `optax.apply_updates()` to update `params` and `opt_state`.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
grads: Gradients that have the same pytree structure as `.params`.
|
352 |
+
**kwargs: Additional dataclass attributes that should be `.replace()`-ed.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
An updated instance of `self` with `step` incremented by one, `params`
|
356 |
+
and `opt_state` updated by applying `grads`, and additional attributes
|
357 |
+
replaced as specified by `kwargs`.
|
358 |
+
"""
|
359 |
+
|
360 |
+
# clip gradients by global l2 norm
|
361 |
+
casted_max_grad_norm = to_dtype(self.max_grad_norm)
|
362 |
+
g_norm = linear_algebra.global_norm(grads)
|
363 |
+
g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
|
364 |
+
grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
|
365 |
+
|
366 |
+
# perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
|
367 |
+
# grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
|
368 |
+
updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
|
369 |
+
|
370 |
+
new_params = optax.apply_updates(self.params, updates)
|
371 |
+
return self.replace(
|
372 |
+
step=self.step + 1,
|
373 |
+
params=new_params,
|
374 |
+
opt_state=to_dtype(new_opt_state),
|
375 |
+
**kwargs,
|
376 |
+
)
|
377 |
+
|
378 |
+
@classmethod
|
379 |
+
def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
|
380 |
+
"""Creates a new instance with `step=0` and initialized `opt_state`."""
|
381 |
+
# downcast optimizer state to bf16 if mixed-precision training
|
382 |
+
opt_state = tx.init(to_dtype(params)) if tx is not None else None
|
383 |
+
return cls(
|
384 |
+
step=0,
|
385 |
+
apply_fn=apply_fn,
|
386 |
+
params=params,
|
387 |
+
tx=tx,
|
388 |
+
opt_state=opt_state,
|
389 |
+
**kwargs,
|
390 |
+
)
|
391 |
+
|
392 |
+
def replicate(self):
|
393 |
+
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
394 |
+
|
395 |
+
|
396 |
+
@flax.struct.dataclass
|
397 |
+
class FlaxDataCollatorSpeechSeq2SeqWithPadding:
|
398 |
+
"""
|
399 |
+
Data collator that will dynamically pad the inputs received.
|
400 |
+
Args:
|
401 |
+
processor ([`Wav2Vec2Processor`])
|
402 |
+
The processor used for proccessing the data.
|
403 |
+
decoder_start_token_id (:obj: `int`)
|
404 |
+
The begin-of-sentence of the decoder.
|
405 |
+
input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
406 |
+
Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
|
407 |
+
among:
|
408 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
409 |
+
sequence if provided).
|
410 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
411 |
+
maximum acceptable input length for the model if that argument is not provided.
|
412 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
413 |
+
different lengths).
|
414 |
+
target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
415 |
+
Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
|
416 |
+
See above for details.
|
417 |
+
max_input_length (:obj:`float`, `optional`):
|
418 |
+
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
419 |
+
pad_input_to_multiple_of (:obj:`int`, `optional`):
|
420 |
+
If set will pad the input sequence to a multiple of the provided value.
|
421 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
422 |
+
7.5 (Volta).
|
423 |
+
pad_target_to_multiple_of (:obj:`int`, `optional`):
|
424 |
+
If set will pad the target sequence to a multiple of the provided value.
|
425 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
426 |
+
7.5 (Volta).
|
427 |
+
"""
|
428 |
+
|
429 |
+
processor: Any
|
430 |
+
input_padding: Union[bool, str] = "longest"
|
431 |
+
label_padding: Union[bool, str] = "max_length"
|
432 |
+
pad_input_to_multiple_of: Optional[int] = None
|
433 |
+
pad_to_multiple_of_label: Optional[int] = None
|
434 |
+
max_input_length: Optional[float] = None
|
435 |
+
max_label_length: Optional[float] = None
|
436 |
+
|
437 |
+
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
|
438 |
+
# split inputs and labels since they have to be of different lengths and need
|
439 |
+
# different padding methods
|
440 |
+
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
441 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
442 |
+
|
443 |
+
# reformat list to dict and set to pytorch format
|
444 |
+
batch = self.processor.feature_extractor.pad(
|
445 |
+
input_features,
|
446 |
+
max_length=self.max_input_length,
|
447 |
+
padding=self.input_padding,
|
448 |
+
pad_to_multiple_of=self.pad_input_to_multiple_of,
|
449 |
+
return_tensors="np",
|
450 |
+
)
|
451 |
+
|
452 |
+
labels_batch = self.processor.tokenizer.pad(
|
453 |
+
label_features,
|
454 |
+
max_length=self.max_label_length,
|
455 |
+
padding=self.label_padding,
|
456 |
+
pad_to_multiple_of=self.pad_to_multiple_of_label,
|
457 |
+
return_tensors="np",
|
458 |
+
)
|
459 |
+
|
460 |
+
labels = labels_batch["input_ids"]
|
461 |
+
labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
|
462 |
+
labels = labels.filled(fill_value=-100)
|
463 |
+
|
464 |
+
batch["labels"] = labels
|
465 |
+
|
466 |
+
return batch
|
467 |
+
|
468 |
+
|
469 |
+
def get_grouped_indices(
|
470 |
+
dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
|
471 |
+
) -> np.array:
|
472 |
+
"""
|
473 |
+
Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
|
474 |
+
Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
|
475 |
+
lengths. To do this, the indices are:
|
476 |
+
|
477 |
+
- randomly permuted (if a JAX rng is specified)
|
478 |
+
- grouped in mega-batches of size `mega_batch_mult * batch_size`
|
479 |
+
- sorted by length in each mega-batch
|
480 |
+
|
481 |
+
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
|
482 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
483 |
+
"""
|
484 |
+
lengths = dataset["input_length"]
|
485 |
+
|
486 |
+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
487 |
+
if mega_batch_mult is None:
|
488 |
+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
489 |
+
# Just in case, for tiny datasets
|
490 |
+
if mega_batch_mult == 0:
|
491 |
+
mega_batch_mult = 1
|
492 |
+
|
493 |
+
# We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
|
494 |
+
num_samples = len(lengths)
|
495 |
+
indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
|
496 |
+
|
497 |
+
megabatch_size = mega_batch_mult * batch_size
|
498 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
499 |
+
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
|
500 |
+
|
501 |
+
# The rest is to get the biggest batch first.
|
502 |
+
# Since each megabatch is sorted by descending length, the longest element is the first
|
503 |
+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
504 |
+
max_idx = np.argmax(megabatch_maximums).item()
|
505 |
+
# Switch to put the longest batch in first position
|
506 |
+
# (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
|
507 |
+
megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
|
508 |
+
|
509 |
+
megabatches = np.array([i for megabatch in megabatches for i in megabatch])
|
510 |
+
|
511 |
+
return megabatches
|
512 |
+
|
513 |
+
|
514 |
+
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
|
515 |
+
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
|
516 |
+
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
|
517 |
+
num_samples = len(samples_idx)
|
518 |
+
if drop_last:
|
519 |
+
samples_to_remove = num_samples % batch_size
|
520 |
+
if samples_to_remove != 0:
|
521 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
522 |
+
sections_split = num_samples // batch_size
|
523 |
+
samples_idx = samples_idx.reshape((sections_split, batch_size))
|
524 |
+
else:
|
525 |
+
sections_split = math.ceil(num_samples / batch_size)
|
526 |
+
samples_idx = np.array_split(samples_idx, sections_split)
|
527 |
+
return samples_idx
|
528 |
+
|
529 |
+
|
530 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
531 |
+
summary_writer.scalar("train_time", train_time, step)
|
532 |
+
|
533 |
+
train_metrics = get_metrics(train_metrics)
|
534 |
+
for key, vals in train_metrics.items():
|
535 |
+
tag = f"train_{key}"
|
536 |
+
for i, val in enumerate(vals):
|
537 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
538 |
+
|
539 |
+
|
540 |
+
def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
|
541 |
+
for metric_name, value in eval_metrics.items():
|
542 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
543 |
+
|
544 |
+
if pred_str is not None:
|
545 |
+
# write output actual predictions for debugging
|
546 |
+
summary_writer.text("eval_predictions", "\n".join(pred_str), step)
|
547 |
+
|
548 |
+
|
549 |
+
def write_wandb_log(metrics, step, prefix=None):
|
550 |
+
if jax.process_index() == 0:
|
551 |
+
log_metrics = {}
|
552 |
+
for k, v in metrics.items():
|
553 |
+
if "layer" in k:
|
554 |
+
log_metrics[f"{k}/"] = v
|
555 |
+
elif prefix is not None:
|
556 |
+
log_metrics[f"{prefix}/{k}"] = v
|
557 |
+
else:
|
558 |
+
log_metrics[k] = v
|
559 |
+
wandb.log(log_metrics, step)
|
560 |
+
|
561 |
+
|
562 |
+
def write_wandb_pred(pred_str, label_str, step, final_step=False, prefix="eval"):
|
563 |
+
if jax.process_index() == 0:
|
564 |
+
# convert str data to a wandb compatible format
|
565 |
+
str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
|
566 |
+
if not final_step:
|
567 |
+
# we'll log the first 50 predictions for each intermediate epoch
|
568 |
+
wandb.log(
|
569 |
+
{
|
570 |
+
f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
|
571 |
+
columns=["label_str", "pred_str"], data=str_data[:50]
|
572 |
+
)
|
573 |
+
},
|
574 |
+
step,
|
575 |
+
)
|
576 |
+
else:
|
577 |
+
# we'll log all predictions for the last epoch
|
578 |
+
wandb.log(
|
579 |
+
{
|
580 |
+
f"{prefix}/step_{int(step / 1000)}k_all": wandb.Table(
|
581 |
+
columns=["label_str", "pred_str"], data=str_data
|
582 |
+
)
|
583 |
+
},
|
584 |
+
step,
|
585 |
+
)
|
586 |
+
|
587 |
+
|
588 |
+
def create_learning_rate_fn(
|
589 |
+
num_train_steps: int, num_warmup_steps: int, learning_rate: float
|
590 |
+
) -> Callable[[int], jnp.array]:
|
591 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
592 |
+
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
593 |
+
decay_fn = optax.linear_schedule(
|
594 |
+
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
595 |
+
)
|
596 |
+
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
597 |
+
return schedule_fn
|
598 |
+
|
599 |
+
|
600 |
+
def ctc_loss(
|
601 |
+
logits,
|
602 |
+
logits_attention_mask,
|
603 |
+
labels,
|
604 |
+
blank_id,
|
605 |
+
loss_reduction="mean",
|
606 |
+
output_emission_dict=False,
|
607 |
+
log_epsilon=-100000.0,
|
608 |
+
):
|
609 |
+
"""Computes CTC loss.
|
610 |
+
This function performs forward computation over an FSA with `N * 2` states
|
611 |
+
where `N` is the max number of labels. The states are split into two groups:
|
612 |
+
Phi states and emission states. a phi-state accepts repetition of
|
613 |
+
phi (blank)-symbols and transits to emission state when the correct label is
|
614 |
+
observed. An emission state accepts repetition of the label and transits to
|
615 |
+
the next phi states at any time (so called epsilon-transition).
|
616 |
+
Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
|
617 |
+
and `N` denotes the time steps in `labels`.
|
618 |
+
Args:
|
619 |
+
logits: (B, T, K)-array containing log-probabilities of each class.
|
620 |
+
logitpaddings: (B, T)-array. Padding indicators for `logits`.
|
621 |
+
labels: (B, N)-array containing reference integer labels.
|
622 |
+
labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
|
623 |
+
`labels` must be right-padded, i.e. each row of `labelpaddings` must be
|
624 |
+
repetition of zeroes, followed by repetition of ones.
|
625 |
+
blank_id: Id for blank token.
|
626 |
+
loss_reduction: one of "mean", "sum", "default"
|
627 |
+
- "none": no reduction is applied.
|
628 |
+
- "mean": output loss will be divided by target lengths and then the
|
629 |
+
mean over the batch is taken.
|
630 |
+
- "sum": output loss are summed over batch
|
631 |
+
output_emission_dict: whether to output additional information about the emission probs
|
632 |
+
Returns:
|
633 |
+
A pair of `(per_seq_loss, aux)`.
|
634 |
+
per_seq_loss:
|
635 |
+
(B,)-array containing loss values for each sequence in the batch.
|
636 |
+
aux: Dictionary containing interim variables used for computing losses.
|
637 |
+
aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
|
638 |
+
phi-state corresponding to the n-th label.
|
639 |
+
aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
|
640 |
+
emission-state corresponding to the n-th label.
|
641 |
+
aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
|
642 |
+
corresponding to each time frame.
|
643 |
+
aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
|
644 |
+
corresponding to each time frame.
|
645 |
+
"""
|
646 |
+
# label paddings are indicated by -100
|
647 |
+
labelpaddings = labels < 0
|
648 |
+
# logit paddings are the inverse of attention_mask
|
649 |
+
logitpaddings = ~logits_attention_mask
|
650 |
+
|
651 |
+
# Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
|
652 |
+
batchsize, unused_maxinputlen, num_classes = logits.shape
|
653 |
+
batchsize_, maxlabellen = labels.shape
|
654 |
+
|
655 |
+
logprobs = jax.nn.log_softmax(logits)
|
656 |
+
labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
|
657 |
+
|
658 |
+
# repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
|
659 |
+
repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
|
660 |
+
repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
|
661 |
+
|
662 |
+
logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
|
663 |
+
logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
|
664 |
+
|
665 |
+
one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
|
666 |
+
logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
|
667 |
+
logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
|
668 |
+
|
669 |
+
logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
|
670 |
+
logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
|
671 |
+
logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
|
672 |
+
|
673 |
+
def loop_body(prev, x):
|
674 |
+
prev_phi, prev_emit = prev
|
675 |
+
# emit-to-phi epsilon transition, except if the next label is repetition
|
676 |
+
prev_phi_orig = prev_phi
|
677 |
+
prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
|
678 |
+
|
679 |
+
logprob_emit, logprob_phi, pad = x
|
680 |
+
|
681 |
+
# phi-to-emit transition
|
682 |
+
next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
|
683 |
+
# self-loop transition
|
684 |
+
next_phi = prev_phi + logprob_phi
|
685 |
+
# emit-to-phi blank transition only when the next label is repetition
|
686 |
+
next_phi = next_phi.at[:, 1:].set(
|
687 |
+
jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
|
688 |
+
)
|
689 |
+
|
690 |
+
pad = pad.reshape((batchsize, 1))
|
691 |
+
next_emit = pad * prev_emit + (1.0 - pad) * next_emit
|
692 |
+
next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
|
693 |
+
|
694 |
+
return (next_phi, next_emit), (next_phi, next_emit)
|
695 |
+
|
696 |
+
xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
|
697 |
+
_, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
|
698 |
+
|
699 |
+
# last row needs to be updated with the last epsilon transition
|
700 |
+
logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
|
701 |
+
logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
|
702 |
+
|
703 |
+
# extract per_seq_loss
|
704 |
+
one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
|
705 |
+
per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
|
706 |
+
|
707 |
+
if loss_reduction == "mean":
|
708 |
+
target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
|
709 |
+
loss = (per_seq_loss / target_lengths).mean()
|
710 |
+
elif loss_reduction == "sum":
|
711 |
+
loss = per_seq_loss.sum()
|
712 |
+
else:
|
713 |
+
loss = per_seq_loss
|
714 |
+
|
715 |
+
if not output_emission_dict:
|
716 |
+
return loss
|
717 |
+
|
718 |
+
return loss, {
|
719 |
+
"logalpha_phi": logalpha_phi,
|
720 |
+
"logalpha_emit": logalpha_emit,
|
721 |
+
"logprobs_phi": logprobs_phi,
|
722 |
+
"logprobs_emit": logprobs_emit,
|
723 |
+
}
|
724 |
+
|
725 |
+
|
726 |
+
def main():
|
727 |
+
# 1. Parse input arguments
|
728 |
+
# See all possible arguments in src/transformers/training_args.py
|
729 |
+
# or by passing the --help flag to this script.
|
730 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
731 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
|
732 |
+
|
733 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
734 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
735 |
+
# let's parse it to get our arguments.
|
736 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
737 |
+
else:
|
738 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
739 |
+
|
740 |
+
# 2. Setup logging
|
741 |
+
# Make one log on every process with the configuration for debugging.
|
742 |
+
logging.basicConfig(
|
743 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
744 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
745 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
746 |
+
)
|
747 |
+
# Set the verbosity to info of the Transformers logger.
|
748 |
+
# We only want one process per machine to log things on the screen.
|
749 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
750 |
+
if jax.process_index() == 0:
|
751 |
+
datasets.utils.logging.set_verbosity_warning()
|
752 |
+
transformers.utils.logging.set_verbosity_info()
|
753 |
+
else:
|
754 |
+
datasets.utils.logging.set_verbosity_error()
|
755 |
+
transformers.utils.logging.set_verbosity_error()
|
756 |
+
|
757 |
+
# Set up wandb run
|
758 |
+
if jax.process_index() == 0:
|
759 |
+
wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
|
760 |
+
|
761 |
+
logger.info("Training/evaluation parameters %s", training_args)
|
762 |
+
|
763 |
+
# Set the default TPU matmul precision and display the number of devices
|
764 |
+
jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
|
765 |
+
logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
|
766 |
+
|
767 |
+
# 4. Load dataset
|
768 |
+
raw_datasets = DatasetDict()
|
769 |
+
|
770 |
+
if training_args.do_train:
|
771 |
+
raw_datasets["train"] = load_dataset(
|
772 |
+
data_args.dataset_name,
|
773 |
+
data_args.dataset_config_name,
|
774 |
+
split=data_args.train_split_name,
|
775 |
+
cache_dir=data_args.dataset_cache_dir,
|
776 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
777 |
+
)
|
778 |
+
|
779 |
+
if training_args.do_eval:
|
780 |
+
raw_datasets["eval"] = load_dataset(
|
781 |
+
data_args.dataset_name,
|
782 |
+
data_args.dataset_config_name,
|
783 |
+
split=data_args.eval_split_name,
|
784 |
+
cache_dir=data_args.dataset_cache_dir,
|
785 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
786 |
+
)
|
787 |
+
|
788 |
+
if training_args.do_predict:
|
789 |
+
test_split = data_args.test_split_name.split("+")
|
790 |
+
for split in test_split:
|
791 |
+
raw_datasets[split] = load_dataset(
|
792 |
+
data_args.dataset_name,
|
793 |
+
data_args.dataset_config_name,
|
794 |
+
split=split,
|
795 |
+
cache_dir=data_args.dataset_cache_dir,
|
796 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
797 |
+
)
|
798 |
+
|
799 |
+
if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
|
800 |
+
raise ValueError(
|
801 |
+
"Cannot not train, not do evaluation and not do prediction. At least one of "
|
802 |
+
"training, evaluation or prediction has to be done."
|
803 |
+
)
|
804 |
+
|
805 |
+
# if not training, there is no need to run multiple epochs
|
806 |
+
if not training_args.do_train:
|
807 |
+
training_args.num_train_epochs = 1
|
808 |
+
|
809 |
+
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
810 |
+
raise ValueError(
|
811 |
+
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
812 |
+
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
813 |
+
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
|
814 |
+
)
|
815 |
+
|
816 |
+
if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
|
817 |
+
raise ValueError(
|
818 |
+
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
819 |
+
"Make sure to set `--text_column_name` to the correct text column - one of "
|
820 |
+
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
|
821 |
+
)
|
822 |
+
|
823 |
+
# 5. Load pretrained model, tokenizer, and feature extractor
|
824 |
+
#
|
825 |
+
# Distributed training:
|
826 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
827 |
+
config = Wav2Vec2Config.from_pretrained(
|
828 |
+
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
829 |
+
cache_dir=model_args.cache_dir,
|
830 |
+
revision=model_args.model_revision,
|
831 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
832 |
+
)
|
833 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
834 |
+
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
835 |
+
cache_dir=model_args.cache_dir,
|
836 |
+
revision=model_args.model_revision,
|
837 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
838 |
+
)
|
839 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
840 |
+
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
841 |
+
cache_dir=model_args.cache_dir,
|
842 |
+
revision=model_args.model_revision,
|
843 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
844 |
+
)
|
845 |
+
# update config according to training args, model args, and tokenizer attributes
|
846 |
+
config.update(
|
847 |
+
{
|
848 |
+
"gradient_checkpointing": training_args.gradient_checkpointing,
|
849 |
+
"activation_dropout": model_args.activation_dropout,
|
850 |
+
"hidden_dropout": model_args.hidden_dropout,
|
851 |
+
"feat_proj_dropout": model_args.feat_proj_dropout,
|
852 |
+
"mask_time_prob": model_args.mask_time_prob,
|
853 |
+
"vocab_size": tokenizer.vocab_size,
|
854 |
+
}
|
855 |
+
)
|
856 |
+
|
857 |
+
if training_args.precision == "full_mixed":
|
858 |
+
dtype = jnp.bfloat16
|
859 |
+
training_args.mixed_precision = True
|
860 |
+
elif training_args.precision == "half_mixed":
|
861 |
+
dtype = jnp.bfloat16
|
862 |
+
training_args.mixed_precision = False
|
863 |
+
else:
|
864 |
+
dtype = jnp.float32
|
865 |
+
training_args.mixed_precision = False
|
866 |
+
|
867 |
+
model = FlaxWav2Vec2ForCTC.from_pretrained(
|
868 |
+
model_args.model_name_or_path,
|
869 |
+
config=config,
|
870 |
+
dtype=dtype,
|
871 |
+
cache_dir=model_args.cache_dir,
|
872 |
+
revision=model_args.model_revision,
|
873 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
874 |
+
)
|
875 |
+
|
876 |
+
# 6. Resample speech dataset ALWAYS
|
877 |
+
raw_datasets = raw_datasets.cast_column(
|
878 |
+
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
879 |
+
)
|
880 |
+
|
881 |
+
# 7. Preprocessing the datasets.
|
882 |
+
# We need to read the audio files as arrays and tokenize the targets.
|
883 |
+
max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
884 |
+
min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
885 |
+
max_eval_input_length = int(data_args.max_eval_duration_in_seconds * feature_extractor.sampling_rate) if data_args.max_eval_duration_in_seconds else None
|
886 |
+
max_target_length = data_args.max_label_length
|
887 |
+
min_target_length = data_args.min_label_length
|
888 |
+
pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
|
889 |
+
audio_column_name = data_args.audio_column_name
|
890 |
+
num_workers = data_args.preprocessing_num_workers
|
891 |
+
text_column_name = data_args.text_column_name
|
892 |
+
model_input_name = feature_extractor.model_input_names[0]
|
893 |
+
|
894 |
+
if training_args.do_train and data_args.max_train_samples is not None:
|
895 |
+
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
896 |
+
|
897 |
+
if training_args.do_eval and data_args.max_eval_samples is not None:
|
898 |
+
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
899 |
+
|
900 |
+
if training_args.do_predict and data_args.max_test_samples is not None:
|
901 |
+
for split in test_split:
|
902 |
+
raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
|
903 |
+
|
904 |
+
def prepare_dataset(batch):
|
905 |
+
# Pre-process audio
|
906 |
+
sample = batch[audio_column_name]
|
907 |
+
# normalise audio (mean, std) to (0, 1)
|
908 |
+
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
909 |
+
# process audio length
|
910 |
+
batch[model_input_name] = inputs.input_values[0]
|
911 |
+
batch["input_length"] = len(batch["input_values"])
|
912 |
+
|
913 |
+
input_str = batch[text_column_name]
|
914 |
+
batch["labels"] = tokenizer(input_str).input_ids
|
915 |
+
batch["labels_length"] = len(batch["labels"])
|
916 |
+
return batch
|
917 |
+
|
918 |
+
vectorized_datasets = raw_datasets.map(
|
919 |
+
prepare_dataset,
|
920 |
+
remove_columns=next(iter(raw_datasets.values())).column_names,
|
921 |
+
num_proc=num_workers,
|
922 |
+
desc="preprocess dataset",
|
923 |
+
)
|
924 |
+
|
925 |
+
# filter training data with inputs longer than max_input_length
|
926 |
+
def is_audio_in_length_range(length):
|
927 |
+
return min_input_length < length < max_input_length
|
928 |
+
|
929 |
+
if training_args.do_train:
|
930 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].filter(
|
931 |
+
is_audio_in_length_range,
|
932 |
+
num_proc=num_workers,
|
933 |
+
input_columns=["input_length"],
|
934 |
+
)
|
935 |
+
|
936 |
+
# filter data with targets shorter than min_target_length or longer than max_target_length
|
937 |
+
def is_labels_in_length_range(length):
|
938 |
+
return min_target_length < length < max_target_length
|
939 |
+
|
940 |
+
if training_args.do_train:
|
941 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].filter(
|
942 |
+
is_labels_in_length_range,
|
943 |
+
num_proc=num_workers,
|
944 |
+
input_columns=["labels_length"],
|
945 |
+
)
|
946 |
+
|
947 |
+
|
948 |
+
if max_eval_input_length is not None:
|
949 |
+
# filter training data with inputs longer than max_input_length
|
950 |
+
def is_eval_audio_in_length_range(length):
|
951 |
+
return min_input_length < length < max_eval_input_length
|
952 |
+
|
953 |
+
if training_args.do_eval:
|
954 |
+
vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
|
955 |
+
is_eval_audio_in_length_range,
|
956 |
+
num_proc=num_workers,
|
957 |
+
input_columns=["input_length"],
|
958 |
+
)
|
959 |
+
|
960 |
+
if training_args.do_predict:
|
961 |
+
for split in test_split:
|
962 |
+
vectorized_datasets[split] = vectorized_datasets[split].filter(
|
963 |
+
is_eval_audio_in_length_range,
|
964 |
+
num_proc=num_workers,
|
965 |
+
input_columns=["input_length"],
|
966 |
+
)
|
967 |
+
|
968 |
+
# for large datasets it is advised to run the preprocessing on a
|
969 |
+
# single machine first with `args.preprocessing_only` since there will mostly likely
|
970 |
+
# be a timeout when running the script in distributed mode.
|
971 |
+
# In a second step `args.preprocessing_only` can then be set to `False` to load the
|
972 |
+
# cached dataset
|
973 |
+
if data_args.preprocessing_only:
|
974 |
+
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
|
975 |
+
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
|
976 |
+
return
|
977 |
+
|
978 |
+
# 8. Load Metrics
|
979 |
+
wer_metric = load_metric("wer")
|
980 |
+
cer_metric = load_metric("cer")
|
981 |
+
|
982 |
+
def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
|
983 |
+
padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
|
984 |
+
|
985 |
+
pred_str = tokenizer.batch_decode(pred_ids)
|
986 |
+
# we do not want to group tokens when computing the metrics
|
987 |
+
label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
|
988 |
+
|
989 |
+
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
990 |
+
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
991 |
+
|
992 |
+
return {"wer": wer, "cer": cer}, pred_str, label_str
|
993 |
+
|
994 |
+
# 9. save feature extractor, tokenizer and config
|
995 |
+
feature_extractor.save_pretrained(training_args.output_dir)
|
996 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
997 |
+
config.save_pretrained(training_args.output_dir)
|
998 |
+
|
999 |
+
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
1000 |
+
|
1001 |
+
data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
|
1002 |
+
processor=processor,
|
1003 |
+
input_padding="longest",
|
1004 |
+
pad_input_to_multiple_of=pad_input_to_multiple_of,
|
1005 |
+
max_label_length=data_args.max_label_length,
|
1006 |
+
)
|
1007 |
+
|
1008 |
+
# Enable tensorboard only on the master node
|
1009 |
+
has_tensorboard = is_tensorboard_available()
|
1010 |
+
if has_tensorboard and jax.process_index() == 0:
|
1011 |
+
try:
|
1012 |
+
from flax.metrics.tensorboard import SummaryWriter
|
1013 |
+
|
1014 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
1015 |
+
except ImportError as ie:
|
1016 |
+
has_tensorboard = False
|
1017 |
+
logger.warning(
|
1018 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
1019 |
+
)
|
1020 |
+
else:
|
1021 |
+
logger.warning(
|
1022 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
1023 |
+
"Please run `pip install tensorboard` to enable."
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
# 10. Handle the repository creation
|
1027 |
+
if training_args.push_to_hub:
|
1028 |
+
with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
|
1029 |
+
git_lfs_extensions = f.read()
|
1030 |
+
if "*.wandb" not in git_lfs_extensions:
|
1031 |
+
f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
|
1032 |
+
if training_args.hub_model_id is None:
|
1033 |
+
repo_name = get_full_repo_name(
|
1034 |
+
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
1035 |
+
)
|
1036 |
+
else:
|
1037 |
+
repo_name = training_args.hub_model_id
|
1038 |
+
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
1039 |
+
|
1040 |
+
# 11. Initialize our training
|
1041 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
1042 |
+
rng, dropout_rng = jax.random.split(rng)
|
1043 |
+
|
1044 |
+
# Store some constants
|
1045 |
+
max_steps = int(training_args.max_steps)
|
1046 |
+
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
|
1047 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
1048 |
+
batch_size_per_update = train_batch_size * gradient_accumulation_steps
|
1049 |
+
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
1050 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
1051 |
+
to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
|
1052 |
+
|
1053 |
+
if training_args.do_train:
|
1054 |
+
num_train_samples = len(vectorized_datasets["train"])
|
1055 |
+
steps_per_epoch = num_train_samples // batch_size_per_update
|
1056 |
+
if max_steps > 0:
|
1057 |
+
num_epochs = -(training_args.max_steps // -steps_per_epoch)
|
1058 |
+
total_train_steps = max_steps
|
1059 |
+
else:
|
1060 |
+
num_epochs = int(training_args.num_train_epochs)
|
1061 |
+
total_train_steps = steps_per_epoch * num_epochs
|
1062 |
+
|
1063 |
+
# Create learning rate schedule
|
1064 |
+
# Create learning rate schedule
|
1065 |
+
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
1066 |
+
total_train_steps,
|
1067 |
+
training_args.warmup_steps,
|
1068 |
+
training_args.learning_rate,
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
1072 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
1073 |
+
# mask boolean with the same structure as the parameters.
|
1074 |
+
# The mask is True for parameters that should be decayed.
|
1075 |
+
# Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
|
1076 |
+
# For FlaxT5, one should correct the layer norm parameter naming
|
1077 |
+
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
1078 |
+
def decay_mask_fn(params):
|
1079 |
+
flat_params = traverse_util.flatten_dict(params)
|
1080 |
+
layer_norm_params = [
|
1081 |
+
(name, "scale")
|
1082 |
+
for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
1083 |
+
]
|
1084 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
1085 |
+
return traverse_util.unflatten_dict(flat_mask)
|
1086 |
+
|
1087 |
+
if training_args.adafactor:
|
1088 |
+
# Create Adafactor optimizer
|
1089 |
+
optim = optax.adafactor(
|
1090 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
1091 |
+
dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
|
1092 |
+
weight_decay_rate=training_args.weight_decay,
|
1093 |
+
weight_decay_mask=decay_mask_fn,
|
1094 |
+
)
|
1095 |
+
else:
|
1096 |
+
# Create AdamW optimizer
|
1097 |
+
optim = optax.adamw(
|
1098 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
1099 |
+
b1=training_args.adam_beta1,
|
1100 |
+
b2=training_args.adam_beta2,
|
1101 |
+
eps=training_args.adam_epsilon,
|
1102 |
+
weight_decay=training_args.weight_decay,
|
1103 |
+
mask=decay_mask_fn,
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
# Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
|
1107 |
+
if training_args.multisteps and gradient_accumulation_steps > 1:
|
1108 |
+
optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
|
1109 |
+
else:
|
1110 |
+
num_epochs = 0
|
1111 |
+
total_train_steps = 0
|
1112 |
+
num_train_samples = 0
|
1113 |
+
optim = None
|
1114 |
+
|
1115 |
+
# Setup train state
|
1116 |
+
state = MixedPrecisionTrainState.create(
|
1117 |
+
apply_fn=model.__call__,
|
1118 |
+
get_attention_mask_fn=model._get_feature_vector_attention_mask,
|
1119 |
+
params=model.params,
|
1120 |
+
tx=optim,
|
1121 |
+
to_dtype=to_dtype,
|
1122 |
+
dropout_rng=dropout_rng,
|
1123 |
+
max_grad_norm=training_args.max_grad_norm,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
# Replicate the train state on each device
|
1127 |
+
state = state.replicate()
|
1128 |
+
blank_id = model.config.pad_token_id
|
1129 |
+
|
1130 |
+
# Define gradient update step fn
|
1131 |
+
def train_step(state, batch):
|
1132 |
+
# only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
|
1133 |
+
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
1134 |
+
|
1135 |
+
def compute_loss(params, minibatch):
|
1136 |
+
labels = minibatch.pop("labels")
|
1137 |
+
logits = state.apply_fn(
|
1138 |
+
**minibatch,
|
1139 |
+
params=params,
|
1140 |
+
dropout_rng=dropout_rng,
|
1141 |
+
freeze_feature_encoder=model_args.freeze_feature_encoder,
|
1142 |
+
train=True,
|
1143 |
+
)[0]
|
1144 |
+
logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
|
1145 |
+
loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
|
1146 |
+
|
1147 |
+
return loss
|
1148 |
+
|
1149 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
1150 |
+
|
1151 |
+
if gradient_accumulation_steps == 1 or training_args.multisteps:
|
1152 |
+
loss, grad = grad_fn(to_dtype(state.params), batch)
|
1153 |
+
|
1154 |
+
# Custom gradient accumulation
|
1155 |
+
else:
|
1156 |
+
# add a first dimension over gradient_accumulation_steps for minibatch slices
|
1157 |
+
batch = jax.tree_map(
|
1158 |
+
lambda x: x.reshape(
|
1159 |
+
gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
|
1160 |
+
),
|
1161 |
+
batch,
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
def accum_minibatch_step(accum_grad, minibatch):
|
1165 |
+
# compute loss, num labels and grad over minibatch and accumulate
|
1166 |
+
loss, grad = grad_fn(to_dtype(state.params), minibatch)
|
1167 |
+
return jax.tree_map(jnp.add, accum_grad, grad), loss
|
1168 |
+
|
1169 |
+
# create an initial state for accumulating losses, num labels and gradients
|
1170 |
+
init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
|
1171 |
+
# loop accum minibatch step over the number of gradient accumulation steps
|
1172 |
+
grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
|
1173 |
+
|
1174 |
+
# update state
|
1175 |
+
new_state = state.apply_gradients(
|
1176 |
+
grads=grad,
|
1177 |
+
dropout_rng=new_dropout_rng,
|
1178 |
+
to_dtype=to_dtype,
|
1179 |
+
)
|
1180 |
+
|
1181 |
+
# compute gradient norms over all layers and globally for detailed monitoring
|
1182 |
+
layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
|
1183 |
+
logs = {
|
1184 |
+
"layer_grad_norm": layer_grad_norm,
|
1185 |
+
"grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
|
1186 |
+
}
|
1187 |
+
|
1188 |
+
# compute parameter norms over all layers and globally for detailed monitoring
|
1189 |
+
layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
|
1190 |
+
logs["layer_param_norm"] = layer_param_norm
|
1191 |
+
logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
|
1192 |
+
|
1193 |
+
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
1194 |
+
metrics.update(logs)
|
1195 |
+
|
1196 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
1197 |
+
# metrics = to_fp32(metrics)
|
1198 |
+
|
1199 |
+
return new_state, metrics
|
1200 |
+
|
1201 |
+
# Define eval fn
|
1202 |
+
def eval_step(params, batch):
|
1203 |
+
labels = batch.pop("labels")
|
1204 |
+
logits = model(**batch, params=params, train=False)[0]
|
1205 |
+
|
1206 |
+
logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
|
1207 |
+
loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
|
1208 |
+
|
1209 |
+
pred_ids = jnp.argmax(logits, axis=-1)
|
1210 |
+
|
1211 |
+
# summarize metrics
|
1212 |
+
metrics = {"loss": loss}
|
1213 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
1214 |
+
# metrics = to_fp32(metrics)
|
1215 |
+
return metrics, pred_ids
|
1216 |
+
|
1217 |
+
# Create parallel version of the train and eval step
|
1218 |
+
if training_args.do_train:
|
1219 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
1220 |
+
|
1221 |
+
if training_args.do_eval or training_args.do_predict:
|
1222 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
1223 |
+
|
1224 |
+
def run_evaluation(step, final_step=False):
|
1225 |
+
if training_args.do_eval:
|
1226 |
+
# ======================== Evaluating ==============================
|
1227 |
+
eval_metrics = []
|
1228 |
+
eval_preds = []
|
1229 |
+
eval_labels = []
|
1230 |
+
|
1231 |
+
# Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
|
1232 |
+
eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
|
1233 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
1234 |
+
|
1235 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
1236 |
+
samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
|
1237 |
+
batch = data_collator(samples)
|
1238 |
+
labels = batch["labels"]
|
1239 |
+
|
1240 |
+
try:
|
1241 |
+
metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
|
1242 |
+
except TypeError:
|
1243 |
+
continue
|
1244 |
+
eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
|
1245 |
+
eval_metrics.append(metrics)
|
1246 |
+
|
1247 |
+
eval_labels.extend(labels)
|
1248 |
+
|
1249 |
+
# normalize eval metrics
|
1250 |
+
eval_metrics = get_metrics(eval_metrics)
|
1251 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
1252 |
+
eval_metrics = to_fp32(eval_metrics)
|
1253 |
+
|
1254 |
+
# always run compute metrics
|
1255 |
+
error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
|
1256 |
+
eval_metrics.update(error_rate_metric)
|
1257 |
+
error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
|
1258 |
+
|
1259 |
+
# Print metrics and update progress bar
|
1260 |
+
desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
|
1261 |
+
epochs.write(desc)
|
1262 |
+
epochs.desc = desc
|
1263 |
+
|
1264 |
+
# Save metrics
|
1265 |
+
write_wandb_log(eval_metrics, step, prefix="eval")
|
1266 |
+
write_wandb_pred(pred_str, label_str, step, final_step=final_step)
|
1267 |
+
# if has_tensorboard and jax.process_index() == 0:
|
1268 |
+
# write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
|
1269 |
+
|
1270 |
+
def save_checkpoint(step):
|
1271 |
+
# save and push checkpoint to the hub
|
1272 |
+
if jax.process_index() == 0:
|
1273 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
1274 |
+
model.save_pretrained(training_args.output_dir, params=params)
|
1275 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
1276 |
+
if training_args.push_to_hub:
|
1277 |
+
repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
|
1278 |
+
|
1279 |
+
logger.info("***** Running training *****")
|
1280 |
+
logger.info(f" Num examples = {num_train_samples}")
|
1281 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
1282 |
+
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
1283 |
+
logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
|
1284 |
+
logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
|
1285 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
1286 |
+
logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
|
1287 |
+
logger.info(f" Use scan: {config.use_scan}")
|
1288 |
+
logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
|
1289 |
+
|
1290 |
+
train_time = cur_step = 0
|
1291 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
1292 |
+
for epoch in epochs:
|
1293 |
+
if training_args.do_train:
|
1294 |
+
# ======================== Training ================================
|
1295 |
+
train_start = time.time()
|
1296 |
+
|
1297 |
+
# Create sampling rng
|
1298 |
+
rng, input_rng = jax.random.split(rng)
|
1299 |
+
|
1300 |
+
# Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
|
1301 |
+
train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
|
1302 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
|
1303 |
+
|
1304 |
+
# Gather the indices for creating the batch and do a training step
|
1305 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
|
1306 |
+
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
|
1307 |
+
batch = data_collator(samples)
|
1308 |
+
batch = shard(batch.data)
|
1309 |
+
try:
|
1310 |
+
state, train_metric = p_train_step(state, batch)
|
1311 |
+
except TypeError as e:
|
1312 |
+
logger.warning("Encountered following error: \n", e)
|
1313 |
+
|
1314 |
+
cur_step = epoch * (num_train_samples // batch_size_per_update) + step
|
1315 |
+
|
1316 |
+
if cur_step % training_args.logging_steps == 0:
|
1317 |
+
# Save metrics
|
1318 |
+
train_metric = unreplicate(train_metric)
|
1319 |
+
train_time += time.time() - train_start
|
1320 |
+
# need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
|
1321 |
+
write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
|
1322 |
+
# we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
|
1323 |
+
# if has_tensorboard and jax.process_index() == 0:
|
1324 |
+
# write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
1325 |
+
|
1326 |
+
epochs.write(
|
1327 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
|
1328 |
+
)
|
1329 |
+
|
1330 |
+
if cur_step % total_train_steps == 0:
|
1331 |
+
break
|
1332 |
+
|
1333 |
+
if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
|
1334 |
+
run_evaluation(cur_step, final_step=False)
|
1335 |
+
|
1336 |
+
if cur_step % training_args.save_steps == 0:
|
1337 |
+
save_checkpoint(cur_step)
|
1338 |
+
|
1339 |
+
if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
|
1340 |
+
# run evaluation at the end of the epoch if eval steps are not specified
|
1341 |
+
run_evaluation(cur_step, final_step=False)
|
1342 |
+
save_checkpoint(cur_step)
|
1343 |
+
|
1344 |
+
if training_args.do_train:
|
1345 |
+
save_checkpoint(cur_step)
|
1346 |
+
|
1347 |
+
cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
|
1348 |
+
|
1349 |
+
if training_args.do_eval:
|
1350 |
+
run_evaluation(cur_step, final_step=True)
|
1351 |
+
|
1352 |
+
# TODO: collapse 'do_predict' into the run_evaluation function
|
1353 |
+
if training_args.do_predict:
|
1354 |
+
for split in test_split:
|
1355 |
+
# ======================== Evaluating ==============================
|
1356 |
+
eval_metrics = []
|
1357 |
+
eval_preds = []
|
1358 |
+
eval_labels = []
|
1359 |
+
|
1360 |
+
# Generate eval set by sequentially sampling indices from the test dataset and grouping by length
|
1361 |
+
eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
|
1362 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
1363 |
+
|
1364 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
|
1365 |
+
samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
|
1366 |
+
batch = data_collator(samples)
|
1367 |
+
labels = batch["labels"]
|
1368 |
+
|
1369 |
+
metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
|
1370 |
+
eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
|
1371 |
+
eval_metrics.append(metrics)
|
1372 |
+
|
1373 |
+
eval_labels.extend(labels)
|
1374 |
+
|
1375 |
+
# normalize eval metrics
|
1376 |
+
eval_metrics = get_metrics(eval_metrics)
|
1377 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
1378 |
+
eval_metrics = to_fp32(eval_metrics)
|
1379 |
+
|
1380 |
+
# always run compute metrics
|
1381 |
+
error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
|
1382 |
+
eval_metrics.update(error_rate_metric)
|
1383 |
+
error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
|
1384 |
+
|
1385 |
+
# Print metrics and update progress bar
|
1386 |
+
desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
|
1387 |
+
epochs.write(desc)
|
1388 |
+
epochs.desc = desc
|
1389 |
+
|
1390 |
+
# Save metrics
|
1391 |
+
write_wandb_log(eval_metrics, cur_step, prefix=split)
|
1392 |
+
write_wandb_pred(pred_str, label_str, cur_step, final_step=True, prefix=split)
|
1393 |
+
# if has_tensorboard and jax.process_index() == 0:
|
1394 |
+
# write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
|
1395 |
+
|
1396 |
+
|
1397 |
+
if __name__ == "__main__":
|
1398 |
+
main()
|
run_tedlium.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
python run_flax_speech_recognition_ctc.py \
|
3 |
+
--model_name_or_path="esc/wav2vec2-pretrained" \
|
4 |
+
--tokenizer_name="wav2vec2-ctc-tedlium-tokenizer" \
|
5 |
+
--dataset_name="esc/esc-datasets" \
|
6 |
+
--dataset_config_name="tedlium" \
|
7 |
+
--output_dir="./" \
|
8 |
+
--wandb_project="wav2vec2-ctc" \
|
9 |
+
--wandb_name="wav2vec2-ctc-tedlium" \
|
10 |
+
--max_steps="50000" \
|
11 |
+
--save_steps="10000" \
|
12 |
+
--eval_steps="10000" \
|
13 |
+
--learning_rate="3e-4" \
|
14 |
+
--logging_steps="25" \
|
15 |
+
--warmup_steps="5000" \
|
16 |
+
--preprocessing_num_workers="1" \
|
17 |
+
--hidden_dropout="0.2" \
|
18 |
+
--activation_dropout="0.2" \
|
19 |
+
--feat_proj_dropout="0.2" \
|
20 |
+
--do_train \
|
21 |
+
--do_eval \
|
22 |
+
--do_predict \
|
23 |
+
--overwrite_output_dir \
|
24 |
+
--gradient_checkpointing \
|
25 |
+
--freeze_feature_encoder \
|
26 |
+
--push_to_hub \
|
27 |
+
--use_auth_token
|
special_tokens_map.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"eos_token": "</s>",
|
4 |
+
"pad_token": "<pad>",
|
5 |
+
"unk_token": "<unk>"
|
6 |
+
}
|
tokenizer_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"do_lower_case": false,
|
4 |
+
"eos_token": "</s>",
|
5 |
+
"name_or_path": "sanchit-gandhi/wav2vec2-ctc-tedlium-black-box-tokenizer",
|
6 |
+
"pad_token": "<pad>",
|
7 |
+
"replace_word_delimiter_char": " ",
|
8 |
+
"special_tokens_map_file": null,
|
9 |
+
"tokenizer_class": "Wav2Vec2CTCTokenizer",
|
10 |
+
"unk_token": "<unk>",
|
11 |
+
"word_delimiter_token": "|"
|
12 |
+
}
|
vocab.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"'": 14,
|
3 |
+
"</s>": 2,
|
4 |
+
"<pad>": 0,
|
5 |
+
"<s>": 1,
|
6 |
+
"<unk>": 3,
|
7 |
+
"a": 15,
|
8 |
+
"b": 11,
|
9 |
+
"c": 6,
|
10 |
+
"d": 31,
|
11 |
+
"e": 16,
|
12 |
+
"f": 22,
|
13 |
+
"g": 23,
|
14 |
+
"h": 30,
|
15 |
+
"i": 25,
|
16 |
+
"j": 18,
|
17 |
+
"k": 26,
|
18 |
+
"l": 24,
|
19 |
+
"m": 5,
|
20 |
+
"n": 13,
|
21 |
+
"o": 7,
|
22 |
+
"p": 28,
|
23 |
+
"q": 21,
|
24 |
+
"r": 17,
|
25 |
+
"s": 27,
|
26 |
+
"t": 10,
|
27 |
+
"u": 8,
|
28 |
+
"v": 20,
|
29 |
+
"w": 19,
|
30 |
+
"x": 29,
|
31 |
+
"y": 12,
|
32 |
+
"z": 4,
|
33 |
+
"|": 9
|
34 |
+
}
|