diff --git a/SpeechT5 b/SpeechT5
deleted file mode 160000
index 8b5ade783571e63450aaa5507444150dcb08fa94..0000000000000000000000000000000000000000
--- a/SpeechT5
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 8b5ade783571e63450aaa5507444150dcb08fa94
diff --git a/SpeechT5/CODE_OF_CONDUCT.md b/SpeechT5/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..f9ba8cf65f3e3104dd061c178066ec8247811f33
--- /dev/null
+++ b/SpeechT5/CODE_OF_CONDUCT.md
@@ -0,0 +1,9 @@
+# Microsoft Open Source Code of Conduct
+
+This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
+
+Resources:
+
+- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
+- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
+- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
diff --git a/SpeechT5/LICENSE b/SpeechT5/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..9e841e7a26e4eb057b24511e7b92d42b257a80e5
--- /dev/null
+++ b/SpeechT5/LICENSE
@@ -0,0 +1,21 @@
+ MIT License
+
+ Copyright (c) Microsoft Corporation.
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE
diff --git a/SpeechT5/README.md b/SpeechT5/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..aa825607e61ad8e7e9c0b47161b38df0961a0654
--- /dev/null
+++ b/SpeechT5/README.md
@@ -0,0 +1,267 @@
+# SpeechT5
+
+Unified-modal speech-text pre-training for spoken language processing:
+
+> [**SpeechT5**](https://arxiv.org/abs/2110.07205) (```ACL 2022```): **SpeechT5: Unified-Modal Encoder-Decoder Pre-training for Spoken Language Processing**
+
+> [**Speech2C**](https://arxiv.org/abs/2203.17113) (```INTERSPEECH 2022```): **Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data**
+
+> [**YiTrans**](https://arxiv.org/abs/2206.05777) (```IWSLT 2022```): **The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task**
+
+> [**SpeechUT**](https://arxiv.org/abs/2210.03730) (```EMNLP 2022```): **SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training**
+
+> [**SpeechLM**](https://arxiv.org/abs/2209.15329) (```Arxiv 2022```): **SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data**
+
+> [**Speech2S**](https://arxiv.org/abs/2210.17027) (```ICASSP 2023```): **Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation**
+
+> [**Prosody-SpeechT5**](https://ieeexplore.ieee.org/document/10096530/) (```ICASSP 2023```): **Prosody-aware SpeechT5 for Expressive Neural TTS**
+
+> [**VATLM**](https://arxiv.org/abs/2211.11275) (```IEEE Transactions on Multimedia```): **VATLM: Visual-Audio-Text Pre-Training with Unified Masked Prediction for Speech Representation Learning**
+
+> [**VALL-E X**](https://arxiv.org/abs/2303.03926) (```Arxiv 2023```): **Speak Foreign Languages with Your Own Voice: Cross-Lingual Neural Codec Language Modeling**
+
+> [**VioLA**](https://arxiv.org/abs/2305.16107) (```Arxiv 2023```): **VioLA: Unified Codec Language Models for Speech Recognition, Synthesis, and Translation**
+
+
+
+
+## Update
+
+- May, 2023: VioLA [**Arxiv**](https://arxiv.org/abs/2305.16107).
+- May, 2023: [**VATLM**](https://arxiv.org/abs/2211.11275) was accepted by IEEE Transactions on Multimedia.
+- March, 2023: VALL-E X [**Arxiv**](https://arxiv.org/abs/2303.03926) and [**Demo**](https://aka.ms/vallex).
+- February, 2023: [**Speech2S**](https://arxiv.org/abs/2210.17027) and [**Prosody-SpeechT5**](https://arxiv.org/abs/2211.11275) were accepted by ICASSP 2023.
+- [HuggingFace Integration] February, 2023: [**SpeechT5**](https://aclanthology.org/2022.acl-long.393/) models are on [**HuggingFace**](https://huggingface.co/blog/speecht5).
+- [Model Release] November, 2022: [**VATLM**](https://github.com/microsoft/SpeechT5/tree/main/VATLM) models are released.
+- November, 2022: VATLM [**Arxiv**](https://arxiv.org/abs/2211.11275).
+- November, 2022: Speech2S [**Arxiv**](https://arxiv.org/abs/2210.17027).
+- [Model Release] October, 2022: [**SpeechUT**](https://github.com/microsoft/SpeechT5/tree/main/SpeechUT) models are released.
+- October, 2022: [**SpeechUT**](https://arxiv.org/abs/2210.03730) was accepted by EMNLP 2022.
+- [Model Release] October, 2022: [**SpeechLM**](https://github.com/microsoft/SpeechT5/tree/main/SpeechLM) models are released.
+- September, 2022: SpeechLM [**Arxiv**](https://arxiv.org/abs/2209.15329).
+- [Evaluation] June, 2022: The end-to-end ST system [**YiTrans**](https://arxiv.org/abs/2206.05777) achieved top results on [**IWSLT 2022**](https://iwslt.org/2022/offline) shared tasks.
+- June, 2022: [**Speech2C**](https://www.isca-speech.org/archive/interspeech_2022/ao22_interspeech.html) was accepted by InterSpeech 2022.
+- [Model Release] May, 2022: [**Speech2C**](https://github.com/microsoft/SpeechT5/tree/main/Speech2C) models are released.
+- [Model Release] April, 2022: [**SpeechT5**](https://github.com/microsoft/SpeechT5/tree/main/SpeechT5) models are released.
+- March, 2022: Speech2C [**Arxiv**](https://arxiv.org/abs/2203.17113).
+- February, 2022: [**SpeechT5**](https://aclanthology.org/2022.acl-long.393/) was accepted by ACL 2022.
+- October, 2021: SpeechT5 [**Arxiv**](https://arxiv.org/abs/2110.07205).
+
+
+## Pre-Trained Models
+
+
+| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
+| :------: | :----------------------------------------------: | :-----------------: | :-----: |
+| SpeechT5 Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | - | [HuggingFace](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base.pt)
[Google Drive](https://drive.google.com/file/d/1Sq00uZ1pw6Z4OUaqhOWzQEJxIVWgAO5U/view?usp=sharing) |
+| SpeechT5 Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [HuggingFace](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base_asr.pt)
[Google Drive](https://drive.google.com/file/d/1qLKJ81JPWOGf1MHfjSmgtZyqqTqgI6kT/view?usp=sharing) |
+| SpeechT5 Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | - | [Google Drive](https://drive.google.com/file/d/1M79b1jetSPOVxWVMIX-y0URvDjNskZKp/view?usp=sharing) |
+| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [Google Drive](https://drive.google.com/file/d/1nGZ0LWEwlLq2pz7o805YALsMr9irV0Za/view?usp=sharing) |
+| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [10 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1nWSAc-33LmcDQHzH8IjXVJsuk0JZTWgN/view?usp=sharing) |
+| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1LwbQ5Y3tKZoK3s1ayLQgsfLTFnmkKNZs/view?usp=sharing) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1iJvhSGghNrMT-wAY1nwVu2YaYuTy1pxx/view?usp=sharing) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1mH3N7iKMWYk3rSBJErQPYf3x5ugqDq5x/view?usp=sharing) |
+| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1eblW8U8f9t-NTuCNRrNHwr-8BeLAUAmQ/view?usp=sharing) |
+| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1vXyO5DolbiWiTYZ6pkkKQsu2wJetaPlv/view?usp=sharing) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_ende.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_enca.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_enar.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_entr.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1QjLIgTJKIylVIp5hUkfSjGPtz8Xo7Lky/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1YZQDVv096o8Opt0RBnkRiZXYPRDqKZnP/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1qYygNWSc11TQbBI1OzC4ChlR-dNh8t9S/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/162U88mwso2aVfzzPkEM2nP_vwTpcb57T/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1lbTSRXewEeb2t45URunD6EiJcbniyjWW/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1Er4I_jHS175pQQph223yKtiiLQ378VvH/view?usp=sharing) |
+| SpeechUT Base (ASR) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4asr_32gpu_1accum/checkpoint_298_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A39%3A48Z&se=2024-03-09T01%3A39%3A00Z&sr=b&sp=r&sig=l3gJS1D%2BJfLfNfS3z33WjmSMGrOBJ63CvqGGewC6WeU%3D)|
+| SpeechUT Base (ASR) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/speechut_base_asr100h_checkpoint_best.pt?sv=2020-04-08&st=2023-03-08T01%3A41%3A22Z&se=2024-03-09T01%3A41%3A00Z&sr=b&sp=r&sig=%2B9lpGrqtZXa%2F6n1uZT%2Biey54ky31bYKSJytgfnBbbN4%3D)|
+| SpeechUT Large (ASR) | [60k hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/large_speechut4asr_32gpu_4accum/checkpoint_22_400k.pt?sv=2020-04-08&st=2023-03-08T01%3A42%3A10Z&se=2024-03-09T01%3A42%3A00Z&sr=b&sp=r&sig=TZNcsHQAqapyj%2BAvpHtl749kZy9flTkWm8P5L4W26qs%3D)|
+| SpeechUT Large (ASR) | [60k hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/speechut_large_asr960h_checkpoint_best.pt?sv=2020-04-08&st=2023-03-08T01%3A43%3A02Z&se=2024-03-09T01%3A43%3A00Z&sr=b&sp=r&sig=PmO%2BgSAMXRgMC7GfpS4c%2BrDPsfJGekqUzD5AJm7RrYU%3D)|
+| SpeechUT Base (En-De) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [408 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [4.6M Text](https://www.statmt.org/wmt16/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4ende_32gpu_1accum/checkpoint_217_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A43%3A47Z&se=2024-03-09T01%3A43%3A00Z&sr=b&sp=r&sig=XDEesMdGQ027j7YtpSql1kZtwgfv39gruOuWYlKlJ7w%3D)|
+| SpeechUT Base (En-De) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [408 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [4.6M Text](https://www.statmt.org/wmt16/) | [En-De MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4ende_32gpu_1accum/fineutne_ende_checkpoint_avg.pt?sv=2020-04-08&st=2023-03-08T01%3A44%3A15Z&se=2024-03-09T01%3A44%3A00Z&sr=b&sp=r&sig=8dcenahRg46EJdwiHUalVBJgKra6JoSN7tUxdLAwzOM%3D)|
+| SpeechUT Base (En-Es) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [504 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [15M Text](https://www.statmt.org/wmt13/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enes_32gpu_1accum/checkpoint_204_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A48%3A16Z&se=2024-03-09T01%3A48%3A00Z&sr=b&sp=r&sig=hWoCM0y0SGZTD4CznC%2F5CejFczkqDYTOaISmlhCAYAU%3D)|
+| SpeechUT Base (En-Es) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [504 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [15M Text](https://www.statmt.org/wmt13/) | [En-Es MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enes_32gpu_1accum/fineutne_enes_checkpoint_avg.pt?sv=2020-04-08&st=2023-03-08T01%3A48%3A41Z&se=2024-03-09T01%3A48%3A00Z&sr=b&sp=r&sig=KGfzgKfKkDVQI0JxxnS%2BsYdBQzhUqFLQAVYG0OSGBtk%3D)|
+| SpeechUT Base (En-Fr) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [492 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [40M Text](https://www.statmt.org/wmt14/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enfr_32gpu_1accum/checkpoint_297_600000.pt?sv=2020-04-08&st=2023-03-08T01%3A49%3A09Z&se=2024-03-09T01%3A49%3A00Z&sr=b&sp=r&sig=1eqpXMLCjWpfyd7AiOHGzfk%2B8ZYqWwVWdHk1GqXgoeg%3D)|
+| SpeechUT Base (En-Fr) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [492 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [40M Text](https://www.statmt.org/wmt14/) | [En-Fr MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enfr_32gpu_1accum/fineutne_enfr_checkpoint.pt?sv=2020-04-08&st=2023-03-08T01%3A49%3A34Z&se=2024-03-09T01%3A49%3A00Z&sr=b&sp=r&sig=i3jMAqvL1Vp7DRjACAbrdoQKhlv2Cmi40%2F14SJ6%2BoiU%3D)|
+
+
+
+## SpeechT5 Introduction
+
+Motivated by the success of T5 (Text-To-Text Transfer Transformer) in pre-trained natural language processing models, we propose a unified-modal SpeechT5 framework that explores the encoder-decoder pre-training for self-supervised speech/text representation learning.
+The SpeechT5 framework consists of a shared encoder-decoder network and six modal-specific (speech/text) pre/post-nets.
+After preprocessing the input speech/text through the pre-nets, the shared encoder-decoder network models the sequence-to-sequence transformation, and then the post-nets generate the output in the speech/text modality based on the output of the decoder.
+
+
+
+Leveraging large-scale unlabeled speech and text data, we pre-train SpeechT5 to learn a unified-modal representation, hoping to improve the modeling capability for both speech and text.
+To align the textual and speech information into this unified semantic space, we propose a cross-modal vector quantization approach that randomly mixes up speech/text states with latent units as the interface between encoder and decoder.
+Extensive evaluations show the superiority of the proposed SpeechT5 framework on a wide variety of spoken language processing tasks, including automatic speech recognition, speech synthesis, speech translation, voice conversion, speech enhancement, and speaker identification.
+
+
+
+## SpeechT5 Downstream Task Performance
+
+We evaluate our models on typical spoken language processing tasks, including automatic speech recognition, text to speech, speech to text translation, voice conversion, speech enhancement, and speaker identification.
+
+### Automatic Speech Recognition
+
+Evaluation on the [LibriSpeech](http://www.openslr.org/12)
+
+| Model |LM | dev-clean | dev-other | test-clean | test-other |
+| ------------- |------------- | ------| ----- | ----| ----|
+| wav2vec2.0 Base | - | 6.1 | 13.5 | 6.1 | 13.3 |
+| HuBERT Base | - | 5.5 | 13.1 | 5.8 | 13.3 |
+| Baseline (w/o CTC) | - | 5.8 | 12.3 | 6.2 | 12.3 |
+| Baseline | - | 4.9 | 11.7 | 5.0 | 11.9 |
+| SpeechT5 (w/o CTC) | - | 5.4 | 10.7 | 5.8 | 10.7 |
+| **SpeechT5** | - | **4.3** | **10.3** | **4.4** | **10.4** |
+| DiscreteBERT | 4-gram | 4.0 |10.9 |4.5 |12.1 |
+| wav2vec 2.0 Base | 4-gram | 2.7 |7.9 |3.4 |8.0 |
+| HuBERT Base | 4-gram | 2.7 |7.8 |3.4 |8.1 |
+| wav2vec 2.0 Base | Transf. | 2.2 |6.3 |2.6 |6.3 |
+| Baseline | Transf. | 2.3 |6.3 |2.5 |6.3 |
+| **SpeechT5** | Transf. | **2.1** |**5.5** |**2.4** |**5.8** |
+
+### Text-to-Speech
+
+Evaluation on the [LibriTTS](http://www.openslr.org/60/)
+
+
+| Model | Naturalness | MOS | CMOS |
+| ------------- |------------ | ------ | ----- |
+| Ground Truth | - | 3.87 | - |
+| Baseline | 2.76 | 3.56 | 0 |
+| **SpeechT5** | 2.91 | **3.65** | **+0.290** |
+
+### Speech Translation
+
+Evaluation on the [MUST-C v1](https://ict.fbk.eu/must-c/)
+
+| Model | EN-DE | EN-FR |
+| ------------- |------------ | ------ |
+| Fairseq ST | 22.70 | 32.90 |
+| ESPnet ST | 22.91 | 32.69 |
+| Adapter Tuning| 24.63 | 34.98 |
+| Baseline | 23.43 | 33.76 |
+| SpeechT5 (w/o initializing decoder) | 24.44 | 34.5 |
+| **SpeechT5** | **25.18** | **35.30** |
+
+
+### Voice Conversion
+
+Evaluation on the [CMU Arctic](http://www.festvox.org/cmu_arctic/)
+
+
+| Model | WER | WER | MCD | MCD |
+| ------------- | ------ | ----- | ---- | ----|
+| | bdl to slt | clb to slt | bdl to slt | clb to slt |
+| VTN w/ ASR | 11.1 | 10.9 | 6.5 | 6.11 |
+| VTN w/ TTS | 7.6 | 9.1 | 6.33 | 13.3 |
+| Many-to-many VTN | - | - | 6.13 | 5.97 |
+| Baseline | 21.5 | 10.8 | 6.26 | 6.16 |
+| **SpeechT5** | **7.8** | **6.4** | **5.93**| **5.87** |
+
+
+
+### Speech Enhancement
+
+Evaluation on the [WSJ0 Hipster AmbientMixtures (WHAM!)](http://wham.whisper.ai/)
+
+
+| Model | WER |
+| ------------- |------------ |
+| Ground Truth Speech | 3.2 |
+| Noisy Speech | 76.1 |
+| Baseline | 10.9 |
+| **SpeechT5** | **8.9** |
+
+
+### Speaker Identification
+
+Evaluation on the [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html)
+
+| Model | Acc |
+| ------------- |------------ |
+| SUPERB, wav2vec 2.0 Base | 75.18% |
+| SUPERB, HuBERT Base | 81.42% |
+| SUPERB, HuBERT Large | 90.33% |
+| SpeechNet, single task | 86.00% |
+| SpeechNet, multi-task with TTS | 87.90% |
+| Thin ResNet-34 | 89.00% |
+| Baseline | 91.92% |
+| **SpeechT5** | **96.49%** |
+
+## License
+
+This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
+Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [ESPnet](https://github.com/espnet/espnet) projects.
+
+[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
+
+### Reference
+
+If you find our work is useful in your research, please cite the following paper:
+
+```bibtex
+@article{Ao2021SpeechT5,
+ title = {SpeechT5: Unified-Modal Encoder-Decoder Pre-training for Spoken Language Processing},
+ author = {Junyi Ao and Rui Wang and Long Zhou and Chengyi Wang and Shuo Ren and Yu Wu and Shujie Liu and Tom Ko and Qing Li and Yu Zhang and Zhihua Wei and Yao Qian and Jinyu Li and Furu Wei},
+ eprint={2110.07205},
+ archivePrefix={arXiv},
+ primaryClass={eess.AS},
+ year={2021}
+}
+```
+
+```bibtex
+@article{Ao2022Speech2C,
+ title = {Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data},
+ author = {Junyi Ao and Ziqiang Zhang and Long Zhou and Shujie Liu and Haizhou Li and Tom Ko and Lirong Dai and Jinyu Li and Yao Qian and Furu Wei},
+ eprint={2203.17113},
+ archivePrefix={arXiv},
+ primaryClass={cs.SD},
+ year={2022}
+}
+```
+
+```bibtex
+@article{Zhang2022Yitrans,
+ title = {The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task},
+ author = {Zhang, Ziqiang and Ao, Junyi and Zhou, Long and Liu, Shujie and Wei, Furu and Li, Jinyu},
+ eprint={2206.05777},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ year={2022}
+}
+```
+
+```bibtex
+@article{zhang2022speechut,
+ title = {SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training},
+ author = {Zhang, Ziqiang and Zhou, Long and Ao, Junyi and Liu, Shujie and Dai, Lirong and Li, Jinyu and Wei, Furu},
+ eprint={2210.03730},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ year={2022}
+}
+```
+
+```bibtex
+@article{zhang2022speechlm,
+ title = {SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data},
+ author = {Zhang, Ziqiang and Chen, Sanyuan and Zhou, Long and Wu, Yu and Ren, Shuo and Liu, Shujie and Yao, Zhuoyuan and Gong, Xun and Dai, Lirong and Li, Jinyu and Wei, Furu},
+ eprint={2209.15329},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ year={2022}
+}
+```
+
+### Contact Information
+
+For help or issues using SpeechT5 models, please submit a GitHub issue.
+
+For other communications related to SpeechT5, please contact Long Zhou (`lozhou@microsoft.com`).
diff --git a/SpeechT5/SECURITY.md b/SpeechT5/SECURITY.md
new file mode 100644
index 0000000000000000000000000000000000000000..869fdfe2b246991a053fab9cfec1bed3ab532ab1
--- /dev/null
+++ b/SpeechT5/SECURITY.md
@@ -0,0 +1,41 @@
+
+
+## Security
+
+Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
+
+If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
+
+## Reporting Security Issues
+
+**Please do not report security vulnerabilities through public GitHub issues.**
+
+Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
+
+If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
+
+You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
+
+Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
+
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
+ * Full paths of source file(s) related to the manifestation of the issue
+ * The location of the affected source code (tag/branch/commit or direct URL)
+ * Any special configuration required to reproduce the issue
+ * Step-by-step instructions to reproduce the issue
+ * Proof-of-concept or exploit code (if possible)
+ * Impact of the issue, including how an attacker might exploit the issue
+
+This information will help us triage your report more quickly.
+
+If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
+
+## Preferred Languages
+
+We prefer all communications to be in English.
+
+## Policy
+
+Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
+
+
diff --git a/SpeechT5/Speech2C/README.md b/SpeechT5/Speech2C/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e568918c7ee624ba9bfe8c39f810a72af69f3f2
--- /dev/null
+++ b/SpeechT5/Speech2C/README.md
@@ -0,0 +1,145 @@
+# Speech2C
+
+> [**Speech2C**](https://arxiv.org/abs/2203.17113) (```INTERSPEECH 2022```): **Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data**
+
+## Pre-Trained and Fine-tuned Models
+
+| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
+| :------: | :----------------------------------------------: | :-----------------: | :-----: |
+| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [Google Drive](https://drive.google.com/file/d/1nGZ0LWEwlLq2pz7o805YALsMr9irV0Za/view?usp=sharing) |
+| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [10 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1nWSAc-33LmcDQHzH8IjXVJsuk0JZTWgN/view?usp=sharing) |
+| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1LwbQ5Y3tKZoK3s1ayLQgsfLTFnmkKNZs/view?usp=sharing) |
+
+
+## Language Model and Vocabulary
+| Model | Dataset | Model | Vocabulary |
+| :------: | :------: | :---: | :--------: |
+| LM | [LibriSpeech LM Dataset](https://www.openslr.org/11/) | [Model](https://drive.google.com/file/d/1UDCcNJT1DlquSRw0iRAXH6GHlf6zK6-8/view?usp=sharing) | [Vocabulary](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt) |
+
+## Setup
+```
+git submodule update --init Speech2C/fairseq
+cd Speech2C/
+pip install --editable fairseq/
+```
+
+## Data Preparation
+Please follow the steps of data preparation for HuBERT in [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#data-preparation).
+
+## Pre-Training
+```
+DATA_DIR=
+LABEL_DIR=
+FAIRSEQ_PATH=
+
+python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
+ --config-dir speech2c/config \
+ --config-name speech2c_base_librispeech \
+ task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} task.labels='["km"]' \
+ model.label_rate=50 common.user_dir=SpeechT5/Speech2C/speech2c \
+```
+
+## Finetune
+
+```
+DATA_DIR=
+LABEL_DIR=
+FAIRSEQ_PATH=
+W2V_PATH=
+CONFIG_NAME=
+
+python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
+ --config-dir speech2c/config \
+ --config-name ${CONFIG_NAME} \
+ task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} \
+ model.w2v_path=${W2V_PATH} common.user_dir=SpeechT5/Speech2C/speech2c \
+```
+
+## Inference
+Note that joint CTC and decoder inference is only supported when the batch size is 1.
+
+```
+FAIRSEQ_PATH=
+DATA_DIR=
+LABEL_DIR=
+BEAM_SIZE=
+CTC_WEIGHT=
+TEST_SET=
+CHECKPOINT_PATH=
+W2V_PATH=
+
+
+python ${FAIRSEQ_PATH}/fairseq_cli/generate.py ${DATA_DIR} \
+ --label-dir ${LABEL_DIR} \
+ --path ${CHECKPOINT_PATH} \
+ --user-dir SpeechT5/Speech2C/speech2c \
+ --model-overrides "{'w2v_path': '${W2V_PATH}'}" \
+ --gen-subset ${TEST_SET} \
+ --task speech2c_pretraining \
+ --post-process letter \
+ --add-decoder \
+ --labels '["ltr"]' \
+ --fine-tuning \
+ --scoring wer \
+ --max-len-a 0 \
+ --max-len-b 620 \
+ --pad-audio \
+ --random-crop \
+ --ctc-weight ${CTC_WEIGHT} \
+ --max-tokens 8000000 \
+ --beam ${BEAM_SIZE} \
+ --single-target \
+```
+
+## Results on Librispeech
+
+### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 10hr subset
+
+| Model |LM | test-clean | test-other |
+| ------------- |------------- | ----| ----|
+| wav2vec2.0 Base | - | 11.1 | 17.6 |
+| HuBERT Base | - | 10.1 | 16.8 |
+| **Speech2C** | - | **7.8** | **13.1** |
+| wav2vec 2.0 Base | 4-gram | 4.3 |9.5 |
+| wav2vec 2.0 Base | Transf. |3.2 |7.8 |
+| HuBERT Base | 4-gram |4.3 |9.4 |
+| **Speech2C** | **Transf.** | **3.1** | **7.0** |
+
+### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 100hr subset
+
+| Model |LM | test-clean | test-other |
+| ------------- |------------- | ----| ----|
+| wav2vec2.0 Base | - | 6.1 | 13.3 |
+| wav2vec2.0 Large | - | 4.7 | 9.0 |
+| HuBERT Base | - | 6.3 | 13.2 |
+| SpeechT5 | - | 4.4 | 10.4 |
+| Baseline | - | 5.0 | 11.9 |
+| **Speech2C** | - | **4.3** |**9.0** |
+| wav2vec 2.0 Base | 4-gram | 3.4 |8.0 |
+| wav2vec 2.0 Base | Transf. | 2.6 | 6.3 |
+| HuBERT Base | 4-gram | 3.4 |8.1 |
+| SpeechT5 | Transf. | 2.4 |5.8 |
+| Baseline | Transf. | 2.5 |6.3 |
+| **Speech2C** | **Transf.** | **2.4** |**5.2** |
+
+## License
+
+This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
+Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
+
+[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
+
+## Reference
+
+If you find our work is useful in your research, please cite the following paper:
+
+```bibtex
+@article{Ao2022Speech2C,
+ title = {Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data},
+ author = {Junyi Ao and Ziqiang Zhang and Long Zhou and Shujie Liu and Haizhou Li and Tom Ko and Lirong Dai and Jinyu Li and Yao Qian and Furu Wei},
+ eprint={2203.17113},
+ archivePrefix={arXiv},
+ primaryClass={cs.SD},
+ year={2022}
+}
+```
diff --git a/SpeechT5/Speech2C/speech2c/__init__.py b/SpeechT5/Speech2C/speech2c/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8994f9a368ae4b2eff720fffb134e2a5b813ee1c
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/__init__.py
@@ -0,0 +1 @@
+from . import data, tasks, criterions, models # noqa
\ No newline at end of file
diff --git a/SpeechT5/Speech2C/speech2c/config/base_100h.yaml b/SpeechT5/Speech2C/speech2c/config/base_100h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2af86af96e3719a1419a4dd49af156d4c61e9c49
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/config/base_100h.yaml
@@ -0,0 +1,93 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tblog
+ seed: 1337
+
+checkpoint:
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: dec_accuracy
+ maximize_best_checkpoint_metric: true
+
+distributed_training:
+ ddp_backend: c10d
+ find_unused_parameters: true
+ distributed_world_size: 1
+ distributed_port: 29671
+ nprocs_per_node: 8
+
+task:
+ _name: speech2c_pretraining
+ data: ???
+ fine_tuning: true
+ label_dir: ???
+ normalize: false # must be consistent with pre-training
+ labels: ["ltr"]
+ single_target: true
+ add_decoder: true
+ pad_audio: true
+ random_crop: false
+
+dataset:
+ num_workers: 6
+ max_tokens: 3200000
+ skip_invalid_size_inputs_valid_test: true
+ train_subset: train_100h
+ valid_subset: dev_other
+
+criterion:
+ _name: ctc_ce
+ zero_infinity: true
+
+optimization:
+ max_update: 80000
+ lr: [0.00004]
+ sentence_avg: true
+ update_freq: [1]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: speech2c_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.1
+ decoder_layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 25000
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ - criterion.wer_kenlm_model
+ - criterion.wer_lexicon
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2C/speech2c/config/base_10h.yaml b/SpeechT5/Speech2C/speech2c/config/base_10h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aaa4ed7a79998fc1a09480f2917e2557e8aba457
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/config/base_10h.yaml
@@ -0,0 +1,104 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tblog
+ seed: 1337
+
+checkpoint:
+ save_interval: 5
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: dec_accuracy
+ maximize_best_checkpoint_metric: true
+
+distributed_training:
+ ddp_backend: c10d
+ find_unused_parameters: true
+ distributed_world_size: 1
+ distributed_port: 29671
+ nprocs_per_node: 8
+
+task:
+ _name: speech2c_pretraining
+ data: ???
+ fine_tuning: true
+ label_dir: ???
+ normalize: false # must be consistent with pre-training
+ labels: ["ltr"]
+ single_target: true
+ add_decoder: true
+ pad_audio: true
+ random_crop: false
+
+dataset:
+ num_workers: 6
+ max_tokens: 3200000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: ${model.freeze_finetune_updates}
+ validate_interval: 5
+ train_subset: train_10h
+ valid_subset: dev_other
+
+criterion:
+ _name: ctc_ce
+ zero_infinity: true
+
+optimization:
+ max_update: 25000
+ lr: [2e-5]
+ sentence_avg: true
+ update_freq: [1]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: speech2c_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_selection: static
+ mask_length: 10
+ mask_other: 0
+ mask_prob: 0.75
+ mask_channel_selection: static
+ mask_channel_length: 64
+ mask_channel_other: 0
+ mask_channel_prob: 0.5
+ layerdrop: 0.1
+ decoder_layerdrop: 0.1
+ dropout: 0.0
+ activation_dropout: 0.1
+ attention_dropout: 0.0
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ - criterion.wer_kenlm_model
+ - criterion.wer_lexicon
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2C/speech2c/config/speech2c_base_librispeech.yaml b/SpeechT5/Speech2C/speech2c/config/speech2c_base_librispeech.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1f361375d8d11d6d3f7dc5573bbfc1e779930d52
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/config/speech2c_base_librispeech.yaml
@@ -0,0 +1,100 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ seed: 1337
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_world_size: 32
+ distributed_port: 29671
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: speech2c_pretraining
+ data: ???
+ label_dir: ???
+ labels: ???
+ label_rate: ${model.label_rate}
+ sample_rate: 16000
+ max_sample_size: 250000
+ min_sample_size: 32000
+ pad_audio: false
+ random_crop: true
+ normalize: false # must be consistent with extractor
+ add_decoder: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1400000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ validate_interval_updates: 10000
+
+criterion:
+ _name: speech2c
+ pred_masked_weight: 1.0
+ pred_nomask_weight: 0.0
+ loss_weights: [10,]
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 10.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: speech2c
+ label_rate: ???
+ skip_masked: false
+ skip_nomask: false
+ mask_prob: 0.80
+ extractor_mode: default
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
+ final_dim: 256
+ encoder_layerdrop: 0.05
+ dropout_input: 0.1
+ dropout_features: 0.1
+ dropout: 0.1
+ attention_dropout: 0.1
+ feature_grad_mult: 0.1
+ untie_final_proj: true
+ activation_dropout: 0.0
+ use_rel_pos_enc: true
+ decoder_dict_size: -1
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2C/speech2c/criterions/__init__.py b/SpeechT5/Speech2C/speech2c/criterions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69fc7d7c6fa06ee16e28752119410410bf3e212f
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/criterions/__init__.py
@@ -0,0 +1,10 @@
+import importlib
+import os
+
+
+for file in os.listdir(os.path.dirname(__file__)):
+ if file.endswith(".py") and not file.startswith("_"):
+ criterion_name = file[: file.find(".py")]
+ importlib.import_module(
+ "speech2c.criterions." + criterion_name
+ )
diff --git a/SpeechT5/Speech2C/speech2c/criterions/ctc_ce.py b/SpeechT5/Speech2C/speech2c/criterions/ctc_ce.py
new file mode 100644
index 0000000000000000000000000000000000000000..39922924a1f22f6405f743cf262ca3609de59268
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/criterions/ctc_ce.py
@@ -0,0 +1,404 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+from argparse import Namespace
+from dataclasses import dataclass, field
+from omegaconf import II
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from fairseq import metrics, utils
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
+from fairseq.dataclass import FairseqDataclass
+from fairseq.data.data_utils import post_process
+from fairseq.tasks import FairseqTask
+from fairseq.logging.meters import safe_round
+
+
+@dataclass
+class CtcCeCriterionConfig(FairseqDataclass):
+ zero_infinity: bool = field(
+ default=False,
+ metadata={"help": "zero inf loss when source length <= target length"},
+ )
+ sentence_avg: bool = II("optimization.sentence_avg")
+ post_process: str = field(
+ default="letter",
+ metadata={
+ "help": "how to post process predictions into words. can be letter, "
+ "wordpiece, BPE symbols, etc. "
+ "See fairseq.data.data_utils.post_process() for full list of options"
+ },
+ )
+ wer_kenlm_model: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
+ },
+ )
+ wer_lexicon: Optional[str] = field(
+ default=None,
+ metadata={"help": "lexicon to use with wer_kenlm_model"},
+ )
+ wer_lm_weight: float = field(
+ default=2.0,
+ metadata={"help": "lm weight to use with wer_kenlm_model"},
+ )
+ wer_word_score: float = field(
+ default=-1.0,
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
+ )
+
+ wer_args: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
+ },
+ )
+
+ dec_weight: float = field(
+ default=0.5,
+ metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
+ )
+ report_accuracy: bool = field(
+ default=True,
+ metadata={"help": "report decoder accuracy metric"},
+ )
+ ignore_prefix_size: int = field(
+ default=0,
+ metadata={"help": "Ignore first N tokens"},
+ )
+ label_smoothing: float = field(
+ default=0.1,
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
+ )
+
+
+@register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
+class CtcCeCriterion(FairseqCriterion):
+ def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
+ super().__init__(task)
+ self.blank_idx = (
+ task.target_dictionary.index(task.blank_symbol)
+ if hasattr(task, "blank_symbol")
+ else 0
+ )
+ self.pad_idx = task.target_dictionary.pad()
+ self.eos_idx = task.target_dictionary.eos()
+ self.post_process = cfg.post_process
+
+ if cfg.wer_args is not None:
+ (
+ cfg.wer_kenlm_model,
+ cfg.wer_lexicon,
+ cfg.wer_lm_weight,
+ cfg.wer_word_score,
+ ) = eval(cfg.wer_args)
+
+ if cfg.wer_kenlm_model is not None:
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
+
+ dec_args = Namespace()
+ dec_args.nbest = 1
+ dec_args.criterion = "ctc"
+ dec_args.kenlm_model = cfg.wer_kenlm_model
+ dec_args.lexicon = cfg.wer_lexicon
+ dec_args.beam = 50
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
+ dec_args.lm_weight = cfg.wer_lm_weight
+ dec_args.word_score = cfg.wer_word_score
+ dec_args.unk_weight = -math.inf
+ dec_args.sil_weight = 0
+
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
+ else:
+ self.w2l_decoder = None
+
+ self.zero_infinity = cfg.zero_infinity
+ self.sentence_avg = cfg.sentence_avg
+
+ self.dec_weight = cfg.dec_weight
+ self.report_accuracy = cfg.report_accuracy
+ self.ignore_prefix_size = cfg.ignore_prefix_size
+ self.eps = cfg.label_smoothing
+
+ def forward(self, model, sample, reduce=True):
+ net_output = model(**sample["net_input"])
+ lprobs = model.get_normalized_probs(
+ net_output, log_probs=True
+ ).contiguous() # (T, B, C) from the encoder
+
+ if "src_lengths" in sample["net_input"]:
+ input_lengths = sample["net_input"]["src_lengths"]
+ else:
+ if net_output["padding_mask"] is not None:
+ non_padding_mask = ~net_output["padding_mask"]
+ input_lengths = non_padding_mask.long().sum(-1)
+ else:
+ input_lengths = lprobs.new_full(
+ (lprobs.size(1),), lprobs.size(0), dtype=torch.long
+ )
+
+ pad_mask = (sample["target"] != self.pad_idx) & (
+ sample["target"] != self.eos_idx
+ )
+ targets_flat = sample["target"].masked_select(pad_mask)
+ if "target_lengths" in sample:
+ target_lengths = sample["target_lengths"]
+ else:
+ target_lengths = pad_mask.sum(-1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = F.ctc_loss(
+ lprobs,
+ targets_flat,
+ input_lengths,
+ target_lengths,
+ blank=self.blank_idx,
+ reduction="sum",
+ zero_infinity=self.zero_infinity,
+ )
+
+ ntokens = (
+ sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
+ )
+
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
+
+ logging_output = {}
+ if "decoder_target" in sample:
+ dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
+ dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
+ logging_output["ctc_loss"] = loss.item()
+ loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
+ logging_output["dec_loss"] = dec_loss.item()
+ logging_output["dec_nll_loss"] = dec_nll_loss.item()
+ logging_output["dec_sample_size"] = dec_sample_size
+
+ if self.report_accuracy:
+ n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
+ logging_output["dec_n_correct"] = utils.item(n_correct.data)
+ logging_output["total"] = utils.item(total.data)
+
+ logging_output = {
+ "loss": utils.item(loss.data), # * sample['ntokens'],
+ "ntokens": ntokens,
+ "nsentences": sample["id"].numel(),
+ "sample_size": sample_size,
+ **logging_output,
+ }
+
+ if not model.training:
+ import editdistance
+
+ with torch.no_grad():
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
+
+ c_err = 0
+ c_len = 0
+ w_errs = 0
+ w_len = 0
+ wv_errs = 0
+ for lp, t, inp_l in zip(
+ lprobs_t,
+ sample["target_label"]
+ if "target_label" in sample
+ else sample["target"],
+ input_lengths,
+ ):
+ lp = lp[:inp_l].unsqueeze(0)
+
+ decoded = None
+ if self.w2l_decoder is not None:
+ decoded = self.w2l_decoder.decode(lp)
+ if len(decoded) < 1:
+ decoded = None
+ else:
+ decoded = decoded[0]
+ if len(decoded) < 1:
+ decoded = None
+ else:
+ decoded = decoded[0]
+
+ p = (t != self.task.target_dictionary.pad()) & (
+ t != self.task.target_dictionary.eos()
+ )
+ targ = t[p]
+ targ_units = self.task.target_dictionary.string(targ)
+ targ_units_arr = targ.tolist()
+
+ toks = lp.argmax(dim=-1).unique_consecutive()
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
+
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
+ c_len += len(targ_units_arr)
+
+ targ_words = post_process(targ_units, self.post_process).split()
+
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
+ pred_words_raw = post_process(pred_units, self.post_process).split()
+
+ if decoded is not None and "words" in decoded:
+ pred_words = decoded["words"]
+ w_errs += editdistance.eval(pred_words, targ_words)
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
+ else:
+ dist = editdistance.eval(pred_words_raw, targ_words)
+ w_errs += dist
+ wv_errs += dist
+
+ w_len += len(targ_words)
+
+ logging_output["wv_errors"] = wv_errs
+ logging_output["w_errors"] = w_errs
+ logging_output["w_total"] = w_len
+ logging_output["c_errors"] = c_err
+ logging_output["c_total"] = c_len
+
+ return loss, sample_size, logging_output
+
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ loss, nll_loss = label_smoothed_nll_loss(
+ lprobs,
+ target,
+ self.eps,
+ ignore_index=self.pad_idx,
+ reduce=reduce,
+ )
+ return loss, nll_loss
+
+ def compute_accuracy(self, model, net_output, sample):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ mask = target.ne(self.pad_idx)
+ n_correct = torch.sum(
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
+ )
+ total = torch.sum(mask)
+ return n_correct, total
+
+ def get_lprobs_and_target(self, model, net_output, sample):
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
+ target = sample["decoder_target"]
+ if self.ignore_prefix_size > 0:
+ if getattr(lprobs, "batch_first", False):
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
+ target = target[:, self.ignore_prefix_size :].contiguous()
+ else:
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
+ target = target[self.ignore_prefix_size :, :].contiguous()
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
+
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training."""
+
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
+ nsentences = utils.item(
+ sum(log.get("nsentences", 0) for log in logging_outputs)
+ )
+ sample_size = utils.item(
+ sum(log.get("sample_size", 0) for log in logging_outputs)
+ )
+
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar("ntokens", ntokens)
+ metrics.log_scalar("nsentences", nsentences)
+ if sample_size != ntokens:
+ metrics.log_scalar(
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
+ metrics.log_scalar("_c_errors", c_errors)
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
+ metrics.log_scalar("_c_total", c_total)
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
+ metrics.log_scalar("_w_errors", w_errors)
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
+ metrics.log_scalar("_wv_errors", wv_errors)
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
+ metrics.log_scalar("_w_total", w_total)
+
+ if c_total > 0:
+ metrics.log_derived(
+ "uer",
+ lambda meters: safe_round(
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
+ )
+ if meters["_c_total"].sum > 0
+ else float("nan"),
+ )
+ if w_total > 0:
+ metrics.log_derived(
+ "wer",
+ lambda meters: safe_round(
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
+ )
+ if meters["_w_total"].sum > 0
+ else float("nan"),
+ )
+ metrics.log_derived(
+ "raw_wer",
+ lambda meters: safe_round(
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
+ )
+ if meters["_w_total"].sum > 0
+ else float("nan"),
+ )
+
+ if "dec_loss" in logging_outputs[0]:
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
+ dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
+ dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
+ metrics.log_scalar(
+ "dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
+ )
+ metrics.log_scalar(
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar(
+ "dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
+ )
+ metrics.log_derived(
+ "dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
+ )
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
+ if total > 0:
+ metrics.log_scalar("total", total)
+ n_correct = utils.item(
+ sum(log.get("dec_n_correct", 0) for log in logging_outputs)
+ )
+ metrics.log_scalar("dec_n_correct", n_correct)
+ metrics.log_derived(
+ "dec_accuracy",
+ lambda meters: round(
+ meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
+ )
+ if meters["total"].sum > 0
+ else float("nan"),
+ )
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
diff --git a/SpeechT5/Speech2C/speech2c/criterions/speech2c_criterion.py b/SpeechT5/Speech2C/speech2c/criterions/speech2c_criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6a695fc04024df3f2b5f8d87077484491c90d84
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/criterions/speech2c_criterion.py
@@ -0,0 +1,261 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import re
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn.functional as F
+from fairseq import metrics, utils
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
+from fairseq.criterions.hubert_criterion import HubertCriterionConfig
+
+@dataclass
+class Speech2cCriterionConfig(HubertCriterionConfig):
+ dec_weight: float = field(
+ default=1.0,
+ metadata={"help": "weights for decoder CE Loss, loss will be (hubert_loss + dec_weight * CE_Loss)"},
+ )
+ report_accuracy: bool = field(
+ default=True,
+ metadata={"help": "report decoder accuracy metric"},
+ )
+ ignore_prefix_size: int = field(
+ default=0,
+ metadata={"help": "Ignore first N tokens"},
+ )
+ label_smoothing: float = field(
+ default=0.0,
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
+ )
+
+
+@register_criterion("speech2c", dataclass=Speech2cCriterionConfig)
+class Speech2cCriterion(FairseqCriterion):
+ def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None, dec_weight=1.0, report_accuracy=False, ignore_prefix_size=0, label_smoothing=0.0):
+ super().__init__(task)
+ self.pred_masked_weight = pred_masked_weight
+ self.pred_nomask_weight = pred_nomask_weight
+ self.loss_weights = loss_weights
+ self.log_keys = [] if log_keys is None else log_keys
+ self.dec_weight = dec_weight
+ self.report_accuracy = report_accuracy
+ self.ignore_prefix_size = ignore_prefix_size
+ self.eps = label_smoothing
+ self.padding_idx = task.dictionaries[0].pad()
+
+ def forward(self, model, sample, reduce=True, log_pred=False):
+ """Compute the loss for the given sample.
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ net_output = model(target_list=sample["target_list"], **sample["net_input"])
+ loss = 0.0
+ sample_size = 0
+ logging_output = {}
+ reduction = "sum" if reduce else "none"
+
+ loss_m_list = []
+ logp_m_list = model.get_logits(net_output, True)
+ targ_m_list = model.get_targets(net_output, True)
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
+ loss_m_list.append(loss_m)
+ logging_output[f"loss_m_{i}"] = loss_m.detach().item()
+ if self.pred_masked_weight > 0:
+ loss += self.pred_masked_weight * sum(loss_m_list)
+ sample_size += targ_m_list[0].numel()
+
+ loss_u_list = []
+ logp_u_list = model.get_logits(net_output, False)
+ targ_u_list = model.get_targets(net_output, False)
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
+ loss_u_list.append(loss_u)
+ logging_output[f"loss_u_{i}"] = loss_u.detach().item()
+ if self.pred_nomask_weight > 0:
+ loss += self.pred_nomask_weight * sum(loss_u_list)
+ sample_size += targ_u_list[0].numel()
+
+ if self.loss_weights is not None:
+ assert hasattr(model, "get_extra_losses")
+ extra_losses, names = model.get_extra_losses(net_output)
+ if torch.is_tensor(extra_losses):
+ extra_losses = [extra_losses]
+ names = [names]
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
+ assert len(extra_losses) == len(
+ self.loss_weights
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
+ if coef != 0 and p is not None:
+ p = coef * p.float() * sample_size
+ loss += p
+ logging_output[f"loss_{n}"] = p.item()
+
+ if "decoder_target" in sample:
+ dec_sample_size = sample["dec_ntokens"]
+ dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
+ loss = loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
+ logging_output["dec_loss"] = dec_loss.item()
+ logging_output["dec_nll_loss"] = dec_nll_loss.item()
+ logging_output["dec_sample_size"] = dec_sample_size
+
+ if self.report_accuracy:
+ n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
+ logging_output["dec_n_correct"] = utils.item(n_correct.data)
+ logging_output["total"] = utils.item(total.data)
+
+ logging_output = {
+ "loss": loss.item() if reduce else loss,
+ "ntokens": sample_size,
+ "nsentences": sample["id"].numel(),
+ "sample_size": sample_size,
+ **logging_output,
+ }
+
+ for lk in self.log_keys:
+ if lk in net_output:
+ logging_output[lk] = float((net_output[lk]))
+
+ def compute_correct(logits):
+ if logits.numel() == 0:
+ return 0, 0
+ else:
+ assert logits.dim() > 1, logits.shape
+ max = logits.argmax(-1) == 0
+ min = logits.argmin(-1) == 0
+ both = max & min
+ corr = max.long().sum().item() - both.long().sum().item()
+ count = max.numel()
+ return corr, count
+
+ with torch.no_grad():
+ for i, logp_m in enumerate(logp_m_list):
+ corr_m, count_m = compute_correct(logp_m)
+ logging_output[f"correct_m_{i}"] = corr_m
+ logging_output[f"count_m_{i}"] = count_m
+
+ for i, logp_u in enumerate(logp_u_list):
+ corr_u, count_u = compute_correct(logp_u)
+ logging_output[f"correct_u_{i}"] = corr_u
+ logging_output[f"count_u_{i}"] = count_u
+
+ return loss, sample_size, logging_output
+
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ loss, nll_loss = label_smoothed_nll_loss(
+ lprobs,
+ target,
+ self.eps,
+ ignore_index=self.padding_idx,
+ reduce=reduce,
+ )
+ return loss, nll_loss
+
+ def compute_accuracy(self, model, net_output, sample):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ mask = target.ne(self.padding_idx)
+ n_correct = torch.sum(
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
+ )
+ total = torch.sum(mask)
+ return n_correct, total
+
+ def get_lprobs_and_target(self, model, net_output, sample):
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
+ target = sample["decoder_target"]
+ if self.ignore_prefix_size > 0:
+ if getattr(lprobs, "batch_first", False):
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
+ target = target[:, self.ignore_prefix_size :].contiguous()
+ else:
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
+ target = target[self.ignore_prefix_size :, :].contiguous()
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+
+ metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
+ if sample_size != ntokens:
+ metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
+ else:
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
+
+ counts = {}
+ for lk in logging_outputs[0].keys():
+ if lk.startswith("count_"):
+ val = sum(log[lk] for log in logging_outputs)
+ metrics.log_scalar(lk, val)
+ counts[lk] = val
+
+ for lk in logging_outputs[0].keys():
+ if lk.startswith("loss_"):
+ val = sum(log[lk] for log in logging_outputs)
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
+ elif lk.startswith("correct_"):
+ val = sum(log[lk] for log in logging_outputs)
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
+
+ if "dec_loss" in logging_outputs[0]:
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
+ dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
+ dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
+ metrics.log_scalar(
+ "dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
+ )
+ metrics.log_scalar(
+ "dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
+ )
+ metrics.log_derived(
+ "dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
+ )
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
+ if total > 0:
+ metrics.log_scalar("total", total)
+ n_correct = utils.item(
+ sum(log.get("dec_n_correct", 0) for log in logging_outputs)
+ )
+ metrics.log_scalar("dec_n_correct", n_correct)
+ metrics.log_derived(
+ "dec_accuracy",
+ lambda meters: round(
+ meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
+ )
+ if meters["total"].sum > 0
+ else float("nan"),
+ )
+
+ @staticmethod
+ def aggregate_logging_outputs(logging_outputs):
+ """Aggregate logging outputs from data parallel training."""
+ raise NotImplementedError()
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return False
diff --git a/SpeechT5/Speech2C/speech2c/data/speech2c_dataset.py b/SpeechT5/Speech2C/speech2c/data/speech2c_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7af1303b0faa145d19e0bdf1d0a1ed9db61ad625
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/data/speech2c_dataset.py
@@ -0,0 +1,145 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import logging
+from typing import Any, List, Optional, Union
+
+import torch
+from fairseq.data import data_utils, Dictionary
+from fairseq.data.audio.hubert_dataset import HubertDataset
+logger = logging.getLogger(__name__)
+
+
+class Speech2cDataset(HubertDataset):
+ def __init__(
+ self,
+ manifest_path: str,
+ sample_rate: float,
+ label_paths: List[str],
+ label_rates: Union[List[float], float], # -1 for sequence labels
+ pad_list: List[str],
+ eos_list: List[str],
+ label_processors: Optional[List[Any]] = None,
+ max_keep_sample_size: Optional[int] = None,
+ min_keep_sample_size: Optional[int] = None,
+ max_sample_size: Optional[int] = None,
+ shuffle: bool = True,
+ pad_audio: bool = False,
+ normalize: bool = False,
+ store_labels: bool = True,
+ random_crop: bool = False,
+ single_target: bool = False,
+ tgt_dict: Optional[Dictionary] = None,
+ add_decoder: bool = False,
+ fine_tuning: bool = False,
+ ):
+ super().__init__(
+ manifest_path,
+ sample_rate,
+ label_paths,
+ label_rates,
+ pad_list,
+ eos_list,
+ label_processors,
+ max_keep_sample_size,
+ min_keep_sample_size,
+ max_sample_size,
+ shuffle,
+ pad_audio,
+ normalize,
+ store_labels,
+ random_crop,
+ single_target
+ )
+
+ self.tgt_dict = tgt_dict
+ self.add_decoder = add_decoder
+ self.fine_tuning = fine_tuning
+
+ def collater(self, samples):
+ # target = max(sizes) -> random_crop not used
+ # target = max_sample_size -> random_crop used for long
+ samples = [s for s in samples if s["source"] is not None]
+ if len(samples) == 0:
+ return {}
+
+ audios = [s["source"] for s in samples]
+ audio_sizes = [len(s) for s in audios]
+ if self.pad_audio:
+ audio_size = min(max(audio_sizes), self.max_sample_size)
+ else:
+ audio_size = min(min(audio_sizes), self.max_sample_size)
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
+ audios, audio_size
+ )
+
+ targets_by_label = [
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
+ ]
+ targets_list, lengths_list, ntokens_list = self.collater_label(
+ targets_by_label, audio_size, audio_starts
+ )
+
+ if self.add_decoder:
+ if self.fine_tuning:
+ decoder_label = [
+ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
+ for i in range(targets_list[0].size(0))
+ ]
+ else:
+ decoder_label = [
+ torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive(), torch.tensor([self.tgt_dict.eos()])), 0).long()
+ for i in range(targets_list[0].size(0))
+ ]
+ dec_ntokens = sum(x.size(0) for x in decoder_label)
+ decoder_target = data_utils.collate_tokens(
+ decoder_label,
+ self.tgt_dict.pad(),
+ self.tgt_dict.eos(),
+ left_pad=False,
+ move_eos_to_beginning=False,
+ )
+ decoder_target_lengths = torch.tensor(
+ [x.size(0) for x in decoder_label], dtype=torch.long
+ )
+ prev_output_tokens = data_utils.collate_tokens(
+ decoder_label,
+ self.tgt_dict.pad(),
+ self.tgt_dict.eos(),
+ left_pad=False,
+ move_eos_to_beginning=True,
+ )
+ net_input = {
+ "source": collated_audios,
+ "padding_mask": padding_mask,
+ "prev_output_tokens": prev_output_tokens,
+ }
+ batch = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": net_input,
+ "decoder_target": decoder_target,
+ "decoder_target_lengths": decoder_target_lengths,
+ "dec_ntokens": dec_ntokens,
+ }
+ else:
+ net_input = {"source": collated_audios, "padding_mask": padding_mask}
+ batch = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": net_input,
+ }
+
+ if self.single_target:
+ batch["target_lengths"] = lengths_list[0]
+ batch["ntokens"] = ntokens_list[0]
+ batch["target"] = targets_list[0]
+ else:
+ batch["target_lengths_list"] = lengths_list
+ batch["ntokens_list"] = ntokens_list
+ batch["target_list"] = targets_list
+ return batch
diff --git a/SpeechT5/Speech2C/speech2c/models/modules/ctc_prefix_score.py b/SpeechT5/Speech2C/speech2c/models/modules/ctc_prefix_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..b42cbd819abf7bdd718bef3db3f553c8360ac384
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/modules/ctc_prefix_score.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python3
+
+# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import numpy as np
+import six
+
+
+class CTCPrefixScore(object):
+ """Compute CTC label sequence scores
+ which is based on Algorithm 2 in WATANABE et al.
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
+ but extended to efficiently compute the probablities of multiple labels
+ simultaneously
+ """
+
+ def __init__(self, x, blank, eos, xp):
+ self.xp = xp
+ self.logzero = -10000000000.0
+ self.blank = blank
+ self.eos = eos
+ self.input_length = len(x)
+ self.x = x
+
+ def initial_state(self):
+ """Obtain an initial CTC state
+ :return: CTC state
+ """
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
+ # r_t^n() and r_t^b(), where 0 and 1 of axis=1 represent
+ # superscripts n and b (non-blank and blank), respectively.
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
+ r[0, 1] = self.x[0, self.blank]
+ for i in six.moves.range(1, self.input_length):
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
+ return r
+
+ def __call__(self, y, cs, r_prev):
+ """Compute CTC prefix scores for next labels
+ :param y : prefix label sequence
+ :param cs : array of next labels
+ :param r_prev: previous CTC state
+ :return ctc_scores, ctc_states
+ """
+ # initialize CTC states
+ output_length = len(y) - 1 # ignore sos
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
+ # that corresponds to r_t^n(h) and r_t^b(h).
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
+ xs = self.x[:, cs]
+ if output_length == 0:
+ r[0, 0] = xs[0]
+ r[0, 1] = self.logzero
+ else:
+ r[output_length - 1] = self.logzero
+
+ # prepare forward probabilities for the last label
+ r_sum = self.xp.logaddexp(
+ r_prev[:, 0], r_prev[:, 1]
+ ) # log(r_t^n(g) + r_t^b(g))
+ last = y[-1]
+ if output_length > 0 and last in cs:
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
+ for i in six.moves.range(len(cs)):
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
+ else:
+ log_phi = r_sum
+
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
+ # and log prefix probabilities log(psi)
+ start = max(output_length, 1)
+ log_psi = r[start - 1, 0]
+ for t in six.moves.range(start, self.input_length):
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
+ r[t, 1] = (
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
+ )
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
+
+ # get P(...eos|X) that ends with the prefix itself
+ eos_pos = self.xp.where(cs == self.eos)[0]
+ if len(eos_pos) > 0:
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
+
+ # exclude blank probs
+ blank_pos = self.xp.where(cs == self.blank)[0]
+ if len(blank_pos) > 0:
+ log_psi[blank_pos] = self.logzero
+
+ # return the log prefix probability and CTC states, where the label axis
+ # of the CTC states is moved to the first axis to slice it easily
+ return log_psi, self.xp.rollaxis(r, 2)
diff --git a/SpeechT5/Speech2C/speech2c/models/modules/multihead_attention.py b/SpeechT5/Speech2C/speech2c/models/modules/multihead_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b1c1445037ada5aef5b8cf9fd3b63b05d95aca1
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/modules/multihead_attention.py
@@ -0,0 +1,341 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+from typing import Dict, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq import utils
+from torch import Tensor
+
+from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
+
+
+class MultiheadAttention(FairseqMultiheadAttention):
+ """Multi-headed attention.
+
+ See "Attention Is All You Need" for more details.
+ """
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ q_noise=0.0,
+ qn_block_size=8,
+ ):
+ super().__init__(
+ embed_dim,
+ num_heads,
+ kdim,
+ vdim,
+ dropout,
+ bias,
+ add_bias_kv,
+ add_zero_attn,
+ self_attention,
+ encoder_decoder_attention,
+ q_noise,
+ qn_block_size,
+ )
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ value: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ need_weights: bool = True,
+ static_kv: bool = False,
+ attn_mask: Optional[Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ position_bias: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ is_tpu = query.device.type == "xla"
+
+ tgt_len, bsz, embed_dim = query.size()
+ src_len = tgt_len
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ assert key_bsz == bsz
+ assert value is not None
+ assert src_len, bsz == value.shape[:2]
+
+ if (
+ not self.onnx_trace
+ and not is_tpu # don't use PyTorch version on TPUs
+ and incremental_state is None
+ and not static_kv
+ # A workaround for quantization to work. Otherwise JIT compilation
+ # treats bias in linear module as method.
+ and not torch.jit.is_scripting()
+ and position_bias is None
+ ):
+ assert key is not None and value is not None
+ return F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ torch.empty([0]),
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout_module.p,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training or self.dropout_module.apply_during_inference,
+ key_padding_mask,
+ need_weights,
+ attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+ ],
+ dim=1,
+ )
+
+ q = (
+ q.contiguous()
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ if k is not None:
+ k = (
+ k.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ if v is not None:
+ v = (
+ v.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = torch.cat([prev_key, k], dim=1)
+ src_len = k.size(1)
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = torch.cat([prev_value, v], dim=1)
+ prev_key_padding_mask: Optional[Tensor] = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ # In this branch incremental_state is never None
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ assert k is not None
+ assert k.size(1) == src_len
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
+ key_padding_mask
+ ),
+ ],
+ dim=1,
+ )
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ if position_bias is not None: ## first order
+ ## position_bias: [241, 241, 64]
+ #print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
+ reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
+ #print ("reshape_q: ", reshape_q.size())
+ B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
+ #print ("B: ", B.size()) ## [241, 492, 241]
+ #B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
+ B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
+ #print ("B 2: ", B.size())
+ attn_weights += B
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ if self.onnx_trace:
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ if not is_tpu:
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+ else:
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = utils.softmax(
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
+ )
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+
+ assert v is not None
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ if self.onnx_trace and attn.size(1) == 1:
+ # when ONNX tracing a single decoder step (sequence length == 1)
+ # the transpose is a no-op copy before view, thus unnecessary
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
+ else:
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+ attn_weights: Optional[Tensor] = None
+ if need_weights:
+ attn_weights = attn_weights_float.view(
+ bsz, self.num_heads, tgt_len, src_len
+ ).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights
diff --git a/SpeechT5/Speech2C/speech2c/models/modules/relative_pos_enc.py b/SpeechT5/Speech2C/speech2c/models/modules/relative_pos_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a073ebf2893e9e9b092aa520bdaf927e9388c2b
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/modules/relative_pos_enc.py
@@ -0,0 +1,35 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import torch
+
+class RelativePositionalEncoding(torch.nn.Module):
+ def __init__(self, d_model, maxlen=1000, embed_v=False):
+ super(RelativePositionalEncoding, self).__init__()
+
+ self.d_model = d_model
+ self.maxlen = maxlen
+ self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
+ if embed_v:
+ self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
+ self.embed_v = embed_v
+
+
+ def forward(self, pos_seq, incremental_state=None):
+ pos_seq[pos_seq < -self.maxlen] = -self.maxlen
+ pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
+ pos_seq = pos_seq + self.maxlen
+
+ if incremental_state is not None:
+ pos_seq = pos_seq[-1:]
+
+ if self.embed_v:
+ return self.pe_k(pos_seq), self.pe_v(pos_seq)
+ else:
+ return self.pe_k(pos_seq), None
diff --git a/SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder.py b/SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaf4dce4ac717453bf4c37f3f393092ea53ef062
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder.py
@@ -0,0 +1,485 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn as nn
+from fairseq import utils
+from fairseq.distributed import fsdp_wrap
+from fairseq.models import FairseqIncrementalDecoder
+from fairseq.models.transformer import TransformerConfig
+from fairseq.models.transformer.transformer_decoder import module_name_fordropout, Linear
+from fairseq.modules import (
+ AdaptiveSoftmax,
+ BaseLayer,
+ FairseqDropout,
+ LayerDropModuleList,
+ LayerNorm,
+ PositionalEmbedding,
+ SinusoidalPositionalEmbedding,
+)
+from fairseq.modules.checkpoint_activations import checkpoint_wrapper
+from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
+from torch import Tensor
+
+
+from speech2c.models.modules.transformer_decoder_layer import TransformerDecoderLayerBase
+from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
+
+
+class TransformerDecoderBase(FairseqIncrementalDecoder):
+ """
+ Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
+ is a :class:`TransformerDecoderLayer`.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
+ embed_tokens (torch.nn.Embedding): output embedding
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
+ (default: False).
+ """
+
+ def __init__(
+ self,
+ cfg,
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=False,
+ output_projection=None,
+ use_rel_pos_enc=False,
+ ):
+ self.cfg = cfg
+ super().__init__(dictionary)
+ self.register_buffer("version", torch.Tensor([3]))
+ self._future_mask = torch.empty(0)
+
+ self.dropout_module = FairseqDropout(
+ cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
+ )
+ self.decoder_layerdrop = cfg.decoder.layerdrop
+ self.share_input_output_embed = cfg.share_decoder_input_output_embed
+
+ input_embed_dim = embed_tokens.embedding_dim
+ embed_dim = cfg.decoder.embed_dim
+ self.embed_dim = embed_dim
+ self.output_embed_dim = cfg.decoder.output_dim
+
+ self.padding_idx = embed_tokens.padding_idx
+ self.max_target_positions = cfg.max_target_positions
+
+ self.embed_tokens = embed_tokens
+
+ self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
+
+ if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
+ self.quant_noise = apply_quant_noise_(
+ nn.Linear(embed_dim, embed_dim, bias=False),
+ cfg.quant_noise.pq,
+ cfg.quant_noise.pq_block_size,
+ )
+ else:
+ self.quant_noise = None
+
+ self.project_in_dim = (
+ Linear(input_embed_dim, embed_dim, bias=False)
+ if embed_dim != input_embed_dim
+ else None
+ )
+ self.embed_positions = (
+ PositionalEmbedding(
+ self.max_target_positions,
+ embed_dim,
+ self.padding_idx,
+ learned=cfg.decoder.learned_pos,
+ )
+ if not cfg.no_token_positional_embeddings
+ else None
+ )
+ if cfg.layernorm_embedding:
+ self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
+ else:
+ self.layernorm_embedding = None
+
+ self.cross_self_attention = cfg.cross_self_attention
+
+ self.use_rel_pos_enc = use_rel_pos_enc
+ if self.decoder_layerdrop > 0.0:
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
+ else:
+ self.layers = nn.ModuleList([])
+ self.layers.extend(
+ [
+ self.build_decoder_layer(cfg, no_encoder_attn)
+ for _ in range(cfg.decoder.layers)
+ ]
+ )
+ self.num_layers = len(self.layers)
+
+ if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
+ self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
+ else:
+ self.layer_norm = None
+
+ self.project_out_dim = (
+ Linear(embed_dim, self.output_embed_dim, bias=False)
+ if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
+ else None
+ )
+
+ self.adaptive_softmax = None
+ self.output_projection = output_projection
+ if self.output_projection is None:
+ self.build_output_projection(cfg, dictionary, embed_tokens)
+
+ if self.use_rel_pos_enc:
+ self.pos_emb = RelativePositionalEncoding(self.embed_dim // cfg.decoder.attention_heads, 24)
+
+ def build_output_projection(self, cfg, dictionary, embed_tokens):
+ if cfg.adaptive_softmax_cutoff is not None:
+ self.adaptive_softmax = AdaptiveSoftmax(
+ len(dictionary),
+ self.output_embed_dim,
+ utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
+ dropout=cfg.adaptive_softmax_dropout,
+ adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
+ factor=cfg.adaptive_softmax_factor,
+ tie_proj=cfg.tie_adaptive_proj,
+ )
+ elif self.share_input_output_embed:
+ self.output_projection = nn.Linear(
+ self.embed_tokens.weight.shape[1],
+ self.embed_tokens.weight.shape[0],
+ bias=False,
+ )
+ self.output_projection.weight = self.embed_tokens.weight
+ else:
+ self.output_projection = nn.Linear(
+ self.output_embed_dim, len(dictionary), bias=False
+ )
+ nn.init.normal_(
+ self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
+ )
+ num_base_layers = cfg.base_layers
+ for i in range(num_base_layers):
+ self.layers.insert(
+ ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
+ BaseLayer(cfg),
+ )
+
+ def build_decoder_layer(self, cfg, no_encoder_attn=False):
+ layer = TransformerDecoderLayerBase(cfg, no_encoder_attn, has_relative_attention_bias=self.use_rel_pos_enc)
+ checkpoint = cfg.checkpoint_activations
+ if checkpoint:
+ offload_to_cpu = cfg.offload_activations
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
+ # if we are checkpointing, enforce that FSDP always wraps the
+ # checkpointed layer, regardless of layer size
+ min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
+ return layer
+
+ def forward(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ features_only: bool = False,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ src_lengths: Optional[Any] = None,
+ return_all_hiddens: bool = False,
+ ):
+ """
+ Args:
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
+ `(batch, tgt_len)`, for teacher forcing
+ encoder_out (optional): output from the encoder, used for
+ encoder-side attention, should be of size T x B x C
+ incremental_state (dict): dictionary used for storing state during
+ :ref:`Incremental decoding`
+ features_only (bool, optional): only return features without
+ applying output layer (default: False).
+ full_context_alignment (bool, optional): don't apply
+ auto-regressive mask to self-attention (default: False).
+
+ Returns:
+ tuple:
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
+ - a dictionary with any model-specific outputs
+ """
+
+ x, extra = self.extract_features(
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ incremental_state=incremental_state,
+ full_context_alignment=full_context_alignment,
+ alignment_layer=alignment_layer,
+ alignment_heads=alignment_heads,
+ )
+
+ if not features_only:
+ x = self.output_layer(x)
+ return x, extra
+
+ def extract_features_scriptable(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]],
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ ):
+ """
+ Similar to *forward* but only return features.
+
+ Includes several features from "Jointly Learning to Align and
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
+
+ Args:
+ full_context_alignment (bool, optional): don't apply
+ auto-regressive mask to self-attention (default: False).
+ alignment_layer (int, optional): return mean alignment over
+ heads at this layer (default: last layer).
+ alignment_heads (int, optional): only average alignment over
+ this many heads (default: all heads).
+
+ Returns:
+ tuple:
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
+ - a dictionary with any model-specific outputs
+ """
+ bs, slen = prev_output_tokens.size()
+ if alignment_layer is None:
+ alignment_layer = self.num_layers - 1
+
+ enc: Optional[Tensor] = None
+ padding_mask: Optional[Tensor] = None
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
+ enc = encoder_out["encoder_out"][0]
+ assert (
+ enc.size()[1] == bs
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
+ padding_mask = encoder_out["encoder_padding_mask"][0]
+
+ # embed positions
+ positions = None
+ if self.embed_positions is not None:
+ positions = self.embed_positions(
+ prev_output_tokens, incremental_state=incremental_state
+ )
+
+ if incremental_state is not None:
+ prev_output_tokens = prev_output_tokens[:, -1:]
+ if positions is not None:
+ positions = positions[:, -1:]
+
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+
+ if self.quant_noise is not None:
+ x = self.quant_noise(x)
+
+ if self.project_in_dim is not None:
+ x = self.project_in_dim(x)
+
+ if positions is not None:
+ x += positions
+
+ if self.layernorm_embedding is not None:
+ x = self.layernorm_embedding(x)
+
+ x = self.dropout_module(x)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ if self.use_rel_pos_enc:
+ pos_seq = torch.arange(0, slen).long().to(x.device)
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
+ pos_k, _ = self.pos_emb(pos_seq, incremental_state)
+ else:
+ pos_k = None
+
+ self_attn_padding_mask: Optional[Tensor] = None
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
+
+ # decoder layers
+ attn: Optional[Tensor] = None
+ inner_states: List[Optional[Tensor]] = [x]
+ for idx, layer in enumerate(self.layers):
+ if incremental_state is None and not full_context_alignment:
+ self_attn_mask = self.buffered_future_mask(x)
+ else:
+ self_attn_mask = None
+
+ x, layer_attn, _ = layer(
+ x,
+ enc,
+ padding_mask,
+ incremental_state,
+ self_attn_mask=self_attn_mask,
+ self_attn_padding_mask=self_attn_padding_mask,
+ need_attn=bool((idx == alignment_layer)),
+ need_head_weights=bool((idx == alignment_layer)),
+ pos_bias=pos_k,
+ )
+ inner_states.append(x)
+ if layer_attn is not None and idx == alignment_layer:
+ attn = layer_attn.float().to(x)
+
+ if attn is not None:
+ if alignment_heads is not None:
+ attn = attn[:alignment_heads]
+
+ # average probabilities over heads
+ attn = attn.mean(dim=0)
+
+ if self.layer_norm is not None:
+ x = self.layer_norm(x)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ if self.project_out_dim is not None:
+ x = self.project_out_dim(x)
+
+ return x, {"attn": [attn], "inner_states": inner_states}
+
+ def output_layer(self, features):
+ """Project features to the vocabulary size."""
+ if self.adaptive_softmax is None:
+ # project back to size of vocabulary
+ return self.output_projection(features)
+ else:
+ return features
+
+ def max_positions(self):
+ """Maximum output length supported by the decoder."""
+ if self.embed_positions is None:
+ return self.max_target_positions
+ return min(self.max_target_positions, self.embed_positions.max_positions)
+
+ def buffered_future_mask(self, tensor):
+ dim = tensor.size(0)
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
+ if (
+ self._future_mask.size(0) == 0
+ or (not self._future_mask.device == tensor.device)
+ or self._future_mask.size(0) < dim
+ ):
+ self._future_mask = torch.triu(
+ utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
+ )
+ self._future_mask = self._future_mask.to(tensor)
+ return self._future_mask[:dim, :dim]
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
+ weights_key = "{}.embed_positions.weights".format(name)
+ if weights_key in state_dict:
+ del state_dict[weights_key]
+ state_dict[
+ "{}.embed_positions._float_tensor".format(name)
+ ] = torch.FloatTensor(1)
+
+ if f"{name}.output_projection.weight" not in state_dict:
+ if self.share_input_output_embed:
+ embed_out_key = f"{name}.embed_tokens.weight"
+ else:
+ embed_out_key = f"{name}.embed_out"
+ if embed_out_key in state_dict:
+ state_dict[f"{name}.output_projection.weight"] = state_dict[
+ embed_out_key
+ ]
+ if not self.share_input_output_embed:
+ del state_dict[embed_out_key]
+
+ for i in range(self.num_layers):
+ # update layer norms
+ layer_norm_map = {
+ "0": "self_attn_layer_norm",
+ "1": "encoder_attn_layer_norm",
+ "2": "final_layer_norm",
+ }
+ for old, new in layer_norm_map.items():
+ for m in ("weight", "bias"):
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
+ if k in state_dict:
+ state_dict[
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
+ ] = state_dict[k]
+ del state_dict[k]
+
+ version_key = "{}.version".format(name)
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
+ # earlier checkpoints did not normalize after the stack of layers
+ self.layer_norm = None
+ self.normalize = False
+ state_dict[version_key] = torch.Tensor([1])
+
+ return state_dict
+
+
+class TransformerDecoder(TransformerDecoderBase):
+ def __init__(
+ self,
+ args,
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=False,
+ output_projection=None,
+ ):
+ self.args = args
+ super().__init__(
+ TransformerConfig.from_namespace(args),
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=no_encoder_attn,
+ output_projection=output_projection,
+ use_rel_pos_enc=args.use_rel_pos_enc,
+ )
+
+ def build_output_projection(self, args, dictionary, embed_tokens):
+ super().build_output_projection(
+ TransformerConfig.from_namespace(args), dictionary, embed_tokens
+ )
+
+ def build_decoder_layer(self, args, no_encoder_attn=False):
+ return super().build_decoder_layer(
+ TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
+ )
+
+class TransformerDecoderScriptable(TransformerDecoder):
+ def extract_features(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ ):
+ # call scriptable method from parent class
+ x, _ = self.extract_features_scriptable(
+ prev_output_tokens,
+ encoder_out,
+ incremental_state,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ )
+ return x, None
+
diff --git a/SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder_layer.py b/SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..780bb43d8d3aaf456c0ae4cf5223b9b7eae599e8
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder_layer.py
@@ -0,0 +1,215 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+from typing import Dict, List, Optional
+
+import torch
+from torch import Tensor
+from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase
+from fairseq.modules import LayerNorm
+
+from speech2c.models.modules.multihead_attention import MultiheadAttention
+
+
+class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase):
+ """Decoder layer block.
+
+ In the original paper each operation (multi-head attention, encoder
+ attention or FFN) is postprocessed with: `dropout -> add residual ->
+ layernorm`. In the tensor2tensor code they suggest that learning is more
+ robust when preprocessing each layer with layernorm and postprocessing with:
+ `dropout -> add residual`. We default to the approach in the paper, but the
+ tensor2tensor approach can be enabled by setting
+ *cfg.decoder.normalize_before* to ``True``.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
+ (default: False).
+ """
+
+ def __init__(
+ self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False
+ ):
+ super().__init__(
+ cfg,
+ no_encoder_attn,
+ add_bias_kv,
+ add_zero_attn,
+ )
+
+ if has_relative_attention_bias:
+ self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads)
+
+ def build_self_attention(
+ self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
+ ):
+ return MultiheadAttention(
+ embed_dim,
+ cfg.decoder.attention_heads,
+ dropout=cfg.attention_dropout,
+ add_bias_kv=add_bias_kv,
+ add_zero_attn=add_zero_attn,
+ self_attention=not cfg.cross_self_attention,
+ q_noise=self.quant_noise,
+ qn_block_size=self.quant_noise_block_size,
+ )
+
+ def forward(
+ self,
+ x,
+ encoder_out: Optional[torch.Tensor] = None,
+ encoder_padding_mask: Optional[torch.Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ prev_self_attn_state: Optional[List[torch.Tensor]] = None,
+ prev_attn_state: Optional[List[torch.Tensor]] = None,
+ self_attn_mask: Optional[torch.Tensor] = None,
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
+ need_attn: bool = False,
+ need_head_weights: bool = False,
+ pos_bias=None,
+ ):
+ """
+ Args:
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
+ encoder_padding_mask (ByteTensor, optional): binary
+ ByteTensor of shape `(batch, src_len)` where padding
+ elements are indicated by ``1``.
+ need_attn (bool, optional): return attention weights
+ need_head_weights (bool, optional): return attention weights
+ for each head (default: return average over heads).
+ Returns:
+ encoded output of shape `(seq_len, batch, embed_dim)`
+ """
+ if need_head_weights:
+ need_attn = True
+
+ residual = x
+ if self.normalize_before:
+ x = self.self_attn_layer_norm(x)
+ if pos_bias is not None:
+ pos_bias = self.norm_k(pos_bias)
+ if prev_self_attn_state is not None:
+ prev_key, prev_value = prev_self_attn_state[:2]
+ saved_state: Dict[str, Optional[Tensor]] = {
+ "prev_key": prev_key,
+ "prev_value": prev_value,
+ }
+ if len(prev_self_attn_state) >= 3:
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
+ assert incremental_state is not None
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
+ _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
+ if self.cross_self_attention and not (
+ incremental_state is not None
+ and _self_attn_input_buffer is not None
+ and "prev_key" in _self_attn_input_buffer
+ ):
+ if self_attn_mask is not None:
+ assert encoder_out is not None
+ self_attn_mask = torch.cat(
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
+ )
+ if self_attn_padding_mask is not None:
+ if encoder_padding_mask is None:
+ assert encoder_out is not None
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
+ encoder_out.size(1), encoder_out.size(0)
+ )
+ self_attn_padding_mask = torch.cat(
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
+ )
+ assert encoder_out is not None
+ y = torch.cat((encoder_out, x), dim=0)
+ else:
+ y = x
+
+ x, attn = self.self_attn(
+ query=x,
+ key=y,
+ value=y,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ need_weights=False,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias,
+ )
+ if self.c_attn is not None:
+ tgt_len, bsz = x.size(0), x.size(1)
+ x = x.view(tgt_len, bsz, self.nh, self.head_dim)
+ x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
+ x = x.reshape(tgt_len, bsz, self.embed_dim)
+ if self.attn_ln is not None:
+ x = self.attn_ln(x)
+ x = self.dropout_module(x)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.self_attn_layer_norm(x)
+
+ if self.encoder_attn is not None and encoder_out is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.encoder_attn_layer_norm(x)
+ if prev_attn_state is not None:
+ prev_key, prev_value = prev_attn_state[:2]
+ saved_state: Dict[str, Optional[Tensor]] = {
+ "prev_key": prev_key,
+ "prev_value": prev_value,
+ }
+ if len(prev_attn_state) >= 3:
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
+ assert incremental_state is not None
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
+
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ need_weights=need_attn or (not self.training and self.need_attn),
+ need_head_weights=need_head_weights,
+ )
+ x = self.dropout_module(x)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.encoder_attn_layer_norm(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.final_layer_norm(x)
+
+ x = self.activation_fn(self.fc1(x))
+ x = self.activation_dropout_module(x)
+ if self.ffn_layernorm is not None:
+ x = self.ffn_layernorm(x)
+ x = self.fc2(x)
+ x = self.dropout_module(x)
+ if self.w_resid is not None:
+ residual = torch.mul(self.w_resid, residual)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.final_layer_norm(x)
+ if self.onnx_trace and incremental_state is not None:
+ saved_state = self.self_attn._get_input_buffer(incremental_state)
+ assert saved_state is not None
+ if self_attn_padding_mask is not None:
+ self_attn_state = [
+ saved_state["prev_key"],
+ saved_state["prev_value"],
+ saved_state["prev_key_padding_mask"],
+ ]
+ else:
+ self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
+ return x, attn, self_attn_state
+ return x, attn, None
+
+ def make_generation_fast_(self, need_attn: bool = False, **kwargs):
+ self.need_attn = need_attn
diff --git a/SpeechT5/Speech2C/speech2c/models/modules/transformer_encoder.py b/SpeechT5/Speech2C/speech2c/models/modules/transformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6916c7960cf5bf6fc4fc60257ddb377bfea368fc
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/modules/transformer_encoder.py
@@ -0,0 +1,278 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fairseq import utils
+from fairseq.dataclass import ChoiceEnum
+from fairseq.modules import (
+ LayerNorm,
+ MultiheadAttention,
+ SamePad,
+)
+from fairseq.modules.checkpoint_activations import checkpoint_wrapper
+from fairseq.modules.transformer_sentence_encoder import init_bert_params
+from fairseq.utils import index_put
+from fairseq.distributed import fsdp_wrap
+from fairseq.models.wav2vec.utils import pad_to_multiple
+from fairseq.models.wav2vec.wav2vec2 import TransformerEncoder as W2vTransformerEncoder
+
+from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
+from speech2c.models.modules.multihead_attention import MultiheadAttention
+
+EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
+MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
+
+
+class TransformerEncoder(W2vTransformerEncoder):
+ def __init__(self, args):
+ super().__init__(args)
+
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder_embed_dim
+ self.required_seq_len_multiple = args.required_seq_len_multiple
+ self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
+
+ self.pos_conv = nn.Conv1d(
+ self.embedding_dim,
+ self.embedding_dim,
+ kernel_size=args.conv_pos,
+ padding=args.conv_pos // 2,
+ groups=args.conv_pos_groups,
+ )
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(self.pos_conv.bias, 0)
+
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
+
+ layers = []
+ for _ in range(args.encoder_layers):
+ layer = TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ has_relative_attention_bias=self.use_rel_pos_enc,
+ )
+ if args.checkpoint_activations:
+ layer = fsdp_wrap(layer)
+ layer = checkpoint_wrapper(layer)
+ layers.append(layer)
+ self.layers = nn.ModuleList(layers)
+
+ self.layer_norm_first = args.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = args.encoder_layerdrop
+ if self.use_rel_pos_enc:
+ self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
+
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None, layer=None):
+ x, layer_results = self.extract_features(x, padding_mask, layer)
+
+ if self.layer_norm_first and layer is None:
+ x = self.layer_norm(x)
+
+ return x, layer_results
+
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
+
+ if padding_mask is not None:
+ x = index_put(x, padding_mask, 0)
+
+ x_conv = self.pos_conv(x.transpose(1, 2))
+ x_conv = x_conv.transpose(1, 2)
+ x = x + x_conv
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ # pad to the sequence length dimension
+ x, pad_length = pad_to_multiple(
+ x, self.required_seq_len_multiple, dim=-2, value=0
+ )
+ if pad_length > 0 and padding_mask is None:
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
+ padding_mask[:, -pad_length:] = True
+ else:
+ padding_mask, _ = pad_to_multiple(
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
+ )
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ if self.use_rel_pos_enc:
+ x_len = x.shape[0]
+ pos_seq = torch.arange(0, x_len).long().to(x.device)
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
+ pos_k, pos_v = self.pos_emb(pos_seq)
+ else:
+ pos_k = None
+
+ layer_results = []
+ r = None
+ for i, layer in enumerate(self.layers):
+ dropout_probability = np.random.random()
+ if not self.training or (dropout_probability > self.layerdrop):
+ x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
+ if tgt_layer is not None:
+ # unpad if needed
+ if pad_length > 0:
+ layer_results.append(
+ (
+ x[:-pad_length],
+ z[:, :-pad_length, :-pad_length]
+ if z is not None
+ else z,
+ )
+ )
+ else:
+ layer_results.append((x, z))
+ if i == tgt_layer:
+ r = x
+ break
+
+ if r is not None:
+ x = r
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+ # undo paddding
+ if pad_length > 0:
+ x = x[:, :-pad_length]
+
+ return x, layer_results
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ """
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+ models.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: float = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ has_relative_attention_bias: bool = False,
+ ) -> None:
+
+ super().__init__()
+ # Initialize parameters
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+
+ # Initialize blocks
+ self.activation_fn = utils.get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(
+ self.embedding_dim,
+ num_attention_heads,
+ dropout=attention_dropout,
+ self_attention=True,
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.layer_norm_first = layer_norm_first
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ if has_relative_attention_bias:
+ self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ self_attn_mask: torch.Tensor = None,
+ self_attn_padding_mask: torch.Tensor = None,
+ need_weights: bool = False,
+ att_args=None,
+ pos_bias=None,
+ ):
+ """
+ LayerNorm is applied either before or after the self-attention/ffn
+ modules similar to the original Transformer imlementation.
+ """
+ residual = x
+
+ if self.layer_norm_first:
+ x = self.self_attn_layer_norm(x)
+ if pos_bias is not None:
+ pos_bias = self.norm_k(pos_bias)
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias,
+ )
+ x = self.dropout1(x)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ else:
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ position_bias=pos_bias,
+ )
+
+ x = self.dropout1(x)
+ x = residual + x
+
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ x = self.final_layer_norm(x)
+
+ return x, attn
diff --git a/SpeechT5/Speech2C/speech2c/models/speech2c.py b/SpeechT5/Speech2C/speech2c/models/speech2c.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec69a679451172f8e32047c1bd2275932636e65
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/speech2c.py
@@ -0,0 +1,321 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import logging
+import copy
+import contextlib
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from dataclasses import dataclass, field
+from fairseq.data.dictionary import Dictionary
+from fairseq.models import register_model
+from fairseq.models.hubert import HubertConfig, HubertModel
+from fairseq.models.transformer import Embedding
+from torch import Tensor
+from speech2c.tasks.speech2c_pretraining import (
+ Speech2cPretrainingConfig,
+ Speech2cPretrainingTask,
+)
+
+from speech2c.models.modules.transformer_decoder import TransformerDecoderScriptable
+from speech2c.models.modules.transformer_encoder import TransformerEncoder
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Speech2cConfig(HubertConfig):
+ use_rel_pos_enc: bool = field(
+ default=False,
+ metadata={"help": "whether to use relative positional encoding"},
+ )
+
+ # decoder
+ decoder_layers: int = field(
+ default=6, metadata={"help": "num decoder layers in the transformer"}
+ )
+ decoder_embed_dim: int = field(
+ default=768, metadata={"help": "decoder embedding dimension"}
+ )
+ decoder_ffn_embed_dim: int = field(
+ default=3072, metadata={"help": "decoder embedding dimension for FFN"}
+ )
+ decoder_attention_heads: int = field(
+ default=12, metadata={"help": "num decoder attention heads"}
+ )
+ decoder_normalize_before: bool = field(
+ default=False,
+ metadata={"help": "apply layernorm before each decoder block"},
+ )
+ decoder_layerdrop: float = field(
+ default=0.0,
+ metadata={"help": "probability of dropping a tarnsformer layer"},
+ )
+ share_decoder_input_output_embed: bool = field(
+ default=False,
+ metadata={"help": "share decoder input and output embeddings"},
+ )
+ decoder_output_dim: int = field(
+ default=768, metadata={"help": "decoder output dimension"}
+ )
+ max_target_positions: int = field(
+ default=3000, metadata={"help": "max target position"}
+ )
+ no_scale_embedding: bool = field(
+ default=False,
+ metadata={"help": "not scale embedding"},
+ )
+ adaptive_input: bool = field(
+ default=False,
+ metadata={"help": "adaptive input"},
+ )
+ quant_noise_pq: int = field(
+ default=0, metadata={"help": "quant noise pq"}
+ )
+ decoder_learned_pos: bool = field(
+ default=False,
+ metadata={"help": "decoder learnable positional embedding"},
+ )
+ no_token_positional_embeddings: bool = field(
+ default=False,
+ metadata={"help": "no token positional embeddings"},
+ )
+ decoder_dict_size: int = field(
+ default=-1,
+ metadata={"help": "decoder dictionary dimension, only used for fine-tuning"},
+ )
+
+ # FP16 optimization
+ required_seq_len_multiple: int = field(
+ default=1,
+ metadata={
+ "help": "pad the input to encoder such that the sequence length is divisible by multiple"
+ },
+ )
+ crop_seq_to_multiple: int = field(
+ default=1,
+ metadata={
+ "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple"
+ },
+ )
+
+
+@register_model("speech2c", dataclass=Speech2cConfig)
+class Speech2cModel(HubertModel):
+ def __init__(
+ self,
+ cfg: Speech2cConfig,
+ task_cfg: Speech2cPretrainingConfig,
+ dictionaries: List[Dictionary],
+ ) -> None:
+ super().__init__(cfg, task_cfg, dictionaries)
+ logger.info(f"Speech2cModel Config: {cfg}")
+
+ self.encoder = TransformerEncoder(cfg)
+
+ self.add_decoder = task_cfg.add_decoder
+ if task_cfg.add_decoder:
+ def build_embedding(dictionary, embed_dim):
+ num_embeddings = len(dictionary)
+ padding_idx = dictionary.pad()
+ return Embedding(num_embeddings, embed_dim, padding_idx)
+
+ # To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size
+ cut_dictionary = copy.deepcopy(dictionaries[0])
+ if cfg.decoder_dict_size != -1:
+ cut_dictionary.symbols = cut_dictionary.symbols[:cfg.decoder_dict_size]
+
+ decoder_embed_tokens = build_embedding(
+ cut_dictionary, cfg.decoder_embed_dim
+ )
+
+ self.decoder = TransformerDecoderScriptable(cfg, cut_dictionary, decoder_embed_tokens)
+
+
+ @classmethod
+ def build_model(cls, cfg: Speech2cConfig, task: Speech2cPretrainingTask):
+ """Build a new model instance."""
+
+ model = Speech2cModel(cfg, task.cfg, task.dictionaries)
+ return model
+
+ def get_normalized_probs(
+ self,
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
+ log_probs: bool,
+ sample: Optional[Dict[str, Tensor]] = None,
+ ):
+ # net_output['encoder_out'] is a (B, T, D) tensor
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
+ lprobs.batch_first = True
+ return lprobs
+
+ def forward(
+ self,
+ source: torch.Tensor,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ prev_output_tokens: Optional[torch.Tensor] = None,
+ ) -> Dict[str, torch.Tensor]:
+ """output layer is 1-based"""
+ features = self.forward_features(source)
+ if target_list is not None:
+ features, target_list = self.forward_targets(features, target_list)
+
+ features_pen = features.float().pow(2).mean()
+
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+ unmasked_features = features.clone()
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ features = self.dropout_input(features)
+ unmasked_features = self.dropout_features(unmasked_features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
+ else:
+ x = features
+ mask_indices = None
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+ x, _ = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=None if output_layer is None else output_layer - 1,
+ )
+
+ if features_only:
+ return {"x": x, "padding_mask": padding_mask, "features": features}
+
+ def compute_pred(proj_x, target, label_embs):
+ # compute logits for the i-th label set
+ y = torch.index_select(label_embs, 0, target.long())
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
+ if self.target_glu:
+ y = self.target_glu(y)
+ negs = self.target_glu(negs)
+ # proj_x: (S, D)
+ # y: (S, D)
+ # negs: (Neg, S, D)
+ return self.compute_nce(proj_x, y, negs)
+
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
+
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ proj_x_m = self.final_proj(x[masked_indices])
+ if self.untie_final_proj:
+ proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
+ else:
+ proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
+ logit_m_list = [
+ compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
+ for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
+ ]
+ else:
+ logit_m_list = [None for _ in target_list]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ proj_x_u = self.final_proj(x[nomask_indices])
+ if self.untie_final_proj:
+ proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
+ else:
+ proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
+
+ logit_u_list = [
+ compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
+ for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
+ ]
+ else:
+ logit_u_list = [None for _ in target_list]
+
+ result = {
+ "logit_m_list": logit_m_list,
+ "logit_u_list": logit_u_list,
+ "padding_mask": padding_mask,
+ "features_pen": features_pen,
+ }
+ if self.add_decoder:
+ encoder_out = {
+ "encoder_out": [x.transpose(0, 1)], # T x B x C
+ "encoder_padding_mask": [padding_mask], # B x T
+ }
+ assert prev_output_tokens is not None
+ decoder_out = self.decoder(
+ prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
+ )
+ result['decoder_out'] = decoder_out
+ return result
+
+ def forward_torchscript(self, net_input: Dict[str, Tensor]):
+ """A TorchScript-compatible version of forward.
+ Encoders which use additional arguments may want to override
+ this method for TorchScript compatibility.
+ """
+ res = self.forward(
+ net_input["source"],
+ padding_mask=net_input["padding_mask"],
+ mask=False,
+ features_only=True
+ )
+
+ encoder_out = {
+ "encoder_out": [res["x"].transpose(0, 1)], # T x B x C
+ "encoder_padding_mask": [res["padding_mask"]], # B x T
+ }
+ return encoder_out
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ prev_output_tokens: Optional[torch.Tensor] = None,
+ ft: bool = True,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ with torch.no_grad() if not ft else contextlib.ExitStack():
+ res = self.forward(
+ source,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ output_layer=output_layer,
+ )
+
+ feature = res["features"] if ret_conv else res["x"]
+ if self.add_decoder:
+ encoder_out = {
+ "encoder_out": [feature.transpose(0, 1)], # T x B x C
+ "encoder_padding_mask": [res["padding_mask"]], # B x T
+ }
+ assert prev_output_tokens is not None
+ decoder_out = self.decoder(
+ prev_output_tokens=prev_output_tokens,
+ encoder_out=encoder_out,
+ )
+ else:
+ decoder_out = None
+ return feature, res["padding_mask"], decoder_out
diff --git a/SpeechT5/Speech2C/speech2c/models/speech2c_asr.py b/SpeechT5/Speech2C/speech2c/models/speech2c_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bf8aed97d97f1fd352a884f10173c11043f6a92
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/speech2c_asr.py
@@ -0,0 +1,276 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+from argparse import Namespace
+from omegaconf import II
+
+import torch.nn as nn
+from dataclasses import dataclass, field
+from fairseq import checkpoint_utils, tasks, utils
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
+from fairseq.models.hubert.hubert_asr import HubertAsrConfig, Linear
+from fairseq.tasks import FairseqTask
+
+
+@dataclass
+class Speech2cAsrConfig(HubertAsrConfig):
+ # for decoder
+ decoder_layerdrop: float = field(
+ default=0.0,
+ metadata={"help": "probability of dropping a decoder layer in hubert"},
+ )
+
+ add_decoder: bool = II("task.add_decoder")
+
+@dataclass
+class Speech2cCtcConfig(Speech2cAsrConfig):
+ pass
+
+
+@register_model("speech2c_ctc", dataclass=Speech2cCtcConfig)
+class Speech2cCtc(BaseFairseqModel):
+ def __init__(self, cfg: Speech2cCtcConfig, w2v_encoder: BaseFairseqModel):
+ super().__init__()
+ self.cfg = cfg
+ self.w2v_encoder = w2v_encoder
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ @classmethod
+ def build_model(cls, cfg: Speech2cCtcConfig, task: FairseqTask):
+ """Build a new model instance."""
+ w2v_encoder = Speech2cEncoder(cfg, task.target_dictionary)
+ return cls(cfg, w2v_encoder)
+
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
+ """Get normalized probabilities (or log probs) from a net's output."""
+ if "encoder_out" not in net_output:
+ return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample)
+
+ if "encoder_out_for_ctc" in net_output:
+ logits = net_output["encoder_out_for_ctc"]
+ else:
+ logits = net_output["encoder_out"]
+
+ if isinstance(logits, list):
+ logits = logits[0]
+
+ if log_probs:
+ return utils.log_softmax(logits.float(), dim=-1)
+ else:
+ return utils.softmax(logits.float(), dim=-1)
+
+ def get_logits(self, net_output):
+ logits = net_output["encoder_out"]
+ padding = net_output["encoder_padding_mask"]
+ if padding is not None and padding.any():
+ padding = padding.T
+ logits[padding][..., 0] = 0
+ logits[padding][..., 1:] = float("-inf")
+
+ return logits
+
+ def forward(self, **kwargs):
+ x = self.w2v_encoder(**kwargs)
+ return x
+
+ @property
+ def encoder(self):
+ return self.w2v_encoder
+
+ def reorder_encoder_out(self, encoder_out, new_order):
+ return self.encoder.reorder_encoder_out(encoder_out, new_order)
+
+ @property
+ def decoder(self):
+ return self.w2v_encoder.w2v_model.decoder
+
+
+class Speech2cEncoder(FairseqEncoder):
+ def __init__(self, cfg: Speech2cAsrConfig, tgt_dict=None):
+ self.apply_mask = cfg.apply_mask
+
+ arg_overrides = {
+ "dropout": cfg.dropout,
+ "activation_dropout": cfg.activation_dropout,
+ "dropout_input": cfg.dropout_input,
+ "attention_dropout": cfg.attention_dropout,
+ "mask_length": cfg.mask_length,
+ "mask_prob": cfg.mask_prob,
+ "mask_selection": cfg.mask_selection,
+ "mask_other": cfg.mask_other,
+ "no_mask_overlap": cfg.no_mask_overlap,
+ "mask_channel_length": cfg.mask_channel_length,
+ "mask_channel_prob": cfg.mask_channel_prob,
+ "mask_channel_selection": cfg.mask_channel_selection,
+ "mask_channel_other": cfg.mask_channel_other,
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
+ "encoder_layerdrop": cfg.layerdrop,
+ "decoder_layerdrop": cfg.decoder_layerdrop,
+ "feature_grad_mult": cfg.feature_grad_mult,
+ "decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1,
+ }
+
+ if cfg.w2v_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
+ w2v_args = state.get("cfg", None)
+ if w2v_args is None:
+ w2v_args = convert_namespace_to_omegaconf(state["args"])
+ cfg.w2v_args = w2v_args
+ else:
+ state = None
+ w2v_args = cfg.w2v_args
+ if isinstance(w2v_args, Namespace):
+ cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
+
+ assert cfg.normalize == w2v_args.task.normalize, (
+ "Fine-tuning works best when data normalization is the same. "
+ "Please check that --normalize is set or unset for "
+ "both pre-training and here"
+ )
+
+ w2v_args.task.data = cfg.data
+ w2v_args.task.add_decoder = cfg.add_decoder
+ task = tasks.setup_task(w2v_args.task)
+ if state is not None and "task_state" in state:
+ # This will load the stored "dictionaries" object
+ task.load_state_dict(state["task_state"])
+ model = task.build_model(w2v_args.model)
+
+ if state is not None and not cfg.no_pretrained_weights:
+ if "decoder.embed_tokens.weight" in state["model"]:
+ del state["model"]["decoder.embed_tokens.weight"]
+ if "decoder.output_projection.weight" in state["model"]:
+ del state["model"]["decoder.output_projection.weight"]
+ # set strict=False because we omit some modules
+ model.load_state_dict(state["model"], strict=False)
+
+ model.remove_pretraining_modules()
+
+ super().__init__(task.source_dictionary)
+
+ d = model.mask_emb.size(0)
+
+ self.w2v_model = model
+
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
+ self.num_updates = 0
+
+ if tgt_dict is not None:
+ self.proj = Linear(d, len(tgt_dict))
+ elif getattr(cfg, "decoder_embed_dim", d) != d:
+ self.proj = Linear(d, cfg.decoder_embed_dim)
+ else:
+ self.proj = None
+
+ def set_num_updates(self, num_updates):
+ """Set the number of parameters updates."""
+ super().set_num_updates(num_updates)
+ self.num_updates = num_updates
+
+ def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
+
+ ft = self.freeze_finetune_updates <= self.num_updates
+ w2v_args = {
+ "source": source,
+ "padding_mask": padding_mask,
+ "mask": self.apply_mask and self.training,
+ "prev_output_tokens": prev_output_tokens,
+ "ft": ft,
+ }
+
+ x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args)
+
+ if tbc:
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ x = self.final_dropout(x)
+
+ if self.proj:
+ x = self.proj(x)
+
+ return {
+ "encoder_out": x, # T x B x C
+ "encoder_padding_mask": padding_mask, # B x T
+ "padding_mask": padding_mask,
+ "decoder_out": decoder_out,
+ }
+
+ def get_normalized_probs_decoder(self, net_output, log_probs, sample=None):
+ # net_output['encoder_out'] is a (B, T, D) tensor
+ return self.w2v_model.get_normalized_probs(net_output, log_probs, sample)
+
+ def reorder_encoder_out(self, encoder_out, new_order):
+ if encoder_out["encoder_out"] is not None:
+ if isinstance(encoder_out["encoder_out"], list):
+ encoder_out["encoder_out"] = (
+ [] if len(encoder_out["encoder_out"]) == 0
+ else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
+ )
+ else:
+ encoder_out["encoder_out"] = encoder_out[
+ "encoder_out"
+ ].index_select(1, new_order)
+ if encoder_out["encoder_padding_mask"] is not None:
+ if isinstance(encoder_out["encoder_padding_mask"], list):
+ encoder_out["encoder_padding_mask"] = (
+ [] if len(encoder_out["encoder_padding_mask"]) == 0
+ else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
+ )
+ else:
+ encoder_out["encoder_padding_mask"] = encoder_out[
+ "encoder_padding_mask"
+ ].index_select(0, new_order)
+ if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None:
+ if isinstance(encoder_out["decoder_out"], list):
+ encoder_out["decoder_out"] = (
+ [] if len(encoder_out["decoder_out"]) == 0
+ else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]]
+ )
+ else:
+ encoder_out["decoder_out"] = encoder_out[
+ "decoder_out"
+ ].index_select(0, new_order)
+ if "encoder_out_for_ctc" in encoder_out and encoder_out["encoder_out_for_ctc"] is not None:
+ if isinstance(encoder_out["encoder_out_for_ctc"], list):
+ encoder_out["encoder_out_for_ctc"] = (
+ [] if len(encoder_out["encoder_out_for_ctc"]) == 0
+ else [x.index_select(1, new_order) for x in encoder_out["encoder_out_for_ctc"]]
+ )
+ else:
+ encoder_out["encoder_out_for_ctc"] = encoder_out[
+ "encoder_out_for_ctc"
+ ].index_select(1, new_order)
+
+ return encoder_out
+
+ def forward_torchscript(self, net_input):
+ """A TorchScript-compatible version of forward.
+ Encoders which use additional arguments may want to override
+ this method for TorchScript compatibility.
+ """
+ encoder_out = self.w2v_model.forward_torchscript(net_input)
+
+ assert self.proj is not None
+ encoder_out['encoder_out_for_ctc'] = [self.proj(encoder_out['encoder_out'][0])]
+
+ return encoder_out
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return None
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ return state_dict
+
diff --git a/SpeechT5/Speech2C/speech2c/models/t5_transformer_lm.py b/SpeechT5/Speech2C/speech2c/models/t5_transformer_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d16a2df00b692114f8d84d254cf486d09e1137b
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/models/t5_transformer_lm.py
@@ -0,0 +1,25 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+from fairseq.models import (
+ register_model_architecture,
+)
+from fairseq.models.transformer_lm import base_lm_architecture
+
+
+@register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
+def transformer_lm_t5(args):
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
+ args.decoder_layers = getattr(args, "decoder_layers", 20)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ base_lm_architecture(args)
diff --git a/SpeechT5/Speech2C/speech2c/squence_generator.py b/SpeechT5/Speech2C/speech2c/squence_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e51e8021fe9e4e48619340412df012937db54198
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/squence_generator.py
@@ -0,0 +1,1028 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+from typing import Dict, List, Optional
+import sys
+
+import torch
+import torch.nn as nn
+from fairseq import search, utils
+from fairseq.data import data_utils
+from fairseq.models import FairseqIncrementalDecoder
+from torch import Tensor
+from fairseq.ngram_repeat_block import NGramRepeatBlock
+from speech2c.models.modules.ctc_prefix_score import CTCPrefixScore
+import numpy
+
+
+CTC_SCORING_RATIO = 7.0
+
+class SequenceGenerator(nn.Module):
+ def __init__(
+ self,
+ models,
+ tgt_dict,
+ beam_size=1,
+ max_len_a=0,
+ max_len_b=200,
+ max_len=0,
+ min_len=1,
+ normalize_scores=True,
+ len_penalty=1.0,
+ unk_penalty=0.0,
+ temperature=1.0,
+ match_source_len=False,
+ no_repeat_ngram_size=0,
+ search_strategy=None,
+ eos=None,
+ symbols_to_strip_from_output=None,
+ lm_model=None,
+ lm_weight=1.0,
+ ctc_weight=0.0,
+ ):
+ """Generates translations of a given source sentence.
+ Args:
+ models (List[~fairseq.models.FairseqModel]): ensemble of models,
+ currently support fairseq.models.TransformerModel for scripting
+ beam_size (int, optional): beam width (default: 1)
+ max_len_a/b (int, optional): generate sequences of maximum length
+ ax + b, where x is the source length
+ max_len (int, optional): the maximum length of the generated output
+ (not including end-of-sentence)
+ min_len (int, optional): the minimum length of the generated output
+ (not including end-of-sentence)
+ normalize_scores (bool, optional): normalize scores by the length
+ of the output (default: True)
+ len_penalty (float, optional): length penalty, where <1.0 favors
+ shorter, >1.0 favors longer sentences (default: 1.0)
+ unk_penalty (float, optional): unknown word penalty, where <0
+ produces more unks, >0 produces fewer (default: 0.0)
+ temperature (float, optional): temperature, where values
+ >1.0 produce more uniform samples and values <1.0 produce
+ sharper samples (default: 1.0)
+ match_source_len (bool, optional): outputs should match the source
+ length (default: False)
+ """
+ super().__init__()
+ if isinstance(models, EnsembleModel):
+ self.model = models
+ else:
+ self.model = EnsembleModel(models)
+ self.tgt_dict = tgt_dict
+ self.pad = tgt_dict.pad()
+ self.unk = tgt_dict.unk()
+ self.eos = tgt_dict.eos() if eos is None else eos
+ self.blank = self.tgt_dict.index("")
+ self.symbols_to_strip_from_output = (
+ symbols_to_strip_from_output.union({self.eos})
+ if symbols_to_strip_from_output is not None
+ else {self.eos}
+ )
+ self.vocab_size = len(tgt_dict)
+ self.beam_size = beam_size
+ # the max beam size is the dictionary size - 1, since we never select pad
+ self.beam_size = min(beam_size, self.vocab_size - 1)
+ self.max_len_a = max_len_a
+ self.max_len_b = max_len_b
+ self.min_len = min_len
+ self.max_len = max_len or self.model.max_decoder_positions()
+
+ self.normalize_scores = normalize_scores
+ self.len_penalty = len_penalty
+ self.unk_penalty = unk_penalty
+ self.temperature = temperature
+ self.match_source_len = match_source_len
+
+ if no_repeat_ngram_size > 0:
+ self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
+ else:
+ self.repeat_ngram_blocker = None
+
+ assert temperature > 0, "--temperature must be greater than 0"
+
+ self.search = (
+ search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
+ )
+ # We only need to set src_lengths in LengthConstrainedBeamSearch.
+ # As a module attribute, setting it would break in multithread
+ # settings when the model is shared.
+ self.should_set_src_lengths = (
+ hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
+ )
+
+ self.model.eval()
+
+ self.lm_model = lm_model
+ self.lm_weight = lm_weight
+ self.ctc_weight = ctc_weight
+ if self.lm_model is not None:
+ self.lm_model.eval()
+
+ def cuda(self):
+ self.model.cuda()
+ return self
+
+ @torch.no_grad()
+ def forward(
+ self,
+ sample: Dict[str, Dict[str, Tensor]],
+ prefix_tokens: Optional[Tensor] = None,
+ bos_token: Optional[int] = None,
+ ):
+ """Generate a batch of translations.
+ Args:
+ sample (dict): batch
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
+ with these tokens
+ bos_token (int, optional): beginning of sentence token
+ (default: self.eos)
+ """
+ return self._generate(sample, prefix_tokens, bos_token=bos_token)
+
+ # TODO(myleott): unused, deprecate after pytorch-translate migration
+ def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
+ """Iterate over a batched dataset and yield individual translations.
+ Args:
+ cuda (bool, optional): use GPU for generation
+ timer (StopwatchMeter, optional): time generations
+ """
+ for sample in data_itr:
+ s = utils.move_to_cuda(sample) if cuda else sample
+ if "net_input" not in s:
+ continue
+ input = s["net_input"]
+ # model.forward normally channels prev_output_tokens into the decoder
+ # separately, but SequenceGenerator directly calls model.encoder
+ encoder_input = {
+ k: v for k, v in input.items() if k != "prev_output_tokens"
+ }
+ if timer is not None:
+ timer.start()
+ with torch.no_grad():
+ hypos = self.generate(encoder_input)
+ if timer is not None:
+ timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
+ for i, id in enumerate(s["id"].data):
+ # remove padding
+ src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
+ ref = (
+ utils.strip_pad(s["target"].data[i, :], self.pad)
+ if s["target"] is not None
+ else None
+ )
+ yield id, src, ref, hypos[i]
+
+ @torch.no_grad()
+ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
+ """Generate translations. Match the api of other fairseq generators.
+ Args:
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
+ sample (dict): batch
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
+ with these tokens
+ constraints (torch.LongTensor, optional): force decoder to include
+ the list of constraints
+ bos_token (int, optional): beginning of sentence token
+ (default: self.eos)
+ """
+ return self._generate(sample, **kwargs)
+
+ def _generate(
+ self,
+ sample: Dict[str, Dict[str, Tensor]],
+ prefix_tokens: Optional[Tensor] = None,
+ constraints: Optional[Tensor] = None,
+ bos_token: Optional[int] = None,
+ ):
+ incremental_states = torch.jit.annotate(
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
+ [
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
+ for i in range(self.model.models_size)
+ ],
+ )
+ net_input = sample["net_input"]
+
+ if "src_tokens" in net_input:
+ src_tokens = net_input["src_tokens"]
+ # length of the source text being the character length except EndOfSentence and pad
+ src_lengths = (
+ (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
+ )
+ elif "source" in net_input:
+ src_tokens = net_input["source"]
+ src_lengths = (
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
+ if net_input["padding_mask"] is not None
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
+ )
+ elif "features" in net_input:
+ src_tokens = net_input["features"]
+ src_lengths = (
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
+ if net_input["padding_mask"] is not None
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
+ )
+ else:
+ raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
+
+ # bsz: total number of sentences in beam
+ # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
+ bsz, src_len = src_tokens.size()[:2]
+ beam_size = self.beam_size
+
+ if constraints is not None and not self.search.supports_constraints:
+ raise NotImplementedError(
+ "Target-side constraints were provided, but search method doesn't support them"
+ )
+
+ # Initialize constraints, when active
+ self.search.init_constraints(constraints, beam_size)
+
+ max_len: int = -1
+ if self.match_source_len:
+ max_len = src_lengths.max().item()
+ else:
+ max_len = min(
+ int(self.max_len_a * src_len + self.max_len_b),
+ self.max_len - 1,
+ )
+ assert (
+ self.min_len <= max_len
+ ), "min_len cannot be larger than max_len, please adjust these!"
+ # compute the encoder output for each beam
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
+ encoder_outs = self.model.forward_encoder(net_input)
+
+ # Get CTC lprobs and prep ctc_scorer
+ if self.ctc_weight > 0:
+ ctc_lprobs = self.model.models[0].get_normalized_probs(
+ encoder_outs[0], log_probs=True
+ ).contiguous().transpose(0, 1) # (B, T, C) from the encoder
+
+ hyp = {}
+ ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy)
+ hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
+ hyp["ctc_score_prev"] = 0.0
+ ctc_beam = min(ctc_lprobs.shape[-1], int(beam_size * CTC_SCORING_RATIO))
+ ctc_hyps = {str(self.eos): hyp}
+
+ # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
+ new_order = new_order.to(src_tokens.device).long()
+ encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
+ # ensure encoder_outs is a List.
+ assert encoder_outs is not None
+
+ # initialize buffers
+ scores = (
+ torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
+ ) # +1 for eos; pad is never chosen for scoring
+ tokens = (
+ torch.zeros(bsz * beam_size, max_len + 2)
+ .to(src_tokens)
+ .long()
+ .fill_(self.pad)
+ ) # +2 for eos and pad
+ tokens[:, 0] = self.eos if bos_token is None else bos_token
+ attn: Optional[Tensor] = None
+
+ # A list that indicates candidates that should be ignored.
+ # For example, suppose we're sampling and have already finalized 2/5
+ # samples. Then cands_to_ignore would mark 2 positions as being ignored,
+ # so that we only finalize the remaining 3 samples.
+ cands_to_ignore = (
+ torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
+ ) # forward and backward-compatible False mask
+
+ # list of completed sentences
+ finalized = torch.jit.annotate(
+ List[List[Dict[str, Tensor]]],
+ [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
+ ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
+
+ # a boolean array indicating if the sentence at the index is finished or not
+ finished = [False for i in range(bsz)]
+ num_remaining_sent = bsz # number of sentences remaining
+
+ # number of candidate hypos per step
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
+
+ # offset arrays for converting between different indexing schemes
+ bbsz_offsets = (
+ (torch.arange(0, bsz) * beam_size)
+ .unsqueeze(1)
+ .type_as(tokens)
+ .to(src_tokens.device)
+ )
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
+
+ reorder_state: Optional[Tensor] = None
+ batch_idxs: Optional[Tensor] = None
+
+ original_batch_idxs: Optional[Tensor] = None
+ if "id" in sample and isinstance(sample["id"], Tensor):
+ original_batch_idxs = sample["id"]
+ else:
+ original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
+
+ for step in range(max_len + 1): # one extra step for EOS marker
+ # reorder decoder internal states based on the prev choice of beams
+ if reorder_state is not None:
+ if batch_idxs is not None:
+ # update beam indices to take into account removed sentences
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
+ batch_idxs
+ )
+ reorder_state.view(-1, beam_size).add_(
+ corr.unsqueeze(-1) * beam_size
+ )
+ original_batch_idxs = original_batch_idxs[batch_idxs]
+ self.model.reorder_incremental_state(incremental_states, reorder_state)
+ encoder_outs = self.model.reorder_encoder_out(
+ encoder_outs, reorder_state
+ )
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
+ lprobs, avg_attn_scores = self.model.forward_decoder(
+ tokens[:, : step + 1],
+ encoder_outs,
+ incremental_states,
+ self.temperature,
+ )
+
+ if self.ctc_weight > 0 and step != 0:
+ # lprobs[:, self.blank] = -math.inf # never select blank
+ ctc_lprobs = lprobs.clone()
+ ctc_lprobs[:, self.blank] = -math.inf # never select blank
+ _, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
+ for b in range(tokens.size(0)):
+ hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
+ ctc_scores, ctc_states = ctc_prefix_score(
+ tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
+ )
+ lprobs[b] = lprobs[b]
+ lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
+ ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
+ ).to(device="cuda")
+ for j in range(len(local_best_ids[b])):
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
+
+ elif self.ctc_weight > 0 and step == 0:
+ ctc_lprobs = lprobs.clone()
+ ctc_lprobs[:, self.blank] = -math.inf # never select blank
+ _, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
+ for b in range(tokens.size(0)):
+ hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
+ ctc_scores, ctc_states = ctc_prefix_score(
+ tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
+ )
+ lprobs[b] = lprobs[b]
+ lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
+ ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
+ ).to(device="cuda")
+ for j in range(len(local_best_ids[b])):
+ if b == 0:
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
+
+ if self.lm_model is not None:
+ lm_out = self.lm_model(tokens[:, : step + 1])
+ probs = self.lm_model.get_normalized_probs(
+ lm_out, log_probs=True, sample=None
+ )
+ probs = probs[:, -1, :] * self.lm_weight
+ lprobs += probs
+ # handle prefix tokens (possibly with different lengths)
+ if (
+ prefix_tokens is not None
+ and step < prefix_tokens.size(1)
+ and step < max_len
+ ):
+ lprobs, tokens, scores = self._prefix_tokens(
+ step, lprobs, scores, tokens, prefix_tokens, beam_size
+ )
+ elif step < self.min_len:
+ # minimum length constraint (does not apply if using prefix_tokens)
+ lprobs[:, self.eos] = -math.inf
+
+ lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
+
+ lprobs[:, self.pad] = -math.inf # never select pad
+ lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
+ lprobs[:, self.blank] = -math.inf # never select blank
+
+ # handle max length constraint
+ if step >= max_len:
+ lprobs[:, : self.eos] = -math.inf
+ lprobs[:, self.eos + 1 :] = -math.inf
+
+ # Record attention scores, only support avg_attn_scores is a Tensor
+ if avg_attn_scores is not None:
+ if attn is None:
+ attn = torch.empty(
+ bsz * beam_size, avg_attn_scores.size(1), max_len + 2
+ ).to(scores)
+ attn[:, :, step + 1].copy_(avg_attn_scores)
+
+ scores = scores.type_as(lprobs)
+ eos_bbsz_idx = torch.empty(0).to(
+ tokens
+ ) # indices of hypothesis ending with eos (finished sentences)
+ eos_scores = torch.empty(0).to(
+ scores
+ ) # scores of hypothesis ending with eos (finished sentences)
+
+ if self.should_set_src_lengths:
+ self.search.set_src_lengths(src_lengths)
+
+ if self.repeat_ngram_blocker is not None:
+ lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
+
+ # Shape: (batch, cand_size)
+ cand_scores, cand_indices, cand_beams = self.search.step(
+ step,
+ lprobs.view(bsz, -1, self.vocab_size),
+ scores.view(bsz, beam_size, -1)[:, :, :step],
+ tokens[:, : step + 1],
+ original_batch_idxs,
+ )
+
+ # cand_bbsz_idx contains beam indices for the top candidate
+ # hypotheses, with a range of values: [0, bsz*beam_size),
+ # and dimensions: [bsz, cand_size]
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
+
+ # finalize hypotheses that end in eos
+ # Shape of eos_mask: (batch size, beam size)
+ eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
+ eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
+
+ # only consider eos when it's among the top beam_size indices
+ # Now we know what beam item(s) to finish
+ # Shape: 1d list of absolute-numbered
+ eos_bbsz_idx = torch.masked_select(
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
+ )
+
+ finalized_sents: List[int] = []
+ if eos_bbsz_idx.numel() > 0:
+ eos_scores = torch.masked_select(
+ cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
+ )
+
+ finalized_sents = self.finalize_hypos(
+ step,
+ eos_bbsz_idx,
+ eos_scores,
+ tokens,
+ scores,
+ finalized,
+ finished,
+ beam_size,
+ attn,
+ src_lengths,
+ max_len,
+ )
+ num_remaining_sent -= len(finalized_sents)
+
+ assert num_remaining_sent >= 0
+ if num_remaining_sent == 0:
+ break
+ if self.search.stop_on_max_len and step >= max_len:
+ break
+ assert step < max_len, f"{step} < {max_len}"
+
+ # Remove finalized sentences (ones for which {beam_size}
+ # finished hypotheses have been generated) from the batch.
+ if len(finalized_sents) > 0:
+ new_bsz = bsz - len(finalized_sents)
+
+ # construct batch_idxs which holds indices of batches to keep for the next pass
+ batch_mask = torch.ones(
+ bsz, dtype=torch.bool, device=cand_indices.device
+ )
+ batch_mask[finalized_sents] = False
+ # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
+ batch_idxs = torch.arange(
+ bsz, device=cand_indices.device
+ ).masked_select(batch_mask)
+
+ # Choose the subset of the hypothesized constraints that will continue
+ self.search.prune_sentences(batch_idxs)
+
+ eos_mask = eos_mask[batch_idxs]
+ cand_beams = cand_beams[batch_idxs]
+ bbsz_offsets.resize_(new_bsz, 1)
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
+ cand_scores = cand_scores[batch_idxs]
+ cand_indices = cand_indices[batch_idxs]
+
+ if prefix_tokens is not None:
+ prefix_tokens = prefix_tokens[batch_idxs]
+ src_lengths = src_lengths[batch_idxs]
+ cands_to_ignore = cands_to_ignore[batch_idxs]
+
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+ if attn is not None:
+ attn = attn.view(bsz, -1)[batch_idxs].view(
+ new_bsz * beam_size, attn.size(1), -1
+ )
+ bsz = new_bsz
+ else:
+ batch_idxs = None
+
+ # Set active_mask so that values > cand_size indicate eos hypos
+ # and values < cand_size indicate candidate active hypos.
+ # After, the min values per row are the top candidate active hypos
+
+ # Rewrite the operator since the element wise or is not supported in torchscript.
+
+ eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
+ active_mask = torch.add(
+ eos_mask.type_as(cand_offsets) * cand_size,
+ cand_offsets[: eos_mask.size(1)],
+ )
+
+ # get the top beam_size active hypotheses, which are just
+ # the hypos with the smallest values in active_mask.
+ # {active_hypos} indicates which {beam_size} hypotheses
+ # from the list of {2 * beam_size} candidates were
+ # selected. Shapes: (batch size, beam size)
+ new_cands_to_ignore, active_hypos = torch.topk(
+ active_mask, k=beam_size, dim=1, largest=False
+ )
+
+ # update cands_to_ignore to ignore any finalized hypos.
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
+ # Make sure there is at least one active item for each sentence in the batch.
+ assert (~cands_to_ignore).any(dim=1).all()
+
+ # update cands_to_ignore to ignore any finalized hypos
+
+ # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
+ # can be selected more than once).
+ active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
+ active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
+
+ active_bbsz_idx = active_bbsz_idx.view(-1)
+ active_scores = active_scores.view(-1)
+
+ # copy tokens and scores for active hypotheses
+
+ # Set the tokens for each beam (can select the same row more than once)
+ tokens[:, : step + 1] = torch.index_select(
+ tokens[:, : step + 1], dim=0, index=active_bbsz_idx
+ )
+ # Select the next token for each of them
+ tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
+ cand_indices, dim=1, index=active_hypos
+ )
+ if step > 0:
+ scores[:, :step] = torch.index_select(
+ scores[:, :step], dim=0, index=active_bbsz_idx
+ )
+ scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
+ cand_scores, dim=1, index=active_hypos
+ )
+
+ # Update constraints based on which candidates were selected for the next beam
+ self.search.update_constraints(active_hypos)
+
+ # copy attention for active hypotheses
+ if attn is not None:
+ attn[:, :, : step + 2] = torch.index_select(
+ attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
+ )
+
+ # reorder incremental state in decoder
+ reorder_state = active_bbsz_idx
+
+ # sort by score descending
+ for sent in range(len(finalized)):
+ scores = torch.tensor(
+ [float(elem["score"].item()) for elem in finalized[sent]]
+ )
+ _, sorted_scores_indices = torch.sort(scores, descending=True)
+ finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
+ finalized[sent] = torch.jit.annotate(
+ List[Dict[str, Tensor]], finalized[sent]
+ )
+ return finalized
+
+ def _prefix_tokens(
+ self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
+ ):
+ """Handle prefix tokens"""
+ prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
+ prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
+ prefix_mask = prefix_toks.ne(self.pad)
+ lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
+ lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
+ )
+ # if prefix includes eos, then we should make sure tokens and
+ # scores are the same across all beams
+ eos_mask = prefix_toks.eq(self.eos)
+ if eos_mask.any():
+ # validate that the first beam matches the prefix
+ first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
+ :, 0, 1 : step + 1
+ ]
+ eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
+ target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
+ assert (first_beam == target_prefix).all()
+
+ # copy tokens, scores and lprobs from the first beam to all beams
+ tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
+ scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
+ lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
+ return lprobs, tokens, scores
+
+ def replicate_first_beam(self, tensor, mask, beam_size: int):
+ tensor = tensor.view(-1, beam_size, tensor.size(-1))
+ tensor[mask] = tensor[mask][:, :1, :]
+ return tensor.view(-1, tensor.size(-1))
+
+ def finalize_hypos(
+ self,
+ step: int,
+ bbsz_idx,
+ eos_scores,
+ tokens,
+ scores,
+ finalized: List[List[Dict[str, Tensor]]],
+ finished: List[bool],
+ beam_size: int,
+ attn: Optional[Tensor],
+ src_lengths,
+ max_len: int,
+ ):
+ """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
+ A sentence is finalized when {beam_size} finished items have been collected for it.
+ Returns number of sentences (not beam items) being finalized.
+ These will be removed from the batch and not processed further.
+ Args:
+ bbsz_idx (Tensor):
+ """
+ assert bbsz_idx.numel() == eos_scores.numel()
+
+ # clone relevant token and attention tensors.
+ # tokens is (batch * beam, max_len). So the index_select
+ # gets the newly EOS rows, then selects cols 1..{step + 2}
+ tokens_clone = tokens.index_select(0, bbsz_idx)[
+ :, 1 : step + 2
+ ] # skip the first index, which is EOS
+
+ tokens_clone[:, step] = self.eos
+ attn_clone = (
+ attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
+ if attn is not None
+ else None
+ )
+
+ # compute scores per token position
+ pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
+ pos_scores[:, step] = eos_scores
+ # convert from cumulative to per-position scores
+ pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
+
+ # normalize sentence-level scores
+ if self.normalize_scores:
+ eos_scores /= (step + 1) ** self.len_penalty
+
+ # cum_unfin records which sentences in the batch are finished.
+ # It helps match indexing between (a) the original sentences
+ # in the batch and (b) the current, possibly-reduced set of
+ # sentences.
+ cum_unfin: List[int] = []
+ prev = 0
+ for f in finished:
+ if f:
+ prev += 1
+ else:
+ cum_unfin.append(prev)
+ cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
+
+ unfin_idx = bbsz_idx // beam_size
+ sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
+
+ # Create a set of "{sent}{unfin_idx}", where
+ # "unfin_idx" is the index in the current (possibly reduced)
+ # list of sentences, and "sent" is the index in the original,
+ # unreduced batch
+ # For every finished beam item
+ # sentence index in the current (possibly reduced) batch
+ seen = (sent << 32) + unfin_idx
+ unique_seen: List[int] = torch.unique(seen).tolist()
+
+ if self.match_source_len:
+ condition = step > torch.index_select(src_lengths, 0, unfin_idx)
+ eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
+ sent_list: List[int] = sent.tolist()
+ for i in range(bbsz_idx.size()[0]):
+ # An input sentence (among those in a batch) is finished when
+ # beam_size hypotheses have been collected for it
+ if len(finalized[sent_list[i]]) < beam_size:
+ if attn_clone is not None:
+ # remove padding tokens from attn scores
+ hypo_attn = attn_clone[i]
+ else:
+ hypo_attn = torch.empty(0)
+
+ finalized[sent_list[i]].append(
+ {
+ "tokens": tokens_clone[i],
+ "score": eos_scores[i],
+ "attention": hypo_attn, # src_len x tgt_len
+ "alignment": torch.empty(0),
+ "positional_scores": pos_scores[i],
+ }
+ )
+
+ newly_finished: List[int] = []
+ for unique_s in unique_seen:
+ # check termination conditions for this sentence
+ unique_sent: int = unique_s >> 32
+ unique_unfin_idx: int = unique_s - (unique_sent << 32)
+
+ if not finished[unique_sent] and self.is_finished(
+ step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
+ ):
+ finished[unique_sent] = True
+ newly_finished.append(unique_unfin_idx)
+
+ return newly_finished
+
+ def is_finished(
+ self,
+ step: int,
+ unfin_idx: int,
+ max_len: int,
+ finalized_sent_len: int,
+ beam_size: int,
+ ):
+ """
+ Check whether decoding for a sentence is finished, which
+ occurs when the list of finalized sentences has reached the
+ beam size, or when we reach the maximum length.
+ """
+ assert finalized_sent_len <= beam_size
+ if finalized_sent_len == beam_size or step == max_len:
+ return True
+ return False
+
+
+class EnsembleModel(nn.Module):
+ """A wrapper around an ensemble of models."""
+
+ def __init__(self, models):
+ super().__init__()
+ self.models_size = len(models)
+ # method '__len__' is not supported in ModuleList for torch script
+ self.single_model = models[0]
+ self.models = nn.ModuleList(models)
+
+ self.has_incremental: bool = False
+ if all(
+ hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
+ for m in models
+ ):
+ self.has_incremental = True
+
+ def forward(self):
+ pass
+
+ def has_encoder(self):
+ return hasattr(self.single_model, "encoder")
+
+ def has_incremental_states(self):
+ return self.has_incremental
+
+ def max_decoder_positions(self):
+ return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
+
+ @torch.jit.export
+ def forward_encoder(self, net_input: Dict[str, Tensor]):
+ if not self.has_encoder():
+ return None
+ return [model.encoder.forward_torchscript(net_input) for model in self.models]
+
+ @torch.jit.export
+ def forward_decoder(
+ self,
+ tokens,
+ encoder_outs: List[Dict[str, List[Tensor]]],
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
+ temperature: float = 1.0,
+ ):
+ log_probs = []
+ avg_attn: Optional[Tensor] = None
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None
+ for i, model in enumerate(self.models):
+ if self.has_encoder():
+ encoder_out = encoder_outs[i]
+ # decode each model
+ if self.has_incremental_states():
+ decoder_out = model.decoder.forward(
+ tokens,
+ encoder_out=encoder_out,
+ incremental_state=incremental_states[i],
+ )
+ else:
+ if hasattr(model, "decoder"):
+ decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
+ else:
+ decoder_out = model.forward(tokens)
+
+ attn: Optional[Tensor] = None
+ decoder_len = len(decoder_out)
+ if decoder_len > 1 and decoder_out[1] is not None:
+ if isinstance(decoder_out[1], Tensor):
+ attn = decoder_out[1]
+ else:
+ attn_holder = decoder_out[1]["attn"]
+ if isinstance(attn_holder, Tensor):
+ attn = attn_holder
+ elif attn_holder is not None:
+ attn = attn_holder[0]
+ if attn is not None:
+ attn = attn[:, -1, :]
+
+ decoder_out_tuple = (
+ decoder_out[0][:, -1:, :].div_(temperature),
+ None if decoder_len <= 1 else decoder_out[1],
+ )
+ probs = model.get_normalized_probs(
+ decoder_out_tuple, log_probs=True, sample=None
+ )
+ probs = probs[:, -1, :]
+ if self.models_size == 1:
+ return probs, attn
+
+ log_probs.append(probs)
+ if attn is not None:
+ if avg_attn is None:
+ avg_attn = attn
+ else:
+ avg_attn.add_(attn)
+
+ avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
+ self.models_size
+ )
+
+ if avg_attn is not None:
+ avg_attn.div_(self.models_size)
+ return avg_probs, avg_attn
+
+ @torch.jit.export
+ def reorder_encoder_out(
+ self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
+ ):
+ """
+ Reorder encoder output according to *new_order*.
+ Args:
+ encoder_out: output from the ``forward()`` method
+ new_order (LongTensor): desired order
+ Returns:
+ *encoder_out* rearranged according to *new_order*
+ """
+ new_outs: List[Dict[str, List[Tensor]]] = []
+ if not self.has_encoder():
+ return new_outs
+ for i, model in enumerate(self.models):
+ assert encoder_outs is not None
+ new_outs.append(
+ model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
+ )
+ return new_outs
+
+ @torch.jit.export
+ def reorder_incremental_state(
+ self,
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
+ new_order,
+ ):
+ if not self.has_incremental_states():
+ return
+ for i, model in enumerate(self.models):
+ model.decoder.reorder_incremental_state_scripting(
+ incremental_states[i], new_order
+ )
+
+
+class SequenceGeneratorWithAlignment(SequenceGenerator):
+ def __init__(
+ self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
+ ):
+ """Generates translations of a given source sentence.
+ Produces alignments following "Jointly Learning to Align and
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
+ Args:
+ left_pad_target (bool, optional): Whether or not the
+ hypothesis should be left padded or not when they are
+ teacher forced for generating alignments.
+ """
+ super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
+ self.left_pad_target = left_pad_target
+
+ if print_alignment == "hard":
+ self.extract_alignment = utils.extract_hard_alignment
+ elif print_alignment == "soft":
+ self.extract_alignment = utils.extract_soft_alignment
+
+ @torch.no_grad()
+ def generate(self, models, sample, **kwargs):
+ finalized = super()._generate(sample, **kwargs)
+
+ src_tokens = sample["net_input"]["src_tokens"]
+ bsz = src_tokens.shape[0]
+ beam_size = self.beam_size
+ (
+ src_tokens,
+ src_lengths,
+ prev_output_tokens,
+ tgt_tokens,
+ ) = self._prepare_batch_for_alignment(sample, finalized)
+ if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
+ attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
+ else:
+ attn = [
+ finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
+ for i in range(bsz * beam_size)
+ ]
+
+ if src_tokens.device != "cpu":
+ src_tokens = src_tokens.to("cpu")
+ tgt_tokens = tgt_tokens.to("cpu")
+ attn = [i.to("cpu") for i in attn]
+
+ # Process the attn matrix to extract hard alignments.
+ for i in range(bsz * beam_size):
+ alignment = self.extract_alignment(
+ attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
+ )
+ finalized[i // beam_size][i % beam_size]["alignment"] = alignment
+ return finalized
+
+ def _prepare_batch_for_alignment(self, sample, hypothesis):
+ src_tokens = sample["net_input"]["src_tokens"]
+ bsz = src_tokens.shape[0]
+ src_tokens = (
+ src_tokens[:, None, :]
+ .expand(-1, self.beam_size, -1)
+ .contiguous()
+ .view(bsz * self.beam_size, -1)
+ )
+ src_lengths = sample["net_input"]["src_lengths"]
+ src_lengths = (
+ src_lengths[:, None]
+ .expand(-1, self.beam_size)
+ .contiguous()
+ .view(bsz * self.beam_size)
+ )
+ prev_output_tokens = data_utils.collate_tokens(
+ [beam["tokens"] for example in hypothesis for beam in example],
+ self.pad,
+ self.eos,
+ self.left_pad_target,
+ move_eos_to_beginning=True,
+ )
+ tgt_tokens = data_utils.collate_tokens(
+ [beam["tokens"] for example in hypothesis for beam in example],
+ self.pad,
+ self.eos,
+ self.left_pad_target,
+ move_eos_to_beginning=False,
+ )
+ return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
+
+
+class EnsembleModelWithAlignment(EnsembleModel):
+ """A wrapper around an ensemble of models."""
+
+ def __init__(self, models):
+ super().__init__(models)
+
+ def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
+ avg_attn = None
+ for model in self.models:
+ decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
+ attn = decoder_out[1]["attn"][0]
+ if avg_attn is None:
+ avg_attn = attn
+ else:
+ avg_attn.add_(attn)
+ if len(self.models) > 1:
+ avg_attn.div_(len(self.models))
+ return avg_attn
diff --git a/SpeechT5/Speech2C/speech2c/tasks/speech2c_pretraining.py b/SpeechT5/Speech2C/speech2c/tasks/speech2c_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..de275630bb08ad3ffae5120eee93d0c75d9ed8b0
--- /dev/null
+++ b/SpeechT5/Speech2C/speech2c/tasks/speech2c_pretraining.py
@@ -0,0 +1,91 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import logging
+
+from dataclasses import dataclass, field
+from fairseq.data import Dictionary
+from fairseq.tasks import register_task
+from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig, HubertPretrainingTask, LabelEncoder
+from speech2c.data.speech2c_dataset import Speech2cDataset
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Speech2cPretrainingConfig(HubertPretrainingConfig):
+ add_decoder: bool = field(
+ default=False,
+ metadata={"help": "whether to add decoder for CE Loss on code"},
+ )
+
+ # For inference
+ ctc_weight: float = field(
+ default=0.0,
+ metadata={"help": "ctc weight during inference"},
+ )
+
+
+@register_task("speech2c_pretraining", dataclass=Speech2cPretrainingConfig)
+class Speech2cPretrainingTask(HubertPretrainingTask):
+
+ cfg: Speech2cPretrainingConfig
+
+ def load_dictionaries(self):
+ label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
+ dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels]
+ return dictionaries[0] if self.cfg.fine_tuning else dictionaries
+
+ def load_dataset(self, split: str, **kwargs) -> None:
+ manifest = f"{self.cfg.data}/{split}.tsv"
+ dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
+ pad_list = [dict.pad() for dict in dicts]
+ eos_list = [dict.eos() for dict in dicts]
+ procs = [LabelEncoder(dict) for dict in dicts]
+ paths = [
+ f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
+ ]
+
+ # hubert v1: pad_audio=True, random_crop=False;
+ self.datasets[split] = Speech2cDataset(
+ manifest,
+ sample_rate=self.cfg.sample_rate,
+ label_paths=paths,
+ label_rates=self.cfg.label_rate,
+ pad_list=pad_list,
+ eos_list=eos_list,
+ label_processors=procs,
+ max_keep_sample_size=self.cfg.max_keep_size,
+ min_keep_sample_size=self.cfg.min_sample_size,
+ max_sample_size=self.cfg.max_sample_size,
+ pad_audio=self.cfg.pad_audio,
+ normalize=self.cfg.normalize,
+ store_labels=False,
+ random_crop=self.cfg.random_crop,
+ single_target=self.cfg.single_target,
+ tgt_dict=dicts[0],
+ add_decoder=self.cfg.add_decoder,
+ fine_tuning=self.cfg.fine_tuning,
+ )
+
+ def build_generator(
+ self,
+ models,
+ args,
+ seq_gen_cls=None,
+ extra_gen_cls_kwargs=None,
+ ):
+ from speech2c.squence_generator import SequenceGenerator
+ extra_gen_cls_kwargs = {
+ "ctc_weight": self.cfg.ctc_weight,
+ **extra_gen_cls_kwargs
+ }
+ return super().build_generator(
+ models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
+ )
diff --git a/SpeechT5/Speech2S/README.md b/SpeechT5/Speech2S/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fc827e237111d872dac19ce407b8d11e52a5ee44
--- /dev/null
+++ b/SpeechT5/Speech2S/README.md
@@ -0,0 +1,64 @@
+# Speech2S
+
+
+ [**Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation**](https://arxiv.org/abs/2210.17027)
+
+
+- (Updating) Nov. 2022: release the code and models
+- Nov. 2022: release preprint in [arXiv](https://arxiv.org/abs/2210.17027)
+
+## Pre-Trained and Fine-tuned Models
+
+| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
+| :------: | :----------------------------------------------: | :-----------------: | :-----: |
+| Speech2S_enes | Voxpopuli_en_v2 | - | [Google Drive](https://drive.google.com/file/d/1TYypFiEKoCixUro8FTTG23bRZYwAxhkX/view?usp=share_link) |
+| Speech2S_enes | Voxpopuli_en_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/11RxeKznSrHcoP_KK9A1VgwRt3fNh_U_C/view?usp=share_link) |
+| Speech2S_esen | Voxpopuli_es_v2 | - | [Google Drive](https://drive.google.com/file/d/1NoC7W-UtQZ-ugIptF1ex0ZlGJncsT1S4/view?usp=share_link) |
+| Speech2S_esen | Voxpopuli_es_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/1eNcKw4ZWGmcABWXJxlf6MKocmiPrKSkH/view?usp=share_link) |
+
+
+## Setup
+```
+cd Speech2S/speech2s
+pip install --editable fairseq/
+```
+
+## Data Preparation
+Please follow the steps of data preparation for S2ST in [here](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md).
+
+## Pre-Training
+```
+cd speech2s/stpretrain_scripts
+base_sc2c_enes.sh
+```
+## Finetune
+```
+cd speech2s/stpretrain_scripts
+finetune_enes.sh
+```
+## Inference
+```
+cd speech2s/stpretrain_scripts
+inference_ed.sh
+```
+## Results on Voxpopuli and Covst
+
+
+## License
+
+This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
+Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
+
+[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
+
+## Reference
+
+If you find our work is useful in your research, please cite the following paper:
+```bibtex
+@article{wei2022joint,
+ title={Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation},
+ author={Wei, Kun and Zhou, Long and Zhang, Ziqiang and Chen, Liping and Liu, Shujie and He, Lei and Li, Jinyu and Wei, Furu},
+ journal={arXiv preprint arXiv:2210.17027},
+ year={2022}
+}
+```
diff --git a/SpeechT5/Speech2S/speech2s/__init__.py b/SpeechT5/Speech2S/speech2s/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..97327d269e93a13cd135f6c1a187fd820a8decb8
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/__init__.py
@@ -0,0 +1 @@
+from . import data, tasks, criterions, models
diff --git a/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_base_100h.yaml b/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_base_100h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..736c3c72b9a7ba85eacaf44e1952fa7f0fc15a4f
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_base_100h.yaml
@@ -0,0 +1,101 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 100
+ tensorboard_logdir: tblog
+ seed: 1337
+
+checkpoint:
+ save_interval: 1
+ keep_last_epochs: 1
+ keep_best_checkpoints: 5
+ best_checkpoint_metric: dec_accuracy
+ maximize_best_checkpoint_metric: true
+ restore_file: checkpoint_last.pt
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ find_unused_parameters: true
+ distributed_world_size: 1
+ distributed_port: -1
+ nprocs_per_node: 8
+
+task:
+ _name: joint_sc2t_pretraining
+ data: ???
+ fine_tuning: true
+ label_dir: ???
+ normalize: false # must be consistent with pre-training
+ labels: ["ltr"]
+ store_labels: true
+ single_target: true
+ add_decoder_target: true
+ pad_audio: false
+ random_crop: true
+ hubert_tokenizer: "none"
+ sp_path: None
+
+dataset:
+ num_workers: 0
+ max_tokens: 1300000
+ skip_invalid_size_inputs_valid_test: true
+ train_subset: train_100
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: ctc_ce
+ zero_infinity: true
+
+optimization:
+ max_update: 40000
+ lr: [0.00001]
+ sentence_avg: true
+ update_freq: [2]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+ weight_decay: 0.0
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: speechut_asr
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
+ add_decoder: true
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ - criterion.wer_kenlm_model
+ - criterion.wer_lexicon
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_100h.yaml b/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_100h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7cbc59e61f10ab00b997286d6355f22ce1008677
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_100h.yaml
@@ -0,0 +1,102 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 100
+ tensorboard_logdir: tblog
+ seed: 1337
+
+checkpoint:
+ save_interval: 1
+ keep_last_epochs: 5
+ keep_best_checkpoints: 5
+ best_checkpoint_metric: dec_accuracy
+ maximize_best_checkpoint_metric: true
+ restore_file: checkpoint_last.pt
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ find_unused_parameters: true
+ distributed_world_size: 16
+ distributed_port: -1
+ nprocs_per_node: 8
+
+task:
+ _name: joint_sc2t_pretraining
+ data: ???
+ fine_tuning: true
+ label_dir: ???
+ normalize: true # must be consistent with pre-training
+ labels: ["ltr"]
+ store_labels: true
+ single_target: true
+ add_decoder_target: true
+ pad_audio: false
+ random_crop: true
+ hubert_tokenizer: "none"
+ sp_path: None
+
+dataset:
+ num_workers: 0
+ max_tokens: 1300000
+ skip_invalid_size_inputs_valid_test: true
+ train_subset: train_100
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: ctc_ce
+ zero_infinity: true
+
+optimization:
+ max_update: 40000
+ lr: [0.00001]
+ sentence_avg: true
+ update_freq: [2]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+ weight_decay: 0.0
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: speechut_asr
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.5
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.0
+ activation_dropout: 0.1
+ attention_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
+ add_decoder: true
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ - criterion.wer_kenlm_model
+ - criterion.wer_lexicon
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_960h.yaml b/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_960h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f10d6002555e5cbcfbf31035d8258e77abc26050
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_960h.yaml
@@ -0,0 +1,100 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 100
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_interval: 1
+ keep_last_epochs: 5
+ keep_best_checkpoints: 5
+ best_checkpoint_metric: dec_accuracy
+ maximize_best_checkpoint_metric: true
+ restore_file: checkpoint_last.pt
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ find_unused_parameters: true
+ distributed_world_size: 24
+ distributed_port: -1
+ nprocs_per_node: 8
+
+task:
+ _name: joint_sc2t_pretraining
+ data: ???
+ fine_tuning: true
+ label_dir: ???
+ normalize: true # must be consistent with pre-training
+ labels: ["ltr"]
+ store_labels: true
+ single_target: true
+ add_decoder_target: true
+ pad_audio: false
+ random_crop: true
+ hubert_tokenizer: "none"
+ sp_path: None
+
+dataset:
+ num_workers: 0
+ max_tokens: 1300000
+ skip_invalid_size_inputs_valid_test: true
+ train_subset: train_960
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: ctc_ce
+ zero_infinity: true
+
+optimization:
+ max_update: 40000
+ lr: [0.00001]
+ sentence_avg: true
+ update_freq: [2]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+ weight_decay: 0.0
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: speechut_asr
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.5
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.0
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
+ add_decoder: true
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ - criterion.wer_kenlm_model
+ - criterion.wer_lexicon
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/config/pretrain/speechut_base_librispeech.yaml b/SpeechT5/Speech2S/speech2s/config/pretrain/speechut_base_librispeech.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6a3751febf2efc3cbf7a91e3a75f05b570559f2c
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/config/pretrain/speechut_base_librispeech.yaml
@@ -0,0 +1,153 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ seed: 1337
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_dir: ???
+ save_interval: 4
+ keep_last_epochs: 4
+ save_interval_updates: 50000
+ keep_interval_updates: -1
+ keep_interval_updates_pattern: 50000
+ # no_epoch_checkpoints: true
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_port: -1
+ distributed_world_size: 32
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: joint_sc2t_pretraining
+ data: ???
+ label_dir: ???
+ labels: ???
+ label_rate: ${model.label_rate}
+ store_labels: true
+ sample_rate: 16000
+ max_sample_size: 250000
+ min_sample_size: 32000
+ pad_audio: false
+ random_crop: true
+ normalize: false # must be consistent with extractor
+ add_decoder_target: true
+ text_cfg:
+ seed: ${common.seed}
+ text_data: ???
+ data_config: config.yaml
+ sample_break_mode: eos
+ tokens_per_sample: 1024
+ shorten_method: "random_crop"
+ text_maxtokens_ratio: 1.5
+
+dataset:
+ num_workers: 6
+ max_tokens: 1400000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: ${checkpoint.save_interval}
+ validate_interval_updates: ${checkpoint.save_interval_updates}
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: speechut_criterion
+ pred_masked_weight: 1.0
+ pred_nomask_weight: 0.0
+ loss_weights: [10,]
+ label_smoothing: 0.1
+ u2t_ed_weight: 0.1
+ u2t_ctc_weight: 0.1
+ text_mum_weight: 0.5
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 10.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: speechut
+ label_rate: ???
+ skip_masked: false
+ skip_nomask: false
+ mask_prob: 0.80
+ extractor_mode: default
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
+ final_dim: 256
+ activation_fn: "gelu"
+ encoder_layers: 6
+ encoder_attention_heads: 8
+ encoder_layerdrop: 0.0
+ dropout_input: 0.1
+ dropout_features: 0.1
+ dropout: 0.1
+ attention_dropout: 0.1
+ feature_grad_mult: 0.1
+ untie_final_proj: true
+ activation_dropout: 0.0
+ use_rel_pos_enc: true
+ add_unit_encoder: true
+ add_text_ctc: true
+ mask_u2t: false
+ mix_with_unit: true
+ add_decoder: true
+ reset_decoder_embedding_config: true
+ text_transformer:
+ activation_fn: ${model.activation_fn}
+ dropout: ${model.dropout}
+ attention_dropout: ${model.attention_dropout}
+ activation_dropout: ${model.activation_dropout}
+ max_source_positions: 3000
+ max_target_positions: 3000
+ no_scale_embedding: true
+ layernorm_embedding: true
+ no_token_positional_embeddings: false
+ share_decoder_input_output_embed: false
+ encoder:
+ embed_dim: 768
+ ffn_embed_dim: 3072
+ layers: 6
+ attention_heads: 8
+ normalize_before: false
+ learned_pos: true
+ layerdrop: ${model.encoder_layerdrop}
+ decoder:
+ layerdrop: 0.1
+ embed_dim: 768
+ ffn_embed_dim: 3072
+ layers: 6
+ attention_heads: 12
+ normalize_before: false
+ learned_pos: false
+ output_dim: 768
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/config/pretrain/speechut_large_librilight.yaml b/SpeechT5/Speech2S/speech2s/config/pretrain/speechut_large_librilight.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..849c1d986126f6e26f3e10feb14fae0a299be4b4
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/config/pretrain/speechut_large_librilight.yaml
@@ -0,0 +1,159 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_scale_tolerance: 0.1 # alleviate fp16 overflow issue
+ log_format: json
+ log_interval: 200
+ seed: 1234
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_dir: ???
+ save_interval: 1
+ keep_last_epochs: 4
+ save_interval_updates: 10000
+ keep_interval_updates: -1
+ keep_interval_updates_pattern: 10000
+ # no_epoch_checkpoints: true
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_port: -1
+ distributed_world_size: 128
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: joint_sc2t_pretraining
+ data: ???
+ label_dir: ???
+ labels: ???
+ label_rate: ${model.label_rate}
+ store_labels: true
+ sample_rate: 16000
+ max_sample_size: 250000
+ min_sample_size: 32000
+ pad_audio: false
+ random_crop: true
+ normalize: true # must be consistent with extractor
+ add_decoder_target: true
+ text_cfg:
+ seed: ${common.seed}
+ text_data: ???
+ data_config: config.yaml
+ sample_break_mode: eos
+ tokens_per_sample: 1024
+ shorten_method: "random_crop"
+ text_maxtokens_ratio: 1.4
+
+dataset:
+ num_workers: 6
+ max_tokens: 900000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: ${checkpoint.save_interval}
+ validate_interval_updates: ${checkpoint.save_interval_updates}
+ required_batch_size_multiple: 2
+
+criterion:
+ _name: speechut_criterion
+ pred_masked_weight: 1.0
+ pred_nomask_weight: 0.0
+ loss_weights: [10,]
+ label_smoothing: 0.1
+ u2t_ed_weight: 0.1
+ u2t_ctc_weight: 0.1
+ text_mum_weight: 0.5
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 1.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+ end_learning_rate: 0.00015 # for future longger pre-training, e.g. 600K step
+
+model:
+ _name: speechut
+ label_rate: ???
+ encoder_embed_dim: 1024
+ encoder_ffn_embed_dim: 4096
+ skip_masked: false
+ skip_nomask: false
+ mask_prob: 0.80
+ extractor_mode: layer_norm
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
+ final_dim: 768
+ activation_fn: "gelu"
+ encoder_layers: 12
+ encoder_attention_heads: 16
+ encoder_layerdrop: 0.0
+ dropout_input: 0.0
+ dropout_features: 0.0
+ dropout: 0.0
+ attention_dropout: 0.0
+ layer_norm_first: true
+ feature_grad_mult: 1.0
+ untie_final_proj: true
+ activation_dropout: 0.0
+ use_rel_pos_enc: true
+ add_unit_encoder: true
+ add_text_ctc: true
+ mask_u2t: false
+ mix_with_unit: true
+ add_decoder: true
+ reset_decoder_embedding_config: true
+ scaling_for_att: 32 # alleviate fp16 overflow issue
+ text_transformer:
+ activation_fn: ${model.activation_fn}
+ dropout: ${model.dropout}
+ attention_dropout: ${model.attention_dropout}
+ activation_dropout: ${model.activation_dropout}
+ max_source_positions: 3000
+ max_target_positions: 3000
+ no_scale_embedding: true
+ layernorm_embedding: true
+ no_token_positional_embeddings: true
+ share_decoder_input_output_embed: false
+ encoder:
+ embed_dim: 1024
+ ffn_embed_dim: 4096
+ layers: 12
+ attention_heads: 16
+ normalize_before: false
+ learned_pos: true
+ layerdrop: ${model.encoder_layerdrop}
+ decoder:
+ layerdrop: 0.1
+ embed_dim: 768
+ ffn_embed_dim: 3072
+ layers: 6
+ attention_heads: 12
+ normalize_before: false
+ learned_pos: false
+ output_dim: 768
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/criterions/__init__.py b/SpeechT5/Speech2S/speech2s/criterions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf9fac9a8c00d76decd07417d86a2625c4c851c
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/criterions/__init__.py
@@ -0,0 +1,9 @@
+import importlib
+import os
+
+for file in os.listdir(os.path.dirname(__file__)):
+ if file.endswith(".py") and not file.startswith("_"):
+ criterion_name = file[: file.find(".py")]
+ importlib.import_module(
+ "speechut.criterions." + criterion_name
+ )
diff --git a/SpeechT5/Speech2S/speech2s/criterions/ctc_ce.py b/SpeechT5/Speech2S/speech2s/criterions/ctc_ce.py
new file mode 100644
index 0000000000000000000000000000000000000000..aab6c9d23ac3b7dc410704bcba8982a697a57656
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/criterions/ctc_ce.py
@@ -0,0 +1,414 @@
+# ----------------------------------------------------------------------------
+# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
+# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
+#
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# ----------------------------------------------------------------------------
+
+import math
+from argparse import Namespace
+from dataclasses import dataclass, field
+from omegaconf import II
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from fairseq import metrics, utils
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
+from fairseq.dataclass import FairseqDataclass
+from fairseq.data.data_utils import post_process
+from fairseq.tasks import FairseqTask
+from fairseq.logging.meters import safe_round
+
+
+@dataclass
+class CtcCeCriterionConfig(FairseqDataclass):
+ zero_infinity: bool = field(
+ default=False,
+ metadata={"help": "zero inf loss when source length <= target length"},
+ )
+ sentence_avg: bool = II("optimization.sentence_avg")
+ post_process: str = field(
+ default="letter",
+ metadata={
+ "help": "how to post process predictions into words. can be letter, "
+ "wordpiece, BPE symbols, etc. "
+ "See fairseq.data.data_utils.post_process() for full list of options"
+ },
+ )
+ wer_kenlm_model: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
+ },
+ )
+ wer_lexicon: Optional[str] = field(
+ default=None,
+ metadata={"help": "lexicon to use with wer_kenlm_model"},
+ )
+ wer_lm_weight: float = field(
+ default=2.0,
+ metadata={"help": "lm weight to use with wer_kenlm_model"},
+ )
+ wer_word_score: float = field(
+ default=-1.0,
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
+ )
+
+ wer_args: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
+ },
+ )
+
+ dec_weight: float = field(
+ default=0.5,
+ metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
+ )
+ report_accuracy: bool = field(
+ default=True,
+ metadata={"help": "report decoder accuracy metric"},
+ )
+ ignore_prefix_size: int = field(
+ default=0,
+ metadata={"help": "Ignore first N tokens"},
+ )
+ label_smoothing: float = field(
+ default=0.1,
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
+ )
+
+
+@register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
+class CtcCeCriterion(FairseqCriterion):
+ def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
+ super().__init__(task)
+ self.blank_idx = (
+ task.target_dictionary.index(task.blank_symbol)
+ if hasattr(task, "blank_symbol")
+ else 0
+ )
+ self.pad_idx = task.target_dictionary.pad()
+ self.eos_idx = task.target_dictionary.eos()
+ self.post_process = cfg.post_process
+
+ if cfg.wer_args is not None:
+ (
+ cfg.wer_kenlm_model,
+ cfg.wer_lexicon,
+ cfg.wer_lm_weight,
+ cfg.wer_word_score,
+ ) = eval(cfg.wer_args)
+
+ if cfg.wer_kenlm_model is not None:
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
+
+ dec_args = Namespace()
+ dec_args.nbest = 1
+ dec_args.criterion = "ctc"
+ dec_args.kenlm_model = cfg.wer_kenlm_model
+ dec_args.lexicon = cfg.wer_lexicon
+ dec_args.beam = 50
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
+ dec_args.lm_weight = cfg.wer_lm_weight
+ dec_args.word_score = cfg.wer_word_score
+ dec_args.unk_weight = -math.inf
+ dec_args.sil_weight = 0
+
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
+ else:
+ self.w2l_decoder = None
+
+ self.zero_infinity = cfg.zero_infinity
+ self.sentence_avg = cfg.sentence_avg
+
+ self.dec_weight = cfg.dec_weight
+ self.report_accuracy = cfg.report_accuracy
+ self.ignore_prefix_size = cfg.ignore_prefix_size
+ self.eps = cfg.label_smoothing
+
+ def forward(self, model, sample, reduce=True):
+ net_output = model(**sample["net_input"])
+ lprobs = model.get_normalized_probs(
+ net_output, log_probs=True
+ ).contiguous() # (T, B, C) from the encoder
+
+ if "src_lengths" in sample["net_input"]:
+ input_lengths = sample["net_input"]["src_lengths"]
+ else:
+ if net_output["padding_mask"] is not None:
+ non_padding_mask = ~net_output["padding_mask"]
+ input_lengths = non_padding_mask.long().sum(-1)
+ else:
+ input_lengths = lprobs.new_full(
+ (lprobs.size(1),), lprobs.size(0), dtype=torch.long
+ )
+
+ pad_mask = (sample["target"] != self.pad_idx) & (
+ sample["target"] != self.eos_idx
+ )
+ targets_flat = sample["target"].masked_select(pad_mask)
+ if "target_lengths" in sample:
+ target_lengths = sample["target_lengths"]
+ else:
+ target_lengths = pad_mask.sum(-1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = F.ctc_loss(
+ lprobs,
+ targets_flat,
+ input_lengths,
+ target_lengths,
+ blank=self.blank_idx,
+ reduction="sum",
+ zero_infinity=self.zero_infinity,
+ )
+
+ ntokens = (
+ sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
+ )
+
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
+
+ logging_output = {}
+ if "decoder_target" in sample:
+ if net_output["decoder_out"] is not None:
+ dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
+ dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
+ logging_output["ctc_loss"] = loss.item()
+ loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
+ logging_output["dec_loss"] = dec_loss.item()
+ logging_output["dec_nll_loss"] = dec_nll_loss.item()
+ logging_output["dec_sample_size"] = dec_sample_size
+
+ if self.report_accuracy:
+ n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
+ logging_output["dec_n_correct"] = utils.item(n_correct.data)
+ logging_output["total"] = utils.item(total.data)
+ else:
+ logging_output["ctc_loss"] = loss.item()
+ loss = (1 - self.dec_weight) * loss
+ logging_output["dec_loss"] = 0
+ logging_output["dec_nll_loss"] = 0
+ logging_output["dec_sample_size"] = 1
+ if self.report_accuracy:
+ logging_output["dec_n_correct"] = 0
+ logging_output["total"] = 1
+
+ logging_output = {
+ "loss": utils.item(loss.data), # * sample['ntokens'],
+ "ntokens": ntokens,
+ "nsentences": sample["id"].numel(),
+ "sample_size": sample_size,
+ **logging_output,
+ }
+
+ if not model.training and self.dec_weight < 1.0:
+ import editdistance
+
+ with torch.no_grad():
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
+
+ c_err = 0
+ c_len = 0
+ w_errs = 0
+ w_len = 0
+ wv_errs = 0
+ for lp, t, inp_l in zip(
+ lprobs_t,
+ sample["target_label"]
+ if "target_label" in sample
+ else sample["target"],
+ input_lengths,
+ ):
+ lp = lp[:inp_l].unsqueeze(0)
+
+ decoded = None
+ if self.w2l_decoder is not None:
+ decoded = self.w2l_decoder.decode(lp)
+ if len(decoded) < 1:
+ decoded = None
+ else:
+ decoded = decoded[0]
+ if len(decoded) < 1:
+ decoded = None
+ else:
+ decoded = decoded[0]
+
+ p = (t != self.task.target_dictionary.pad()) & (
+ t != self.task.target_dictionary.eos()
+ )
+ targ = t[p]
+ targ_units = self.task.target_dictionary.string(targ)
+ targ_units_arr = targ.tolist()
+
+ toks = lp.argmax(dim=-1).unique_consecutive()
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
+
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
+ c_len += len(targ_units_arr)
+
+ targ_words = post_process(targ_units, self.post_process).split()
+
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
+ pred_words_raw = post_process(pred_units, self.post_process).split()
+
+ if decoded is not None and "words" in decoded:
+ pred_words = decoded["words"]
+ w_errs += editdistance.eval(pred_words, targ_words)
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
+ else:
+ dist = editdistance.eval(pred_words_raw, targ_words)
+ w_errs += dist
+ wv_errs += dist
+
+ w_len += len(targ_words)
+
+ logging_output["wv_errors"] = wv_errs
+ logging_output["w_errors"] = w_errs
+ logging_output["w_total"] = w_len
+ logging_output["c_errors"] = c_err
+ logging_output["c_total"] = c_len
+
+ return loss, sample_size, logging_output
+
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ loss, nll_loss = label_smoothed_nll_loss(
+ lprobs,
+ target,
+ self.eps,
+ ignore_index=self.pad_idx,
+ reduce=reduce,
+ )
+ return loss, nll_loss
+
+ def compute_accuracy(self, model, net_output, sample):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ mask = target.ne(self.pad_idx)
+ n_correct = torch.sum(
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
+ )
+ total = torch.sum(mask)
+ return n_correct, total
+
+ def get_lprobs_and_target(self, model, net_output, sample):
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
+ target = sample["decoder_target"]
+ if self.ignore_prefix_size > 0:
+ if getattr(lprobs, "batch_first", False):
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
+ target = target[:, self.ignore_prefix_size :].contiguous()
+ else:
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
+ target = target[self.ignore_prefix_size :, :].contiguous()
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
+
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training."""
+
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
+ nsentences = utils.item(
+ sum(log.get("nsentences", 0) for log in logging_outputs)
+ )
+ sample_size = utils.item(
+ sum(log.get("sample_size", 0) for log in logging_outputs)
+ )
+
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar("ntokens", ntokens)
+ metrics.log_scalar("nsentences", nsentences)
+ if sample_size != ntokens:
+ metrics.log_scalar(
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
+ metrics.log_scalar("_c_errors", c_errors)
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
+ metrics.log_scalar("_c_total", c_total)
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
+ metrics.log_scalar("_w_errors", w_errors)
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
+ metrics.log_scalar("_wv_errors", wv_errors)
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
+ metrics.log_scalar("_w_total", w_total)
+
+ if c_total > 0:
+ metrics.log_derived(
+ "uer",
+ lambda meters: safe_round(
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
+ )
+ if meters["_c_total"].sum > 0
+ else float("nan"),
+ )
+ if w_total > 0:
+ metrics.log_derived(
+ "wer",
+ lambda meters: safe_round(
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
+ )
+ if meters["_w_total"].sum > 0
+ else float("nan"),
+ )
+ metrics.log_derived(
+ "raw_wer",
+ lambda meters: safe_round(
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
+ )
+ if meters["_w_total"].sum > 0
+ else float("nan"),
+ )
+
+ if "dec_loss" in logging_outputs[0]:
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
+ dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
+ dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
+ metrics.log_scalar(
+ "dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
+ )
+ metrics.log_scalar(
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar(
+ "dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
+ )
+ metrics.log_derived(
+ "dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
+ )
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
+ if total > 0:
+ metrics.log_scalar("total", total)
+ n_correct = utils.item(
+ sum(log.get("dec_n_correct", 0) for log in logging_outputs)
+ )
+ metrics.log_scalar("dec_n_correct", n_correct)
+ metrics.log_derived(
+ "dec_accuracy",
+ lambda meters: round(
+ meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
+ )
+ if meters["total"].sum > 0
+ else float("nan"),
+ )
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
diff --git a/SpeechT5/Speech2S/speech2s/criterions/speechut_criterion.py b/SpeechT5/Speech2S/speech2s/criterions/speechut_criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d735f1efd16aebf4146e26d5a5ebaeca2516ad7
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/criterions/speechut_criterion.py
@@ -0,0 +1,384 @@
+# ----------------------------------------------------------------------------
+# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
+# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
+#
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# ----------------------------------------------------------------------------
+
+import logging
+import math
+import re
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from fairseq import metrics, utils
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
+from fairseq.dataclass import FairseqDataclass
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class SpeechUTCriterionConfig(FairseqDataclass):
+ pred_masked_weight: float = field(
+ default=1.0,
+ metadata={"help": "weight for predictive loss for masked frames"},
+ )
+ pred_nomask_weight: float = field(
+ default=0.0,
+ metadata={"help": "weight for predictive loss for unmasked frames"},
+ )
+ loss_weights: Optional[List[float]] = field(
+ default=None,
+ metadata={"help": "weights for additional loss terms (not first one)"},
+ )
+ log_keys: List[str] = field(
+ default_factory=lambda: [],
+ metadata={"help": "output keys to log"},
+ )
+ u2t_ed_weight: float = field(
+ default=0.1,
+ metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
+ )
+ u2t_ctc_weight: float = field(
+ default=0.0,
+ metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
+ )
+ text_mum_weight: float = field(
+ default=0.0,
+ metadata={"help": "masked unit modeling weight from the text end"},
+ )
+ report_accuracy: bool = field(
+ default=True,
+ metadata={"help": "report decoder accuracy metric"},
+ )
+ ignore_prefix_size: int = field(
+ default=0,
+ metadata={"help": "Ignore first N tokens"},
+ )
+ label_smoothing: float = field(
+ default=0.0,
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
+ )
+ no_ctc_blank: bool = field(
+ default=False,
+ metadata={"help": "mask out the blank of ctc, only when dec_loss_type=ctc"},
+ )
+ label_smoothing: float = field(
+ default=0.0,
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
+ )
+
+@register_criterion("speechut_criterion", dataclass=SpeechUTCriterionConfig)
+class SpeechUTCriterion(FairseqCriterion):
+ def __init__(
+ self,
+ task,
+ pred_masked_weight,
+ pred_nomask_weight,
+ loss_weights=None,
+ log_keys=None,
+ u2t_ed_weight=0.1,
+ u2t_ctc_weight=0,
+ text_mum_weight=0,
+ report_accuracy=False,
+ ignore_prefix_size=0,
+ label_smoothing=0,
+ no_ctc_blank=False,
+ ):
+ super().__init__(task)
+ self.pred_masked_weight = pred_masked_weight
+ self.pred_nomask_weight = pred_nomask_weight
+ self.loss_weights = loss_weights
+ self.log_keys = [] if log_keys is None else log_keys
+ self.u2t_ed_weight = u2t_ed_weight
+ self.u2t_ctc_weight = u2t_ctc_weight
+ self.text_mum_weight = text_mum_weight
+ self.report_accuracy = report_accuracy
+ self.ignore_prefix_size = ignore_prefix_size
+ self.eps = label_smoothing
+ self.no_ctc_blank = no_ctc_blank
+ self.padding_idx = task.dictionaries[0].pad()
+ self.eos_idx = task.dictionaries[0].eos()
+ self.blank_idx = task.dictionaries[0].bos()
+
+ def compute_hubert_loss(self, model, net_output, reduction, preffix='', suffix=''):
+ loss = 0
+ sample_size = []
+ logging_output = {}
+ loss_m_list = []
+ logp_m_list = model.get_logits(net_output, True)
+ targ_m_list = model.get_targets(net_output, True)
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
+ loss_m_list.append(loss_m)
+ logging_output[f"{preffix}loss_m_{i}"] = loss_m.detach().item()
+ if self.pred_masked_weight > 0:
+ loss += self.pred_masked_weight * sum(loss_m_list)
+ sample_size.append(targ_m_list[0].numel())
+
+ loss_u_list = []
+ logp_u_list = model.get_logits(net_output, False)
+ targ_u_list = model.get_targets(net_output, False)
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
+ loss_u_list.append(loss_u)
+ logging_output[f"{preffix}loss_u_{i}"] = loss_u.detach().item()
+ if self.pred_nomask_weight > 0:
+ loss += self.pred_nomask_weight * sum(loss_u_list)
+ sample_size.append(targ_u_list[0].numel())
+
+ sample_size = np.mean(sample_size)
+
+ def compute_correct(logits, targets):
+ if logits.numel() == 0:
+ return 0, 0
+ else:
+ assert logits.dim() > 1, logits.shape
+ max = logits.argmax(-1) == targets
+ min = logits.argmin(-1) == targets
+ both = max & min
+ corr = max.long().sum().item() - both.long().sum().item()
+ count = max.numel()
+ return corr, count
+
+ with torch.no_grad():
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
+ corr_m, count_m = compute_correct(logp_m, targ_m)
+ logging_output[f"correct_m_{i}{suffix}"] = corr_m
+ logging_output[f"count_m_{i}{suffix}"] = count_m
+
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
+ corr_u, count_u = compute_correct(logp_u, targ_u)
+ logging_output[f"correct_u_{i}{suffix}"] = corr_u
+ logging_output[f"count_u_{i}{suffix}"] = count_u
+
+ return loss, sample_size, logging_output
+
+
+ def forward(self, model, sample, reduce=True, log_pred=False):
+ """Compute the loss for the given sample.
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ reduction = "sum" if reduce else "none"
+
+ if "net_input" in sample:
+ unit_sample = text_sample = None
+ else:
+ unit_sample = sample.get("text_mono", None)
+ text_sample = sample.get("text_paired", None)
+ assert unit_sample is not None or text_sample is not None
+ sample = sample.get("speech")
+
+ ### 1. S2U: do hubert forward and loss computation
+ sample["modality"] = "speech"
+ net_output = model(target_list=sample["target_list"], **sample["net_input"])
+ loss, sample_size, logging_output = self.compute_hubert_loss(
+ model,
+ net_output,
+ reduction,
+ )
+ if self.loss_weights is not None:
+ assert hasattr(model, "get_extra_losses")
+ extra_losses, names = model.get_extra_losses(net_output)
+ if torch.is_tensor(extra_losses):
+ extra_losses = [extra_losses]
+ names = [names]
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
+ assert len(extra_losses) == len(
+ self.loss_weights
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
+ if coef != 0 and p is not None:
+ p = coef * p.float() * sample_size
+ loss += p
+ logging_output[f"loss_{n}"] = p.item()
+ for lk in self.log_keys:
+ if lk in net_output:
+ logging_output[lk] = float((net_output[lk]))
+
+ ### 2. do text U2T forward and loss computation
+ if text_sample is not None and (self.u2t_ctc_weight + self.u2t_ed_weight) > 0:
+ ## 2.1 re-loading "target_list", in default case, target_list = [src_tokens],
+ ## while in case of using "unit-phone-char" structure, target_list will be [ref_tokens]
+ text_sample["net_input"]["target_list"] = [
+ text_sample.get("ref_tokens", text_sample["net_input"]["src_tokens"].clone()),
+ ]
+ text_net_output = model(**text_sample["net_input"])
+ text_sample_size = text_sample["ntokens"]
+
+ ### 2.1 U2T_UCTC
+ if self.u2t_ctc_weight > 0:
+ text_ctc_loss = self.compute_ctc_loss(model, text_net_output, text_sample["target"], reduction=reduction)
+ loss += self.u2t_ctc_weight * text_ctc_loss * sample_size / text_sample_size
+ logging_output["text_ctc_loss"] = utils.item(text_ctc_loss)
+ logging_output["text_sample_size"] = text_sample_size
+
+ ### 2.2 U2T_ED
+ if self.u2t_ed_weight > 0:
+ text_dec_loss, text_dec_nll_loss = self.compute_ce_loss(model, text_net_output["decoder_out"], text_sample, reduce=reduce)
+ loss += self.u2t_ed_weight * text_dec_loss * sample_size / text_sample_size
+ logging_output["text_dec_loss"] = utils.item(text_dec_loss)
+ logging_output["text_dec_nll_loss"] = utils.item(text_dec_nll_loss)
+ logging_output["text_sample_size"] = text_sample_size
+ if self.report_accuracy:
+ n_correct, total = self.compute_accuracy(model, text_net_output["decoder_out"], text_sample)
+ logging_output["correct_text_dec"] = utils.item(n_correct.data)
+ logging_output["count_text_dec"] = utils.item(total.data)
+
+ ### 3. do unit MUM forward and loss computation
+ if unit_sample is not None and self.text_mum_weight > 0:
+ src_tokens = unit_sample["net_input"]["src_tokens"]
+ target = unit_sample.get("target", None)
+ target = src_tokens.clone() if target is None else target
+ unit_net_output = model.forward_mum(src_tokens, target)
+ loss_num, sample_size_mum, logging_output_mum = self.compute_hubert_loss(
+ model,
+ unit_net_output,
+ reduction,
+ preffix="mum_",
+ suffix="_mum",
+ )
+ loss += self.text_mum_weight * loss_num * sample_size / sample_size_mum
+ logging_output["unit_sample_size"] = sample_size_mum
+ logging_output.update(logging_output_mum)
+
+ logging_output = {
+ "loss": utils.item(loss) if reduce else loss,
+ "ntokens": sample_size,
+ "nsentences": sample["id"].numel() + (text_sample["id"].numel() if text_sample is not None else 0),
+ "sample_size": sample_size,
+ **logging_output,
+ }
+
+ return loss, sample_size, logging_output
+
+ def compute_ctc_loss(self, model, net_output, target, reduction):
+ logits = net_output["encoder_out_ctc"][0] # (T, B, C) from the code-encoder
+ if self.no_ctc_blank:
+ ## set prob of to -inf
+ logits = logits.float()
+ logits[:, :, self.blank_idx] = -1000000.0
+
+ lprobs = F.log_softmax(logits.float(), dim=-1)
+
+ encoder_padding_mask = net_output["encoder_padding_mask"][0]
+ non_padding_mask = ~encoder_padding_mask
+ input_lengths = non_padding_mask.long().sum(-1)
+ pad_mask = (target != self.padding_idx) & (target != self.eos_idx)
+ targets_flat = target.masked_select(pad_mask)
+ target_lengths = pad_mask.sum(-1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = F.ctc_loss(
+ lprobs,
+ targets_flat,
+ input_lengths,
+ target_lengths,
+ blank=self.blank_idx,
+ reduction=reduction,
+ zero_infinity=True,
+ )
+ return loss
+
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ loss, nll_loss = label_smoothed_nll_loss(
+ lprobs,
+ target,
+ self.eps,
+ ignore_index=self.padding_idx,
+ reduce=reduce,
+ )
+ return loss, nll_loss
+
+ def compute_accuracy(self, model, net_output, sample):
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
+ mask = target.ne(self.padding_idx)
+ n_correct = torch.sum(
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
+ )
+ total = torch.sum(mask)
+ return n_correct, total
+
+ def get_lprobs_and_target(self, model, net_output, sample):
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
+ target = sample["target"]
+
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ if sample_size != ntokens:
+ metrics.log_scalar(
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
+ )
+ else:
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
+ )
+
+ counts = {}
+ for lk in logging_outputs[0].keys():
+ if lk.startswith("count_"):
+ val = sum(log.get(lk, 0) for log in logging_outputs)
+ metrics.log_scalar(lk, val)
+ counts[lk] = val
+
+ for lk in logging_outputs[0].keys():
+ if lk.startswith("loss_"):
+ val = sum(log.get(lk, 0) for log in logging_outputs)
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
+ elif lk.startswith("correct_"):
+ val = sum(log.get(lk, 0) for log in logging_outputs)
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
+
+ if "text_sample_size" in logging_outputs[0]:
+ text_sample_size = sum(log.get("text_sample_size", 0) for log in logging_outputs)
+ for lk in logging_outputs[0].keys():
+ if lk.startswith("text_") and lk.endswith("_loss"):
+ val = sum(log.get(lk, 0) for log in logging_outputs)
+ metrics.log_scalar(lk, val / text_sample_size / math.log(2), round=3)
+
+ if "unit_sample_size" in logging_outputs[0]:
+ unit_sample_size = sum(log.get("unit_sample_size", 0) for log in logging_outputs)
+ for lk in logging_outputs[0].keys():
+ if lk.startswith("mum_loss_"):
+ val = sum(log.get(lk, 0) for log in logging_outputs)
+ metrics.log_scalar(lk, val / unit_sample_size / math.log(2), round=3)
+
+ @staticmethod
+ def aggregate_logging_outputs(logging_outputs):
+ """Aggregate logging outputs from data parallel training."""
+ raise NotImplementedError()
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return False
diff --git a/SpeechT5/Speech2S/speech2s/data/concat_dataset.py b/SpeechT5/Speech2S/speech2s/data/concat_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5766921ac39b571010b318e0d4b6f967cd21d96e
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/data/concat_dataset.py
@@ -0,0 +1,129 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+import bisect
+
+import numpy as np
+from torch.utils.data.dataloader import default_collate
+
+from fairseq.data import FairseqDataset
+
+
+class ConcatDataset(FairseqDataset):
+ @staticmethod
+ def cumsum(sequence, sample_ratios):
+ r, s = [], 0
+ for e, ratio in zip(sequence, sample_ratios):
+ curr_len = int(ratio * len(e))
+ r.append(curr_len + s)
+ s += curr_len
+ return r
+
+ def __init__(self, datasets, sample_ratios=1):
+ super(ConcatDataset, self).__init__()
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
+ self.datasets = list(datasets)
+ if isinstance(sample_ratios, int):
+ sample_ratios = [sample_ratios] * len(self.datasets)
+ self.sample_ratios = sample_ratios
+ self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
+ self.real_sizes = [len(d) for d in self.datasets]
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
+ return self.datasets[dataset_idx][sample_idx]
+
+ def _get_dataset_and_sample_index(self, idx: int):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
+ return dataset_idx, sample_idx
+
+ def collater(self, samples, **extra_args):
+ # For now only supports datasets with same underlying collater implementations
+ if hasattr(self.datasets[0], "collater"):
+ return self.datasets[0].collater(samples, **extra_args)
+ else:
+ return default_collate(samples, **extra_args)
+
+ def size(self, idx: int):
+ """
+ Return an example's size as a float or tuple.
+ """
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
+ return self.datasets[dataset_idx].size(sample_idx)
+
+ def num_tokens(self, index: int):
+ return np.max(self.size(index))
+
+ def attr(self, attr: str, index: int):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
+ return getattr(self.datasets[dataset_idx], attr, None)
+
+ @property
+ def sizes(self):
+ _dataset_sizes = []
+ for ds, sr in zip(self.datasets, self.sample_ratios):
+ if isinstance(ds.sizes, np.ndarray):
+ _dataset_sizes.append(np.tile(ds.sizes, sr))
+ else:
+ # Only support underlying dataset with single size array.
+ assert isinstance(ds.sizes, list)
+ _dataset_sizes.append(np.tile(ds.sizes[0], sr))
+ return np.concatenate(_dataset_sizes)
+
+ @property
+ def supports_prefetch(self):
+ return all(d.supports_prefetch for d in self.datasets)
+
+ def ordered_indices(self):
+ """
+ Returns indices sorted by length. So less padding is needed.
+ """
+ if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
+ # special handling for concatenating lang_pair_datasets
+ if getattr(self.datasets[0], "shuffle", False):
+ indices = np.random.permutation(len(self)).astype(np.int64)
+ else:
+ indices = np.arange(len(self), dtype=np.int64)
+ sizes = self.sizes
+ tgt_sizes = (
+ sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
+ )
+ src_sizes = (
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ )
+ # sort by target length, then source length
+ if tgt_sizes is not None:
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
+ return indices[np.argsort(src_sizes[indices], kind="mergesort")]
+ else:
+ return np.argsort(self.sizes)
+
+ def prefetch(self, indices):
+ frm = 0
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
+ real_size = len(ds)
+ if getattr(ds, "supports_prefetch", False):
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
+ frm = to
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
+
+ def set_epoch(self, epoch):
+ super().set_epoch(epoch)
+ for ds in self.datasets:
+ if hasattr(ds, "set_epoch"):
+ ds.set_epoch(epoch)
diff --git a/SpeechT5/Speech2S/speech2s/data/hubert_dataset.py b/SpeechT5/Speech2S/speech2s/data/hubert_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..64965dea445a0a5afc63c887b1bc89cece0b203b
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/data/hubert_dataset.py
@@ -0,0 +1,597 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+import itertools
+import logging
+import io
+import os
+import sys
+import time
+from pathlib import Path
+from typing import Any, List, Optional, Union, Tuple
+
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from fairseq.data import data_utils, Dictionary
+from fairseq.data.fairseq_dataset import FairseqDataset
+from fairseq.data.audio.audio_utils import (
+ read_from_stored_zip,
+ is_sf_audio_data,
+)
+
+FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
+
+logger = logging.getLogger(__name__)
+
+def parse_path(path: str) -> Tuple[str, List[int]]:
+ """Parse data path which is either a path to
+ 1. a .npy/.wav/.flac/.ogg file
+ 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
+
+ Args:
+ path (str): the data path to parse
+
+ Returns:
+ file_path (str): the file path
+ slice_ptr (list of int): empty in case 1;
+ byte offset and length for the slice in case 2
+ """
+
+ if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
+ _path, slice_ptr = path, []
+ else:
+ _path, *slice_ptr = path.split(":")
+ if not Path(_path).is_file():
+ raise FileNotFoundError(f"File not found: {_path}")
+ assert len(slice_ptr) in {0, 1, 2}, f"Invalid path: {path}"
+ slice_ptr = [int(i) for i in slice_ptr]
+ return _path, slice_ptr
+
+def load_audio(manifest_path, max_keep, min_keep, retry_times=5):
+ n_long, n_short = 0, 0
+ names, inds, sizes, chunk_names, chunk_indices = [], [], [], [], []
+ for i in range(retry_times):
+ with open(manifest_path) as f:
+ root = f.readline().strip()
+ for ind, line in enumerate(f):
+ items = line.strip().split("\t")
+ assert len(items) == 2, line
+ sz = int(items[1])
+ if min_keep is not None and sz < min_keep:
+ n_short += 1
+ elif max_keep is not None and sz > max_keep:
+ n_long += 1
+ else:
+ fname = items[0].split(":")
+ if len(fname) > 2:
+ if len(chunk_names) == 0 or fname[0] != chunk_names[-1]:
+ chunk_names.append(fname[0])
+ chunk_indices.append(len(names))
+ names.append(items[0])
+ inds.append(ind)
+ sizes.append(sz)
+ if len(names) == 0:
+ logger.warn(f"Fail to load manifest for the {i} time")
+ time.sleep(1)
+ continue
+ else:
+ break
+ tot = ind + 1
+ logger.info(
+ (
+ f"max_keep={max_keep}, min_keep={min_keep}, "
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
+ )
+ )
+ return root, names, inds, tot, sizes, chunk_names, chunk_indices
+
+
+def load_label(label_path, inds, tot, retry_times=5):
+ for i in range(retry_times):
+ with open(label_path) as f:
+ labels = [line.rstrip() for line in f]
+ if len(labels) == 0:
+ logger.warn(f"Fail to load label for the {i} time")
+ time.sleep(1)
+ continue
+ else:
+ break
+ assert (
+ len(labels) == tot
+ ), f"number of labels does not match ({len(labels)} != {tot})"
+ labels = [labels[i] for i in inds]
+ return labels
+
+
+def load_label_offset(label_path, inds, tot, retry_times=5):
+ for i in range(retry_times):
+ with open(label_path) as f:
+ code_lengths = [len(line.encode("utf-8")) for line in f]
+ if len(code_lengths) == 0:
+ logger.warn(f"Fail to load label for the {i} time")
+ time.sleep(1)
+ continue
+ else:
+ break
+ assert (
+ len(code_lengths) == tot
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
+ offsets = list(itertools.accumulate([0] + code_lengths))
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
+ return offsets
+
+
+def verify_label_lengths(
+ audio_sizes,
+ audio_rate,
+ label_path,
+ label_rate,
+ inds,
+ tot,
+ tol=0.1, # tolerance in seconds
+):
+ if label_rate < 0:
+ logger.info(f"{label_path} is sequence label. skipped")
+ return
+
+ with open(label_path) as f:
+ lengths = [len(line.rstrip().split()) for line in f]
+ assert len(lengths) == tot
+ lengths = [lengths[i] for i in inds]
+ num_invalid = 0
+ for i, ind in enumerate(inds):
+ dur_from_audio = audio_sizes[i] / audio_rate
+ dur_from_label = lengths[i] / label_rate
+ if abs(dur_from_audio - dur_from_label) > tol:
+ logger.warning(
+ (
+ f"audio and label duration differ too much "
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
+ f"is correctly set (currently {label_rate}). "
+ f"num. of samples = {audio_sizes[i]}; "
+ f"label length = {lengths[i]}"
+ )
+ )
+ num_invalid += 1
+ if num_invalid > 0:
+ logger.warning(
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
+ )
+
+
+class HubertDataset(FairseqDataset):
+ def __init__(
+ self,
+ manifest_path: str,
+ sample_rate: float,
+ label_paths: List[str],
+ label_rates: Union[List[float], float], # -1 for sequence labels
+ pad_list: List[str],
+ eos_list: List[str],
+ label_processors: Optional[List[Any]] = None,
+ max_keep_sample_size: Optional[int] = None,
+ min_keep_sample_size: Optional[int] = None,
+ max_sample_size: Optional[int] = None,
+ shuffle: bool = True,
+ pad_audio: bool = False,
+ normalize: bool = False,
+ store_labels: bool = True,
+ random_crop: bool = False,
+ single_target: bool = False,
+ tgt_dict: Optional[Dictionary] = None,
+ add_decoder_target: bool = False,
+ fine_tuning: bool = False,
+ tgt_lang_idx: int = None,
+ tokenizer = None,
+ mbart_style_lang_id: bool = False,
+ retry_times: int = 5,
+ reduce_label_for_dec: bool = True,
+ ):
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.chunk_names, self.chunk_indices = load_audio(
+ manifest_path, max_keep_sample_size, min_keep_sample_size, retry_times
+ )
+ self.sample_rate = sample_rate
+ self.shuffle = shuffle
+ self.random_crop = random_crop
+ self.tgt_dict = tgt_dict
+ self.add_decoder_target = add_decoder_target
+ self.fine_tuning = fine_tuning
+
+ self.num_labels = len(label_paths)
+ self.pad_list = pad_list
+ self.eos_list = eos_list
+ self.label_processors = label_processors
+ self.single_target = single_target
+ self.epoch = 0
+
+ self.label_rates = (
+ [label_rates for _ in range(len(label_paths))]
+ if isinstance(label_rates, int)
+ else label_rates
+ )
+ self.store_labels = store_labels
+ if store_labels:
+ self.label_list = [load_label(p, inds, tot, retry_times) for p in label_paths]
+ else:
+ self.label_paths = label_paths
+ self.label_offsets_list = [
+ load_label_offset(p, inds, tot, retry_times) for p in label_paths
+ ]
+ assert label_processors is None or len(label_processors) == self.num_labels
+ for label_path, label_rate in zip(label_paths, self.label_rates):
+ verify_label_lengths(
+ self.wav_sizes, sample_rate, label_path, label_rate, inds, tot
+ )
+
+ self.max_sample_size = (
+ max_sample_size if max_sample_size is not None else sys.maxsize
+ )
+ self.pad_audio = pad_audio
+ self.normalize = normalize
+ self.tgt_lang_idx = tgt_lang_idx
+ self.tokenizer = tokenizer
+ self.mbart_style_lang_id = mbart_style_lang_id
+ self.retry_times = retry_times
+ self.reduce_label_for_dec = reduce_label_for_dec
+ logger.info(
+ f"pad_audio={pad_audio}, random_crop={random_crop}, tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, "
+ f"mbart_style_lang_id={mbart_style_lang_id}, normalize={normalize}, max_sample_size={self.max_sample_size}"
+ )
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1):
+ self.max_tokens = max_tokens
+ self.max_sentences = max_sentences
+ self.required_batch_size_multiple = required_batch_size_multiple
+ if isinstance(indices[0], np.ndarray):
+ batch_list = []
+ for indice in indices:
+ batch = super(HubertDataset, self).batch_by_size(indice, max_tokens, max_sentences, required_batch_size_multiple)
+ batch_list.append(batch)
+ return batch_list
+ else:
+ return super(HubertDataset, self).batch_by_size(indices, max_tokens, max_sentences, required_batch_size_multiple)
+ def shuffle_batches(self, batches, seed):
+ if isinstance(batches[0], list):
+ new_batches = []
+ with data_utils.numpy_seed(seed):
+ np.random.shuffle(batches)
+ for batch in batches:
+ np.random.shuffle(batch)
+ new_batches.extend(batch)
+ return new_batches
+ else:
+ with data_utils.numpy_seed(seed):
+ np.random.shuffle(batches)
+ return batches
+
+ def get_audio(self, index):
+ import soundfile as sf
+
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
+ _path, slice_ptr = parse_path(wav_path)
+ if len(slice_ptr) == 1:
+ import kaldiio
+ feat = kaldiio.load_mat(wav_path)
+ feat = torch.from_numpy(feat).float()
+ if self.normalize:
+ with torch.no_grad():
+ feat = F.layer_norm(feat, feat.shape[-1])
+ return feat
+ else:
+ if len(slice_ptr) == 2:
+ byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
+ assert is_sf_audio_data(byte_data)
+ wav_path = io.BytesIO(byte_data)
+ for i in range(self.retry_times):
+ if i < self.retry_times - 1:
+ try:
+ wav, cur_sample_rate = sf.read(wav_path)
+ break
+ except Exception as e:
+ logger.warn(f"Fail to load wav for the {i} time")
+ logger.warn(e)
+ time.sleep(1)
+ continue
+ else:
+ wav, cur_sample_rate = sf.read(wav_path)
+
+ wav = torch.from_numpy(wav).float()
+ wav = self.postprocess(wav, cur_sample_rate)
+ return wav
+
+ def get_label(self, index, label_idx):
+ if self.store_labels:
+ label = self.label_list[label_idx][index]
+ else:
+ with open(self.label_paths[label_idx]) as f:
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
+ f.seek(offset_s)
+ label = f.read(offset_e - offset_s)
+
+ if self.tokenizer is not None and self.fine_tuning:
+ label = self.tokenizer.encode(label)
+
+ if self.label_processors is not None:
+ label = self.label_processors[label_idx](label)
+ return label
+
+ def get_labels(self, index):
+ return [self.get_label(index, i) for i in range(self.num_labels)]
+
+ def __getitem__(self, index):
+ wav = self.get_audio(index)
+ labels = self.get_labels(index)
+ return {"id": index, "source": wav, "label_list": labels}
+
+ def __len__(self):
+ return len(self.wav_sizes)
+
+ def crop_to_max_size(self, wav, target_size):
+ size = len(wav)
+ diff = size - target_size
+ if diff <= 0:
+ return wav, 0
+
+ start, end = 0, target_size
+ if self.random_crop:
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return wav[start:end], start
+
+ def collater(self, samples):
+ # target = max(sizes) -> random_crop not used
+ # target = max_sample_size -> random_crop used for long
+ samples = [s for s in samples if s["source"] is not None]
+ if len(samples) == 0:
+ return {}
+
+ audios = [s["source"] for s in samples]
+ audio_sizes = [len(s) for s in audios]
+ if self.pad_audio:
+ audio_size = min(max(audio_sizes), self.max_sample_size)
+ else:
+ audio_size = min(min(audio_sizes), self.max_sample_size)
+ feat_dim = audios[0].size(-1) if audios[0].dim() > 1 else 1
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
+ audios, audio_size, feat_dim,
+ )
+
+ targets_by_label = [
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
+ ]
+ targets_list, lengths_list, ntokens_list = self.collater_label(
+ targets_by_label, audio_size, audio_starts
+ )
+
+ if self.add_decoder_target:
+ if self.fine_tuning:
+ decoder_label = [
+ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
+ for i in range(targets_list[0].size(0))
+ ]
+ else:
+ if self.tokenizer is not None:
+ decoder_label = [
+ # Set 48 for translate int to char and avoid \n
+ torch.cat(
+ (
+ torch.tensor(
+ self.tokenizer.sp.Encode(
+ "".join(
+ [chr(j + 48) for j in (
+ targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]]
+ ).tolist()]
+ ), out_type=int
+ )
+ ),
+ torch.tensor([self.tgt_dict.eos()])
+ ), dim=0
+ ).long()
+ for i in range(targets_list[0].size(0))
+ ]
+ else:
+ decoder_label = [
+ torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
+ for i in range(targets_list[0].size(0))
+ ]
+
+ if self.mbart_style_lang_id:
+ decoder_label = [
+ torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long()
+ for i in range(targets_list[0].size(0))
+ ]
+
+ dec_ntokens = sum(x.size(0) for x in decoder_label)
+ decoder_target = data_utils.collate_tokens(
+ decoder_label,
+ self.tgt_dict.pad(),
+ self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
+ left_pad=False,
+ move_eos_to_beginning=False,
+ )
+ decoder_target_lengths = torch.tensor(
+ [x.size(0) for x in decoder_label], dtype=torch.long
+ )
+ prev_output_tokens = data_utils.collate_tokens(
+ decoder_label,
+ self.tgt_dict.pad(),
+ self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
+ left_pad=False,
+ move_eos_to_beginning=True,
+ )
+
+ if self.tgt_lang_idx is not None and not self.mbart_style_lang_id:
+ assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0
+ prev_output_tokens[:, 0] = self.tgt_lang_idx
+
+ net_input = {
+ "source": collated_audios,
+ "padding_mask": padding_mask,
+ "prev_output_tokens": prev_output_tokens,
+ }
+ batch = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": net_input,
+ "decoder_target": decoder_target,
+ "decoder_target_lengths": decoder_target_lengths,
+ "dec_ntokens": dec_ntokens,
+ "lang_idx": self.tgt_lang_idx,
+ }
+ else:
+ net_input = {"source": collated_audios, "padding_mask": padding_mask}
+ batch = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": net_input,
+ }
+
+ if self.single_target:
+ batch["target_lengths"] = lengths_list[0]
+ batch["ntokens"] = ntokens_list[0]
+ batch["target"] = targets_list[0]
+ else:
+ batch["target_lengths_list"] = lengths_list
+ batch["ntokens_list"] = ntokens_list
+ batch["target_list"] = targets_list
+ return batch
+
+ def collater_audio(self, audios, audio_size, feat_dim=1):
+ collated_audios = audios[0].new_zeros(len(audios), audio_size, feat_dim)
+ padding_mask = (
+ torch.BoolTensor(collated_audios.shape[0:2]).fill_(False)
+ # if self.pad_audio else None
+ )
+ audio_starts = [0 for _ in audios]
+ for i, audio in enumerate(audios):
+ audio = audio.view(-1, feat_dim)
+ diff = len(audio) - audio_size
+ if diff == 0:
+ collated_audios[i] = audio
+ elif diff < 0:
+ assert self.pad_audio
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff, feat_dim), 0.0)])
+ padding_mask[i, diff:] = True
+ else:
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
+ audio, audio_size
+ )
+ return collated_audios.squeeze(-1), padding_mask, audio_starts
+
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
+ assert label_rate > 0
+ s2f = label_rate / self.sample_rate
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
+ frm_size = int(round(audio_size * s2f))
+ if not self.pad_audio:
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
+ frm_size = min(frm_size, *rem_size)
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
+ logger.debug(f"audio_starts={audio_starts}")
+ logger.debug(f"frame_starts={frm_starts}")
+ logger.debug(f"frame_size={frm_size}")
+
+ lengths = torch.LongTensor([len(t) for t in targets])
+ ntokens = lengths.sum().item()
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
+ return targets, lengths, ntokens
+
+ def collater_seq_label(self, targets, pad):
+ lengths = torch.LongTensor([len(t) for t in targets])
+ ntokens = lengths.sum().item()
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
+ return targets, lengths, ntokens
+
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
+ targets_list, lengths_list, ntokens_list = [], [], []
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
+ for targets, label_rate, pad in itr:
+ if label_rate == -1:
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
+ else:
+ targets, lengths, ntokens = self.collater_frm_label(
+ targets, audio_size, audio_starts, label_rate, pad
+ )
+ targets_list.append(targets)
+ lengths_list.append(lengths)
+ ntokens_list.append(ntokens)
+ return targets_list, lengths_list, ntokens_list
+
+ def num_tokens(self, index):
+ return self.size(index)
+
+ def size(self, index):
+ if self.pad_audio:
+ return self.wav_sizes[index]
+ return min(self.wav_sizes[index], self.max_sample_size)
+
+ @property
+ def sizes(self):
+ return np.array(self.wav_sizes)
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+
+ if self.shuffle:
+ if len(self.chunk_names) > 0:
+ logger.info(f"ordered indices for epoch {self.epoch}")
+ with data_utils.numpy_seed(self.epoch):
+ self.chunk_order = np.random.permutation(len(self.chunk_names))
+ chunk_count = 0
+ tmp_sizes = []
+ tmp_indices = []
+ indice = []
+ for i in self.chunk_order:
+ chunk_count += 1
+ start = self.chunk_indices[i]
+ end = self.chunk_indices[i+1] if i < len(self.chunk_names) - 1 else len(self)
+ size = list(self.sizes[start:end])
+ tmp_indices.extend(list(np.arange(start, end)))
+ tmp_sizes.extend(size)
+ if chunk_count % 10 == 0 or i == self.chunk_order[0]:
+ order = [np.random.permutation(len(tmp_indices))]
+ order.append(
+ np.minimum(
+ np.array(tmp_sizes),
+ self.max_sample_size,
+ )
+ )
+ sort_idx = np.lexsort(order)[::-1]
+ indice.append(np.array([tmp_indices[k] for k in sort_idx]))
+ tmp_indices = []
+ tmp_sizes =[]
+ return indice
+ else:
+ order = [np.random.permutation(len(self))]
+ order.append(
+ np.minimum(
+ np.array(self.sizes),
+ self.max_sample_size,
+ )
+ )
+ return np.lexsort(order)[::-1]
+ else:
+ return np.arange(len(self))
+
+ def postprocess(self, wav, cur_sample_rate):
+ if wav.dim() == 2:
+ wav = wav.mean(-1)
+ assert wav.dim() == 1, wav.dim()
+
+ if cur_sample_rate != self.sample_rate:
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
+
+ if self.normalize:
+ with torch.no_grad():
+ wav = F.layer_norm(wav, wav.shape)
+ return wav
diff --git a/SpeechT5/Speech2S/speech2s/data/language_trible_dataset.py b/SpeechT5/Speech2S/speech2s/data/language_trible_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6494127d6bb5d993d557f9f534f7cca83b0f7fa1
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/data/language_trible_dataset.py
@@ -0,0 +1,669 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+import logging
+import numpy as np
+import torch
+import os
+import itertools
+
+from fairseq.data import FairseqDataset, data_utils
+from fairseq.data import (
+ AppendTokenDataset,
+ ConcatDataset,
+ PrependTokenDataset,
+ data_utils,
+ indexed_dataset,
+)
+
+logger = logging.getLogger(__name__)
+
+def load_langtriple_dataset(
+ data_path,
+ split,
+ src,
+ src_dict,
+ ref,
+ ref_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ dataset_impl,
+ upsample_primary,
+ left_pad_source,
+ left_pad_target,
+ max_source_positions,
+ max_target_positions,
+ prepend_bos=False,
+ load_alignments=False,
+ truncate_source=False,
+ append_source_id=False,
+ num_buckets=0,
+ shuffle=True,
+ pad_to_multiple=1,
+ prepend_bos_src=None,
+ lang_format="[{}]",
+):
+ assert not truncate_source
+ def split_exists(split, src, ref, tgt, lang, data_path):
+ filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang))
+ return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
+
+ src_datasets = []
+ ref_datasets = []
+ tgt_datasets = []
+
+ for k in itertools.count():
+ split_k = split + (str(k) if k > 0 else "")
+
+ # infer langcode
+ if split_exists(split_k, src, ref, tgt, src, data_path):
+ prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt))
+ elif split_exists(split_k, tgt, ref, src, src, data_path):
+ prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src))
+ else:
+ if k > 0:
+ break
+ else:
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, data_path)
+ )
+
+ src_dataset = data_utils.load_indexed_dataset(
+ prefix + src, src_dict, dataset_impl
+ )
+ src_datasets.append(src_dataset)
+
+ ref_dataset = data_utils.load_indexed_dataset(
+ prefix + ref, ref_dict, dataset_impl
+ )
+ ref_datasets.append(ref_dataset)
+
+ tgt_dataset = data_utils.load_indexed_dataset(
+ prefix + tgt, tgt_dict, dataset_impl
+ )
+ if tgt_dataset is not None:
+ tgt_datasets.append(tgt_dataset)
+
+ logger.info(
+ "{} {} {}-{}-{} {} examples".format(
+ data_path, split_k, src, ref, tgt, len(src_datasets[-1])
+ )
+ )
+
+ if not combine:
+ break
+
+ assert len(src_datasets) == len(ref_datasets)
+ assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
+
+ if len(src_datasets) == 1:
+ src_dataset = src_datasets[0]
+ ref_dataset = ref_datasets[0]
+ tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
+ else:
+ sample_ratios = [1] * len(src_datasets)
+ sample_ratios[0] = upsample_primary
+ src_dataset = ConcatDataset(src_datasets, sample_ratios)
+ ref_dataset = ConcatDataset(ref_datasets, sample_ratios)
+ if len(tgt_datasets) > 0:
+ tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
+ else:
+ tgt_dataset = None
+
+ if prepend_bos:
+ assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
+ src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
+ ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos())
+ if tgt_dataset is not None:
+ tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
+ elif prepend_bos_src is not None:
+ logger.info(f"prepending src bos: {prepend_bos_src}")
+ src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
+ ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src)
+
+ eos = None
+ if append_source_id:
+ src_dataset = AppendTokenDataset(
+ src_dataset, src_dict.index(lang_format.format(src))
+ )
+ ref_dataset = AppendTokenDataset(
+ ref_dataset, ref_dict.index(lang_format.format(ref))
+ )
+ if tgt_dataset is not None:
+ tgt_dataset = AppendTokenDataset(
+ tgt_dataset, tgt_dict.index(lang_format.format(tgt))
+ )
+ eos = tgt_dict.index(lang_format.format(tgt))
+
+ align_dataset = None
+ if load_alignments:
+ align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
+ if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
+ align_dataset = data_utils.load_indexed_dataset(
+ align_path, None, dataset_impl
+ )
+
+ tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
+ return LanguageTripleDataset(
+ src_dataset,
+ src_dataset.sizes,
+ src_dict,
+ ref_dataset,
+ ref_dataset.sizes,
+ ref_dict,
+ tgt_dataset,
+ tgt_dataset_sizes,
+ tgt_dict,
+ left_pad_source=left_pad_source,
+ left_pad_target=left_pad_target,
+ align_dataset=align_dataset,
+ eos=eos,
+ num_buckets=num_buckets,
+ shuffle=shuffle,
+ pad_to_multiple=pad_to_multiple,
+ )
+
+
+def collate(
+ samples,
+ pad_idx,
+ eos_idx,
+ left_pad_source=True,
+ left_pad_target=False,
+ input_feeding=True,
+ pad_to_length=None,
+ pad_to_multiple=1,
+):
+ if len(samples) == 0:
+ return {}
+
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
+ return data_utils.collate_tokens(
+ [s[key] for s in samples],
+ pad_idx,
+ None,
+ left_pad,
+ move_eos_to_beginning,
+ pad_to_length=pad_to_length,
+ pad_to_multiple=pad_to_multiple,
+ )
+
+ def check_alignment(alignment, src_len, tgt_len):
+ if alignment is None or len(alignment) == 0:
+ return False
+ if (
+ alignment[:, 0].max().item() >= src_len - 1
+ or alignment[:, 1].max().item() >= tgt_len - 1
+ ):
+ logger.warning("alignment size mismatch found, skipping alignment!")
+ return False
+ return True
+
+ def compute_alignment_weights(alignments):
+ """
+ Given a tensor of shape [:, 2] containing the source-target indices
+ corresponding to the alignments, a weight vector containing the
+ inverse frequency of each target index is computed.
+ For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
+ a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
+ index 3 is repeated twice)
+ """
+ align_tgt = alignments[:, 1]
+ _, align_tgt_i, align_tgt_c = torch.unique(
+ align_tgt, return_inverse=True, return_counts=True
+ )
+ align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
+ return 1.0 / align_weights.float()
+
+ id = torch.LongTensor([s["id"] for s in samples])
+ src_tokens = merge(
+ "source",
+ left_pad=left_pad_source,
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
+ )
+ ref_tokens = merge(
+ "reference",
+ left_pad=left_pad_source,
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
+ )
+ # sort by descending source length
+ src_lengths = torch.LongTensor(
+ [s["source"].ne(pad_idx).long().sum() for s in samples]
+ )
+ ref_lengths = torch.LongTensor(
+ [s["reference"].ne(pad_idx).long().sum() for s in samples]
+ )
+ src_lengths, sort_order = src_lengths.sort(descending=True)
+ id = id.index_select(0, sort_order)
+ src_tokens = src_tokens.index_select(0, sort_order)
+ ref_lengths = ref_lengths.index_select(0, sort_order)
+ ref_tokens = ref_tokens.index_select(0, sort_order)
+
+ prev_output_tokens = None
+ target = None
+ if samples[0].get("target", None) is not None:
+ target = merge(
+ "target",
+ left_pad=left_pad_target,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
+ )
+ target = target.index_select(0, sort_order)
+ tgt_lengths = torch.LongTensor(
+ [s["target"].ne(pad_idx).long().sum() for s in samples]
+ ).index_select(0, sort_order)
+ ntokens = tgt_lengths.sum().item()
+
+ if samples[0].get("prev_output_tokens", None) is not None:
+ prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
+ elif input_feeding:
+ # we create a shifted version of targets for feeding the
+ # previous output token(s) into the next decoder step
+ prev_output_tokens = merge(
+ "target",
+ left_pad=left_pad_target,
+ move_eos_to_beginning=True,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
+ )
+ else:
+ ntokens = src_lengths.sum().item()
+
+ batch = {
+ "id": id,
+ "nsentences": len(samples),
+ "ntokens": ntokens,
+ "net_input": {
+ "src_tokens": src_tokens,
+ "src_lengths": src_lengths,
+ },
+ "target": target,
+ "ref_tokens": ref_tokens,
+ "ref_lengths": ref_lengths,
+ }
+ if prev_output_tokens is not None:
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
+ 0, sort_order
+ )
+
+ if samples[0].get("alignment", None) is not None:
+ bsz, tgt_sz = batch["target"].shape
+ src_sz = batch["net_input"]["src_tokens"].shape[1]
+
+ offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
+ offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
+ if left_pad_source:
+ offsets[:, 0] += src_sz - src_lengths
+ if left_pad_target:
+ offsets[:, 1] += tgt_sz - tgt_lengths
+
+ alignments = [
+ alignment + offset
+ for align_idx, offset, src_len, tgt_len in zip(
+ sort_order, offsets, src_lengths, tgt_lengths
+ )
+ for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
+ if check_alignment(alignment, src_len, tgt_len)
+ ]
+
+ if len(alignments) > 0:
+ alignments = torch.cat(alignments, dim=0)
+ align_weights = compute_alignment_weights(alignments)
+
+ batch["alignments"] = alignments
+ batch["align_weights"] = align_weights
+
+ if samples[0].get("constraints", None) is not None:
+ # Collate the packed constraints across the samples, padding to
+ # the length of the longest sample.
+ lens = [sample.get("constraints").size(0) for sample in samples]
+ max_len = max(lens)
+ constraints = torch.zeros((len(samples), max(lens))).long()
+ for i, sample in enumerate(samples):
+ constraints[i, 0 : lens[i]] = samples[i].get("constraints")
+ batch["constraints"] = constraints.index_select(0, sort_order)
+
+ return batch
+
+
+class LanguageTripleDataset(FairseqDataset):
+ """
+ A pair of torch.utils.data.Datasets.
+
+ Args:
+ src (torch.utils.data.Dataset): source dataset to wrap
+ src_sizes (List[int]): source sentence lengths
+ src_dict (~fairseq.data.Dictionary): source vocabulary
+ tgt (torch.utils.data.Dataset, optional): target dataset to wrap
+ tgt_sizes (List[int], optional): target sentence lengths
+ tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
+ left_pad_source (bool, optional): pad source tensors on the left side
+ (default: True).
+ left_pad_target (bool, optional): pad target tensors on the left side
+ (default: False).
+ shuffle (bool, optional): shuffle dataset elements before batching
+ (default: True).
+ input_feeding (bool, optional): create a shifted version of the targets
+ to be passed into the model for teacher forcing (default: True).
+ remove_eos_from_source (bool, optional): if set, removes eos from end
+ of source if it's present (default: False).
+ append_eos_to_target (bool, optional): if set, appends eos to end of
+ target if it's absent (default: False).
+ align_dataset (torch.utils.data.Dataset, optional): dataset
+ containing alignments.
+ constraints (Tensor, optional): 2d tensor with a concatenated, zero-
+ delimited list of constraints for each sentence.
+ append_bos (bool, optional): if set, appends bos to the beginning of
+ source/target sentence.
+ num_buckets (int, optional): if set to a value greater than 0, then
+ batches will be bucketed into the given number of batch shapes.
+ src_lang_id (int, optional): source language ID, if set, the collated batch
+ will contain a field 'src_lang_id' in 'net_input' which indicates the
+ source language of the samples.
+ tgt_lang_id (int, optional): target language ID, if set, the collated batch
+ will contain a field 'tgt_lang_id' which indicates the target language
+ of the samples.
+ """
+
+ def __init__(
+ self,
+ src,
+ src_sizes,
+ src_dict,
+ ref,
+ ref_sizes,
+ ref_dict,
+ tgt=None,
+ tgt_sizes=None,
+ tgt_dict=None,
+ left_pad_source=True,
+ left_pad_target=False,
+ shuffle=True,
+ input_feeding=True,
+ remove_eos_from_source=False,
+ append_eos_to_target=False,
+ align_dataset=None,
+ constraints=None,
+ append_bos=False,
+ eos=None,
+ num_buckets=0,
+ src_lang_id=None,
+ tgt_lang_id=None,
+ pad_to_multiple=1,
+ ):
+ if tgt_dict is not None:
+ assert src_dict.pad() == tgt_dict.pad()
+ assert src_dict.eos() == tgt_dict.eos()
+ assert src_dict.unk() == tgt_dict.unk()
+ if tgt is not None:
+ assert len(src) == len(
+ tgt
+ ), "Source and target must contain the same number of examples"
+ assert len(src) == len(
+ ref
+ ), "Source and reference must contain the same number of examples"
+ self.src = src
+ self.ref = ref
+ self.tgt = tgt
+ self.src_sizes = np.array(src_sizes)
+ self.ref_sizes = np.array(ref_sizes)
+ self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
+ self.sizes = (
+ np.vstack((self.src_sizes, self.tgt_sizes)).T
+ if self.tgt_sizes is not None
+ else self.src_sizes
+ )
+ self.src_dict = src_dict
+ self.ref_dict = ref_dict
+ self.tgt_dict = tgt_dict
+ self.left_pad_source = left_pad_source
+ self.left_pad_target = left_pad_target
+ self.shuffle = shuffle
+ self.input_feeding = input_feeding
+ self.remove_eos_from_source = remove_eos_from_source
+ self.append_eos_to_target = append_eos_to_target
+ self.align_dataset = align_dataset
+ if self.align_dataset is not None:
+ assert (
+ self.tgt_sizes is not None
+ ), "Both source and target needed when alignments are provided"
+ self.constraints = constraints
+ self.append_bos = append_bos
+ self.eos = eos if eos is not None else src_dict.eos()
+ self.src_lang_id = src_lang_id
+ self.tgt_lang_id = tgt_lang_id
+ if num_buckets > 0:
+ from fairseq.data import BucketPadLengthDataset
+
+ self.src = BucketPadLengthDataset(
+ self.src,
+ sizes=self.src_sizes,
+ num_buckets=num_buckets,
+ pad_idx=self.src_dict.pad(),
+ left_pad=self.left_pad_source,
+ )
+ self.src_sizes = self.src.sizes
+ logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
+ self.ref = BucketPadLengthDataset(
+ self.ref,
+ sizes=self.ref_sizes,
+ num_buckets=num_buckets,
+ pad_idx=self.ref_dict.pad(),
+ left_pad=self.left_pad_source,
+ )
+ self.ref_sizes = self.ref.sizes
+ logger.info("bucketing reference lengths: {}".format(list(self.src.buckets)))
+ if self.tgt is not None:
+ self.tgt = BucketPadLengthDataset(
+ self.tgt,
+ sizes=self.tgt_sizes,
+ num_buckets=num_buckets,
+ pad_idx=self.tgt_dict.pad(),
+ left_pad=self.left_pad_target,
+ )
+ self.tgt_sizes = self.tgt.sizes
+ logger.info(
+ "bucketing target lengths: {}".format(list(self.tgt.buckets))
+ )
+
+ # determine bucket sizes using self.num_tokens, which will return
+ # the padded lengths (thanks to BucketPadLengthDataset)
+ num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
+ self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
+ self.buckets = [
+ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
+ ]
+ else:
+ self.buckets = None
+ self.pad_to_multiple = pad_to_multiple
+
+ def get_batch_shapes(self):
+ return self.buckets
+
+ def __getitem__(self, index):
+ tgt_item = self.tgt[index] if self.tgt is not None else None
+ src_item = self.src[index]
+ ref_item = self.ref[index]
+ # Append EOS to end of tgt sentence if it does not have an EOS and remove
+ # EOS from end of src sentence if it exists. This is useful when we use
+ # use existing datasets for opposite directions i.e., when we want to
+ # use tgt_dataset as src_dataset and vice versa
+ if self.append_eos_to_target:
+ eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
+ if self.tgt and self.tgt[index][-1] != eos:
+ tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
+
+ if self.append_bos:
+ bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
+ if self.tgt and self.tgt[index][0] != bos:
+ tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
+
+ bos = self.src_dict.bos()
+ if self.src[index][0] != bos:
+ src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
+ if self.ref[index][0] != bos:
+ ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]])
+
+ if self.remove_eos_from_source:
+ eos = self.src_dict.eos()
+ if self.src[index][-1] == eos:
+ src_item = self.src[index][:-1]
+ if self.ref[index][-1] == eos:
+ ref_item = self.ref[index][:-1]
+
+ example = {
+ "id": index,
+ "source": src_item,
+ "reference": ref_item,
+ "target": tgt_item,
+ }
+ if self.align_dataset is not None:
+ example["alignment"] = self.align_dataset[index]
+ if self.constraints is not None:
+ example["constraints"] = self.constraints[index]
+ return example
+
+ def __len__(self):
+ return len(self.src)
+
+ def collater(self, samples, pad_to_length=None):
+ """Merge a list of samples to form a mini-batch.
+
+ Args:
+ samples (List[dict]): samples to collate
+ pad_to_length (dict, optional): a dictionary of
+ {'source': source_pad_to_length, 'target': target_pad_to_length}
+ to indicate the max length to pad to in source and target respectively.
+
+ Returns:
+ dict: a mini-batch with the following keys:
+
+ - `id` (LongTensor): example IDs in the original input order
+ - `ntokens` (int): total number of tokens in the batch
+ - `net_input` (dict): the input to the Model, containing keys:
+
+ - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
+ the source sentence of shape `(bsz, src_len)`. Padding will
+ appear on the left if *left_pad_source* is ``True``.
+ - `src_lengths` (LongTensor): 1D Tensor of the unpadded
+ lengths of each source sentence of shape `(bsz)`
+ - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
+ tokens in the target sentence, shifted right by one
+ position for teacher forcing, of shape `(bsz, tgt_len)`.
+ This key will not be present if *input_feeding* is
+ ``False``. Padding will appear on the left if
+ *left_pad_target* is ``True``.
+ - `src_lang_id` (LongTensor): a long Tensor which contains source
+ language IDs of each sample in the batch
+
+ - `target` (LongTensor): a padded 2D Tensor of tokens in the
+ target sentence of shape `(bsz, tgt_len)`. Padding will appear
+ on the left if *left_pad_target* is ``True``.
+ - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
+ IDs of each sample in the batch
+ """
+ res = collate(
+ samples,
+ pad_idx=self.src_dict.pad(),
+ eos_idx=self.eos,
+ left_pad_source=self.left_pad_source,
+ left_pad_target=self.left_pad_target,
+ input_feeding=self.input_feeding,
+ pad_to_length=pad_to_length,
+ pad_to_multiple=self.pad_to_multiple,
+ )
+ if self.src_lang_id is not None or self.tgt_lang_id is not None:
+ src_tokens = res["net_input"]["src_tokens"]
+ bsz = src_tokens.size(0)
+ if self.src_lang_id is not None:
+ res["net_input"]["src_lang_id"] = (
+ torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
+ )
+ if self.tgt_lang_id is not None:
+ res["tgt_lang_id"] = (
+ torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
+ )
+ return res
+
+ def num_tokens(self, index):
+ """Return the number of tokens in a sample. This value is used to
+ enforce ``--max-tokens`` during batching."""
+ return max(
+ self.src_sizes[index],
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
+ )
+
+ def num_tokens_vec(self, indices):
+ """Return the number of tokens for a set of positions defined by indices.
+ This value is used to enforce ``--max-tokens`` during batching."""
+ sizes = self.src_sizes[indices]
+ if self.tgt_sizes is not None:
+ sizes = np.maximum(sizes, self.tgt_sizes[indices])
+ return sizes
+
+ def size(self, index):
+ """Return an example's size as a float or tuple. This value is used when
+ filtering a dataset with ``--max-positions``."""
+ return (
+ self.src_sizes[index],
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
+ )
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ indices = np.random.permutation(len(self)).astype(np.int64)
+ else:
+ indices = np.arange(len(self), dtype=np.int64)
+ if self.buckets is None:
+ # sort by target length, then source length
+ if self.tgt_sizes is not None:
+ indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
+ return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
+ else:
+ # sort by bucketed_num_tokens, which is:
+ # max(padded_src_len, padded_tgt_len)
+ return indices[
+ np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
+ ]
+
+ @property
+ def supports_prefetch(self):
+ return getattr(self.src, "supports_prefetch", False) and (
+ getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
+ )
+
+ def prefetch(self, indices):
+ self.src.prefetch(indices)
+ if self.tgt is not None:
+ self.tgt.prefetch(indices)
+ if self.align_dataset is not None:
+ self.align_dataset.prefetch(indices)
+
+ def filter_indices_by_size(self, indices, max_sizes):
+ """Filter a list of sample indices. Remove those that are longer
+ than specified in max_sizes.
+
+ Args:
+ indices (np.array): original array of sample indices
+ max_sizes (int or list[int] or tuple[int]): max sample size,
+ can be defined separately for src and tgt (then list or tuple)
+
+ Returns:
+ np.array: filtered sample array
+ list: list of removed indices
+ """
+ return data_utils.filter_paired_dataset_indices_by_size(
+ self.src_sizes,
+ self.tgt_sizes,
+ indices,
+ max_sizes,
+ )
diff --git a/SpeechT5/Speech2S/speech2s/data/load_langpair_dataset.py b/SpeechT5/Speech2S/speech2s/data/load_langpair_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfd204598e67d41a5688e16b0835f96fd40cf384
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/data/load_langpair_dataset.py
@@ -0,0 +1,172 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+"""
+ Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py
+ 1. Add custom lang_format in function load_langpair_dataset
+ 2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset
+"""
+
+import itertools
+import logging
+import os
+
+from fairseq.data import (
+ AppendTokenDataset,
+ LanguagePairDataset,
+ PrependTokenDataset,
+ StripTokenDataset,
+ TruncateDataset,
+ RandomCropDataset,
+ data_utils,
+ indexed_dataset,
+)
+
+from speechut.data.concat_dataset import ConcatDataset
+
+
+EVAL_BLEU_ORDER = 4
+
+
+logger = logging.getLogger(__name__)
+
+
+def load_langpair_dataset(
+ data_path,
+ split,
+ src,
+ src_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ dataset_impl,
+ upsample_primary,
+ left_pad_source,
+ left_pad_target,
+ max_source_positions,
+ max_target_positions,
+ prepend_bos=False,
+ load_alignments=False,
+ truncate_source=False,
+ append_source_id=False,
+ num_buckets=0,
+ shuffle=True,
+ pad_to_multiple=1,
+ prepend_bos_src=None,
+ lang_format="[{}]",
+ input_feeding=True,
+):
+ def split_exists(split, src, tgt, lang, data_path):
+ filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
+ return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
+
+ src_datasets = []
+ tgt_datasets = []
+
+ for k in itertools.count():
+ split_k = split + (str(k) if k > 0 else "")
+
+ # infer langcode
+ if split_exists(split_k, src, tgt, src, data_path):
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
+ elif split_exists(split_k, tgt, src, src, data_path):
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
+ else:
+ if k > 0:
+ break
+ else:
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, data_path)
+ )
+
+ src_dataset = data_utils.load_indexed_dataset(
+ prefix + src, src_dict, dataset_impl
+ )
+ if truncate_source:
+ src_dataset = AppendTokenDataset(
+ RandomCropDataset(
+ StripTokenDataset(src_dataset, src_dict.eos()),
+ max_source_positions - 1,
+ ),
+ src_dict.eos(),
+ )
+ src_datasets.append(src_dataset)
+
+ tgt_dataset = data_utils.load_indexed_dataset(
+ prefix + tgt, tgt_dict, dataset_impl
+ )
+ if tgt_dataset is not None:
+ tgt_datasets.append(tgt_dataset)
+
+ logger.info(
+ "{} {} {}-{} {} examples".format(
+ data_path, split_k, src, tgt, len(src_datasets[-1])
+ )
+ )
+
+ if not combine:
+ break
+
+ assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
+
+ if len(src_datasets) == 1:
+ src_dataset = src_datasets[0]
+ tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
+ else:
+ sample_ratios = [1] * len(src_datasets)
+ sample_ratios[0] = upsample_primary
+ src_dataset = ConcatDataset(src_datasets, sample_ratios)
+ if len(tgt_datasets) > 0:
+ tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
+ else:
+ tgt_dataset = None
+
+ if prepend_bos:
+ assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
+ src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
+ if tgt_dataset is not None:
+ tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
+ elif prepend_bos_src is not None:
+ logger.info(f"prepending src bos: {prepend_bos_src}")
+ src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
+
+ eos = None
+ if append_source_id:
+ src_dataset = AppendTokenDataset(
+ src_dataset, src_dict.index(lang_format.format(src))
+ )
+ if tgt_dataset is not None:
+ tgt_dataset = AppendTokenDataset(
+ tgt_dataset, tgt_dict.index(lang_format.format(tgt))
+ )
+ eos = tgt_dict.index(lang_format.format(tgt))
+
+ align_dataset = None
+ if load_alignments:
+ align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
+ if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
+ align_dataset = data_utils.load_indexed_dataset(
+ align_path, None, dataset_impl
+ )
+
+ tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
+ return LanguagePairDataset(
+ src_dataset,
+ src_dataset.sizes,
+ src_dict,
+ tgt_dataset,
+ tgt_dataset_sizes,
+ tgt_dict,
+ left_pad_source=left_pad_source,
+ left_pad_target=left_pad_target,
+ align_dataset=align_dataset,
+ eos=eos,
+ num_buckets=num_buckets,
+ shuffle=shuffle,
+ pad_to_multiple=pad_to_multiple,
+ input_feeding=input_feeding,
+ )
diff --git a/SpeechT5/Speech2S/speech2s/data/multimodal_corpus_dataset.py b/SpeechT5/Speech2S/speech2s/data/multimodal_corpus_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..19a6f8962757dec9b32430a98cd6e850d1f30d19
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/data/multimodal_corpus_dataset.py
@@ -0,0 +1,368 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+import logging
+from os import replace
+import time
+from collections import OrderedDict
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+from fairseq.data import data_utils
+
+from fairseq.data import FairseqDataset
+
+logger = logging.getLogger(__name__)
+
+
+class MultiCorpusDataset(FairseqDataset):
+ """
+ see fairseq/fairseq/data/multi_corpus_dataset.__doc__
+
+ Args:
+ datasets: a OrderedDict of FairseqDataset instances.
+ distribution: a List containing the probability of getting an utterance from
+ corresponding dataset
+ seed: random seed for sampling the datsets
+ sort_indices: if true, will sort the ordered indices by size
+ batch_sample: if true, will ensure each batch is from a single dataset
+ """
+
+ def __init__(
+ self,
+ datasets: Dict[str, FairseqDataset],
+ max_positions: Dict,
+ distribution: List[float],
+ max_tokens_ratio: List[float],
+ seed: int = 1234,
+ sort_indices: bool = False,
+ check_length: bool = False,
+ ):
+ super().__init__()
+ assert isinstance(datasets, OrderedDict)
+ assert len(datasets) == len(distribution)
+ # assert sum(distribution) == 1
+ self.datasets = datasets
+ self.distribution = distribution
+ self.max_tokens_ratio = max_tokens_ratio
+ self.seed = seed
+ self.sort_indices = sort_indices
+ self.max_positions = max_positions
+ self.check_length = check_length
+
+ # Avoid repeated conversions to list later
+ self.dataset_list = list(datasets.values())
+ self.total_num_instances = 0
+
+ # first_dataset = self.dataset_list[0]
+
+ self.num_instances_per_dataset = []
+ self.dataset_offsets = []
+ for i, dataset in enumerate(self.dataset_list):
+ assert isinstance(dataset, FairseqDataset)
+ # assert type(dataset) is type(first_dataset)
+ self.num_instances_per_dataset.append(
+ 0 if self.distribution[i] == 0 else len(dataset)
+ )
+ self.dataset_offsets.append(self.total_num_instances)
+ self.total_num_instances += self.num_instances_per_dataset[i]
+
+ def ordered_indices(self):
+ start = time.time()
+ with data_utils.numpy_seed(self.seed, self.epoch):
+ logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}")
+ sampled_indices = {}
+
+ # For each dataset i, sample self.distribution[i] * self.total_num_instances
+ for i, key in enumerate(self.datasets):
+ tp = time.time()
+ if self.distribution[i] == 0:
+ # skip dataset if sampling probability is 0
+ continue
+
+ if i < len(self.datasets) - 1:
+ num_instances = int(self.distribution[i] * self.total_num_instances)
+ high = self.dataset_offsets[i + 1]
+ else:
+ num_instances = int(self.distribution[i] * self.total_num_instances)
+ high = self.total_num_instances
+
+ logger.info(f"sampling {num_instances} from {key} dataset")
+
+ # First, add k copies of the dataset where k = num_instances // len(dataset).
+ # This ensures an equal distribution of the data points as much as possible.
+ # For the remaining entries randomly sample them
+ dataset_size = len(self.datasets[key])
+ num_copies = num_instances // dataset_size
+ dataset_indices = np.random.permutation(high - self.dataset_offsets[i])[: num_instances - num_copies * dataset_size]
+ if num_copies > 0:
+ dataset_indices = np.concatenate(
+ (
+ np.repeat(
+ np.arange(high - self.dataset_offsets[i]), num_copies
+ ),
+ dataset_indices,
+ )
+ )
+ # filter by size, we should ignore it by setting check_length=False
+ # , as it is very time-consuming on large dadaset
+ if self.max_positions[key] is not None and self.check_length:
+ dataset_indices, ignored = self.datasets[key].filter_indices_by_size(
+ dataset_indices,
+ self.max_positions[key],
+ )
+ if len(ignored) > 0:
+ logger.warning(
+ (
+ "{:,} samples have invalid sizes and will be skipped, "
+ "max_positions={}, first few sample ids={}"
+ ).format(len(ignored), self.max_positions[key], ignored[:10])
+ )
+
+ if self.sort_indices:
+ logger.info(" - sampled indices took {}s".format(time.time() - tp))
+ tp = time.time()
+ dataset_indices = np.sort(dataset_indices)
+ ordered_indices = self.datasets[key].ordered_indices()
+ if isinstance(ordered_indices[0], np.ndarray): # chunked audio data
+ dataset_indices = [order_idx + self.dataset_offsets[i] for order_idx in ordered_indices]
+ assert self.dataset_offsets[i] == 0
+ # TODO for chunked audio data, now assume len(dataset_indices) == len(dataset). Don't filter any data.
+ else:
+ dataset_indices = ordered_indices[dataset_indices] + self.dataset_offsets[i]
+ logger.info(" - ordered_indices took {}s".format(time.time() - tp))
+ else:
+ np.random.shuffle(dataset_indices)
+
+ sampled_indices[key] = dataset_indices
+
+ logger.info(
+ "multi_corpus_dataset ordered_indices took {}s".format(
+ time.time() - start
+ )
+ )
+ return sampled_indices
+
+ def _map_index(self, index: int):
+ """
+ If dataset A has length N and dataset B has length M
+ then index 1 maps to index 1 of dataset A, and index N + 1
+ maps to index 1 of B.
+ """
+ counter = 0
+ for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
+ if index < counter + num_instances:
+ return index - counter, key
+ counter += num_instances
+ raise ValueError(
+ "Invalid index: {}, max: {}".format(index, self.total_num_instances)
+ )
+
+ def __len__(self):
+ """
+ Length of this dataset is the sum of individual datasets
+ """
+ return self.total_num_instances
+
+ def __getitem__(self, index):
+ new_index, key = self._map_index(index)
+ try:
+ item = self.datasets[key][new_index]
+ item["full_id"] = index
+ return item
+ except Exception as e:
+ e.args = (f"Error from {key} dataset", *e.args)
+ raise
+
+ def collater(self, samples):
+ """
+ If we are doing batch sampling, then pick the right collater to use.
+
+ Otherwise we assume all collaters are the same.
+ """
+ if len(samples) == 0:
+ return None
+
+ samples_dict = {key: [] for key in self.datasets}
+ for s in samples:
+ _, key = self._map_index(s["full_id"])
+ samples_dict[key].append(s)
+
+ batch = {}
+ for key in samples_dict:
+ if len(samples_dict[key]) == 0:
+ continue
+ batch[key] = self.datasets[key].collater(samples_dict[key])
+
+ return batch
+
+
+ def num_tokens(self, index: int):
+ index, key = self._map_index(index)
+ return self.datasets[key].num_tokens(index)
+
+ def size(self, index: int):
+ index, key = self._map_index(index)
+ return self.datasets[key].size(index)
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return False
+
+ def set_epoch(self, epoch, **unused):
+ super().set_epoch(epoch)
+ logger.info(f"setting epoch of multi_corpus_dataset to {epoch}")
+ for ds in self.dataset_list:
+ if hasattr(ds, "set_epoch"):
+ ds.set_epoch(epoch)
+ self.epoch = epoch
+
+ @property
+ def supports_prefetch(self):
+ return False
+
+ @property
+ def supports_fetch_outside_dataloader(self):
+ return all(
+ self.datasets[key].supports_fetch_outside_dataloader
+ for key in self.datasets
+ )
+
+
+ def batch_by_size(
+ self,
+ indices,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+ ):
+ dataset_indices = indices
+ batches_dict = {}
+ for n, key in enumerate(dataset_indices):
+ max_tokens_ratio = self.max_tokens_ratio[n]
+ if isinstance(dataset_indices[key][0], np.ndarray): # chunked audio data
+ cur_batches = self.datasets[key].batch_by_size(
+ dataset_indices[key],
+ round(max_tokens * max_tokens_ratio),
+ max_sentences,
+ required_batch_size_multiple,
+ )
+ logger.info(f"Created {sum([len(b) for b in cur_batches])} [{len(cur_batches)}] batches for dataset {key}")
+ else:
+ cur_batches = super().batch_by_size(
+ np.array(dataset_indices[key], dtype=np.int64),
+ round(max_tokens * max_tokens_ratio),
+ max_sentences,
+ required_batch_size_multiple,
+ )
+ logger.info(f"Created {len(cur_batches)} batches for dataset {key}")
+ batches_dict[key] = cur_batches
+
+ return batches_dict
+
+
+ def get_batch_sampler(
+ self,
+ indices,
+ num_shards,
+ seed,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+ split_modality_batch=False,
+ ):
+
+ def batch_sampler(dataset, epoch):
+ start = time.time()
+ batches_dict = dataset.batch_by_size(
+ indices,
+ max_tokens=max_tokens,
+ max_sentences=max_sentences,
+ required_batch_size_multiple=required_batch_size_multiple,
+ )
+ logger.info(f"multi_corpus_dataset, batch_by_size took {time.time() - start}s")
+ start = time.time()
+ new_batches = []
+
+ ### shuffle inner group size, split into speech/text batches
+ shuffled_batches_list = []
+ speech_batches = []
+ ### we should specify the speech_batches because: we need concatenate different speech datasets
+ # (e.g. ltr or km) instead of loading them parellelly.
+ for name, batches in batches_dict.items():
+ if name.startswith("speech"):
+ if isinstance(batches[0], list): # chunked audio data
+ batches = self.datasets[name].shuffle_batches(list(batches), seed + epoch)
+ shuffled_batches_list.append(batches)
+ else:
+ batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
+ batches = batches[: (len(batches) // num_shards) * num_shards]
+ if len(batches) == 0:
+ logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
+ else:
+ speech_batches += batches
+ else:
+ batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
+ batches = batches[: (len(batches) // num_shards) * num_shards]
+ if len(batches) == 0:
+ logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
+ else:
+ batches = shuffle_buckets(batches, seed=seed+epoch, inner_shuf=False)
+ shuffled_batches_list.append(batches)
+ if len(speech_batches) > 0:
+ speech_batches = shuffle_buckets(speech_batches, seed=seed+epoch, inner_shuf=False)
+ shuffled_batches_list.append(speech_batches)
+
+ ### create the final new_batches
+ num_batch = min(len(batches) for batches in shuffled_batches_list)
+ if split_modality_batch:
+ for i in range(0, num_batch, num_shards):
+ for batches in shuffled_batches_list:
+ new_batches += batches[i: i + num_shards]
+ else:
+ for i in range(num_batch):
+ new_batches.append(np.concatenate([batches[i] for batches in shuffled_batches_list]))
+
+ logger.info(f"multi_corpus_dataset sample {len(new_batches)} batches, took {time.time() - start}s")
+ return new_batches
+
+ def inner_bucket_shuffle(batches, seed, bucket_size=10, thr=0):
+ """we assert batches is sorted form long to short.
+ shuffle samples in a buctet(e.g. 10 batches).
+ batches: a list of numpy array"""
+ num_batch = len(batches)
+ new_batches = []
+ num_buckets = len(batches) // bucket_size
+ i = 0
+ while i < num_batch:
+ if (i < bucket_size * thr or
+ i >= bucket_size * (num_buckets - thr)
+ ):
+ new_batches.append(batches[i])
+ i += 1
+ else:
+ group = np.concatenate(batches[i: i+bucket_size])
+ with data_utils.numpy_seed(seed):
+ np.random.shuffle(group)
+ new_batches += np.array_split(group, bucket_size)
+ i += bucket_size
+ assert all([len(batch) > 0 for batch in new_batches])
+ return new_batches
+
+ def shuffle_buckets(batches, seed, inner_shuf=True):
+ if inner_shuf:
+ batches = inner_bucket_shuffle(batches, seed, num_shards*10)
+ batches = [batches[i: i + num_shards] for i in range(0, len(batches)-num_shards+1, num_shards)]
+ assert len(batches[-1]) == num_shards
+ new_batches = []
+ with data_utils.numpy_seed(seed):
+ np.random.shuffle(batches)
+ for group in batches:
+ new_batches += group
+ return new_batches
+
+ return batch_sampler
diff --git a/SpeechT5/Speech2S/speech2s/models/__init__.py b/SpeechT5/Speech2S/speech2s/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SpeechT5/Speech2S/speech2s/models/speechut.py b/SpeechT5/Speech2S/speech2s/models/speechut.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb668286c1c1c420d0c7d7b9e74a3bca17c6c871
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/models/speechut.py
@@ -0,0 +1,785 @@
+# ----------------------------------------------------------------------------
+# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
+# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
+#
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# ----------------------------------------------------------------------------
+
+import logging
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import utils, checkpoint_utils
+from fairseq.data.data_utils import compute_mask_indices
+from fairseq.data.dictionary import Dictionary
+from fairseq.dataclass import ChoiceEnum
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.transformer import Embedding
+from fairseq.file_io import PathManager
+from torch import Tensor
+from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
+from fairseq.modules import GradMultiply, LayerNorm
+from fairseq.tasks.hubert_pretraining import (
+ HubertPretrainingConfig,
+ HubertPretrainingTask,
+)
+from fairseq.models.hubert import HubertConfig
+from fairseq.models.transformer import TransformerConfig
+from speechut.modules import TransformerEncoder
+from speechut.modules import TransformerEncoderBase
+from speechut.modules import TransformerDecoderBaseScriptable
+
+logger = logging.getLogger(__name__)
+
+EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
+MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
+
+
+@dataclass
+
+class SpeechutConfig(HubertConfig):
+ use_rel_pos_enc: bool = field(
+ default=False,
+ metadata={"help": "whether to use relative positional encoding"},
+ )
+ scaling_for_att: float = field(
+ default=1.0,
+ metadata={"help": "scaling for attention weights to prevent overflow issue (for large model)"},
+ )
+
+ # unit encoder-decoder
+ text_transformer: TransformerConfig = TransformerConfig()
+ reset_decoder_embedding_config: bool = field(
+ default=False,
+ metadata={"help": "reset the no_scale_embedding/layernorm_embedding to default for the decoder"},
+ )
+ add_unit_encoder: bool = field(
+ default=False,
+ metadata={"help": "add unit encoder"},
+ )
+ add_decoder: bool = field(
+ default=True,
+ metadata={"help": "add decoder"},
+ )
+ add_text_ctc: bool = field(
+ default=False,
+ metadata={"help": "add_text_ctc head"},
+ )
+ text_ctc_conv_kernel: int = field(
+ default=2,
+ metadata={"help": "text_ctc_conv kernel size"},
+ )
+ mask_u2t: bool = field(
+ default=True,
+ metadata={"help": "mask the unit input in unit-to-text task"},
+ )
+
+ # embedding mixing
+ mix_with_unit: bool = field(
+ default=True,
+ metadata={"help": "mix with the unit embeddings"},
+ )
+ use_pred_unit: bool = field(
+ default=False,
+ metadata={"help": "use the embeddings of predicted units"},
+ )
+ l2_embedding: bool = field(
+ default=False,
+ metadata={"help": "compute l2 loss between unit embedding and unit hidden state"},
+ )
+
+ # Finetune related
+ encoder_dict_size: int = field(
+ default=-1,
+ metadata={"help": "text encoder dictionary dimension"},
+ )
+
+ decoder_dict_size: int = field(
+ default=-1,
+ metadata={"help": "decoder dictionary dimension"},
+ )
+
+
+@register_model("speechut", dataclass=SpeechutConfig)
+class SpeechutModel(BaseFairseqModel):
+ def __init__(
+ self,
+ cfg: SpeechutConfig,
+ task_cfg: HubertPretrainingConfig,
+ dictionaries: List[Dictionary],
+ unit_dictionary: Dictionary = None,
+ text_tgt_dictionary: Dictionary = None,
+ ) -> None:
+ super().__init__()
+ logger.info(f"SpeechutModel Config: {cfg}")
+
+ feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
+ self.embed = feature_enc_layers[-1][0]
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
+
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+ self.logit_temp = cfg.logit_temp
+ self.skip_masked = cfg.skip_masked
+ self.skip_nomask = cfg.skip_nomask
+
+ final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
+
+ self.mask_emb = nn.Parameter(
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+ )
+
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ self.target_glu = None
+ if cfg.target_glu:
+ self.target_glu = nn.Sequential(
+ nn.Linear(final_dim, final_dim * 2), nn.GLU()
+ )
+
+ self.final_dim = final_dim
+ assert len(dictionaries) <= 2, f"Only support <=2 kinds of targets, get {len(dictionaries)} dictionaries"
+ if len(dictionaries) == 1:
+ dictionaries = [dictionaries[0], dictionaries[0]]
+ self.num_classes = [len(d) for d in dictionaries]
+
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
+ self.code_encoder_proj = nn.Linear(cfg.text_transformer.encoder.embed_dim, self.num_classes[-1])
+ self.final_proj_list = [self.final_proj, self.code_encoder_proj]
+
+ self.label_embs_concat = nn.Parameter(torch.FloatTensor(self.num_classes[0], final_dim))
+ self.label_embs_list = [self.label_embs_concat]
+ for p in self.label_embs_list:
+ nn.init.uniform_(p)
+
+ ### build unit encoder:
+ self.mask_u2t = cfg.mask_u2t
+ self.add_text_ctc = cfg.add_text_ctc
+ self.text_ctc_conv_kernel = cfg.text_ctc_conv_kernel
+ self.padding_idx = unit_dictionary.pad()
+ self.unit_mask_idx = unit_dictionary.index("")
+
+ self.add_unit_encoder = cfg.add_unit_encoder
+ self.mix_with_unit = cfg.mix_with_unit
+ self.use_pred_unit = cfg.use_pred_unit
+ self.l2_embedding = cfg.l2_embedding
+ if self.add_unit_encoder:
+ assert len(unit_dictionary) == self.num_classes[0], f"unit_dictionary: {len(unit_dictionary)}, self.num_classes[0]: {self.num_classes[0]}"
+ ### build unit pre-net, and shared with hubert label_embs if needed (default: False)
+ self.unit_embed_tokens = self.build_embedding(
+ unit_dictionary,
+ cfg.text_transformer.encoder.embed_dim,
+ )
+ if self.final_dim == cfg.text_transformer.encoder.embed_dim:
+ logger.info("Share label_embs[0] with unit_embed_tokens ...")
+ nn.init.uniform_(self.unit_embed_tokens.weight)
+ self.label_embs_list[0] = self.unit_embed_tokens.weight
+
+ ### build unit encoder
+ self.unit_encoder = TransformerEncoderBase(
+ cfg.text_transformer,
+ unit_dictionary,
+ self.unit_embed_tokens,
+ use_rel_pos_enc=cfg.use_rel_pos_enc,
+ scaling_for_att=cfg.scaling_for_att,
+ )
+
+ ### build text ctc head
+ if self.add_text_ctc:
+ conv = nn.Conv1d(
+ cfg.text_transformer.encoder.embed_dim, cfg.text_transformer.encoder.embed_dim,
+ self.text_ctc_conv_kernel,
+ stride=self.text_ctc_conv_kernel // 2,
+ bias=False,
+ padding=self.text_ctc_conv_kernel // 2,
+ )
+ nn.init.kaiming_normal_(conv.weight)
+ self.unit_encoder_ctc_head = nn.Sequential(
+ Rotate3D(),
+ conv,
+ nn.Dropout(p=0.1),
+ nn.Sequential(
+ Rotate3D(),
+ Rotate3D(),
+ LayerNorm(cfg.text_transformer.encoder.embed_dim),
+ ),
+ nn.GELU(),
+ nn.Linear(cfg.text_transformer.encoder.embed_dim, len(text_tgt_dictionary)),
+ )
+
+ ### build unit2text decoder, not available for now
+ self.add_decoder = cfg.add_decoder
+ self.text_transformer_cfg = cfg.text_transformer
+ if self.add_decoder:
+ # To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size or bpe code dict size
+ dec_dictionary = self.cutting_dictionary(text_tgt_dictionary, cfg.decoder_dict_size)
+ decoder_embed_tokens = self.build_embedding(
+ dec_dictionary, cfg.text_transformer.decoder.embed_dim
+ )
+ if cfg.reset_decoder_embedding_config:
+ cfg.text_transformer.no_scale_embedding = False
+ cfg.text_transformer.layernorm_embedding = False
+ cfg.text_transformer.no_token_positional_embeddings = False
+ self.decoder = TransformerDecoderBaseScriptable(cfg.text_transformer, dec_dictionary, decoder_embed_tokens, use_rel_pos_enc=cfg.use_rel_pos_enc)
+
+
+ def cutting_dictionary(self, dictionary, dict_size):
+ if dictionary is None or dict_size <= 0:
+ return dictionary
+ else:
+ import copy
+ cut_dictionary = copy.deepcopy(dictionary)
+ if dict_size > len(cut_dictionary):
+ for i in range(dict_size - len(cut_dictionary)):
+ cut_dictionary.symbols.append(f'_{i}_')
+ else:
+ cut_dictionary.symbols = cut_dictionary.symbols[:dict_size]
+ return cut_dictionary
+
+ def build_embedding(self, dictionary, embed_dim):
+ num_embeddings = len(dictionary)
+ padding_idx = dictionary.pad()
+ return Embedding(num_embeddings, embed_dim, padding_idx)
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ @classmethod
+ def build_model(cls, cfg: SpeechutConfig, task: HubertPretrainingTask):
+ """Build a new model instance."""
+ unit_dictionary = getattr(task, "text_src_dictionary", None)
+ text_tgt_dictionary = getattr(task, "text_dictionary", None)
+ model = SpeechutModel(cfg, task.cfg, task.dictionaries, unit_dictionary, text_tgt_dictionary)
+ return model
+
+ def apply_mask(self, x, padding_mask, target_list):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x[mask_indices] = self.mask_emb
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+ def forward_features(self, source: torch.Tensor) -> torch.Tensor:
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(source)
+ return features
+
+ def forward_targets(
+ self,
+ features: torch.Tensor,
+ target_list: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Trim features to ensure labels exist and then get aligned labels
+ feat_tsz = features.size(2)
+ targ_tsz = min([t.size(1) for t in target_list])
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
+ features = features[..., :feat_tsz]
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
+ target_inds += np.random.choice(int(self.feat2tar_ratio))
+ target_list = [t[:, target_inds.long()] for t in target_list]
+ return features, target_list
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def get_normalized_probs(
+ self,
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
+ log_probs: bool,
+ sample: Optional[Dict[str, Tensor]] = None,
+ ):
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
+ lprobs.batch_first = True
+ return lprobs
+
+ def downsample_ctc_padding_mask(self, padding_mask):
+ """
+ padding_mask: (B, T)
+ """
+ stride = self.text_ctc_conv_kernel // 2
+ return padding_mask[:, ::stride]
+
+ def compute_pred(self, proj_x, label_embs):
+ if self.target_glu:
+ label_embs = self.target_glu(label_embs)
+ x = F.normalize(proj_x.float(), dim=-1) # (S, D)
+ label_embs = F.normalize(label_embs.float(), dim=-1) # (C, D)
+ logits = torch.matmul(x, label_embs.T).type_as(proj_x) # (S, C)
+ logits /= self.logit_temp
+ return logits
+
+ def compute_hubert_logits(self, x, target, proj, label_embs, padding_mask, mask_indices):
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ proj_x_m = proj(x[masked_indices])
+ logit_m_list = [(self.compute_pred(proj_x_m, label_embs), target[masked_indices])]
+ else:
+ logit_m_list = [None]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ proj_x_u = proj(x[nomask_indices])
+ logit_u_list = [(self.compute_pred(proj_x_u, label_embs), target[nomask_indices])]
+ else:
+ logit_u_list = [None]
+
+ return logit_m_list, logit_u_list
+
+ def compute_ce_logits(self, x, target, proj, padding_mask, mask_indices):
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ logit_m_list = [(proj(x[masked_indices]), target[masked_indices])]
+ else:
+ logit_m_list = [None]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ logit_u_list = [(proj(x[nomask_indices]), target[nomask_indices])]
+ else:
+ logit_u_list = [None]
+
+ return logit_m_list, logit_u_list
+
+ def convert_embeddings(self,
+ x,
+ padding_mask,
+ target=None,
+ mask_indices=None,
+ mix_with_unit=False,
+ use_pred_unit=False,
+ l2_embedding=False,
+ remask=False
+ ):
+ """
+ 1. Mix with units if needed (default: True)
+ 2. Prepare for unit_encoder inputs
+ Inputs:
+ x, (B, T, D)
+ Return:
+ src_tokens, (B, T)
+ soft_embeddings, (B, T, D)
+ l2_loss, a loss
+ """
+ soft_embeddings = self.final_proj_list[0](x) if x.size(-1) == self.final_dim else x
+ if padding_mask is None:
+ padding_mask = soft_embeddings.new_zeros(soft_embeddings.size(0), soft_embeddings.size(1), dtype=torch.long)
+ if use_pred_unit:
+ src_tokens = self.compute_pred(self.final_proj_list[0](x), self.label_embs_list[0]).argmax(dim=-1)
+ src_tokens[padding_mask] = self.padding_idx
+ elif target is not None:
+ src_tokens = target
+ else:
+ src_tokens = padding_mask.long()
+
+ if l2_embedding | mix_with_unit:
+ unit_embeddings = self.unit_embed_tokens(src_tokens) # (B, T, D)
+
+ l2_loss = 0
+ if l2_embedding:
+ if mask_indices is not None:
+ l2_loss = (soft_embeddings - unit_embeddings)[mask_indices].float().pow(2).mean(dim=-1)
+ scale = unit_embeddings[mask_indices].float().pow(2).sum(dim=-1)
+ else:
+ l2_loss = (soft_embeddings - unit_embeddings).float().pow(2).mean(dim=-1)
+ scale = unit_embeddings.float().pow(2).sum(dim=-1)
+ l2_loss = (l2_loss / scale).mean()
+
+ if mix_with_unit:
+ B, T, D = x.shape
+ selected_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob / 2,
+ self.mask_length // 2,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ selected_indices = torch.from_numpy(selected_indices).to(x.device)
+ if mask_indices is not None:
+ if remask:
+ remask_indices = torch.logical_and(selected_indices, mask_indices)
+ soft_embeddings[remask_indices] = self.mask_emb
+ swap_indices = torch.logical_and(selected_indices, ~mask_indices)
+ else:
+ swap_indices = selected_indices
+ soft_embeddings[swap_indices] = unit_embeddings[swap_indices]
+
+ soft_embeddings = soft_embeddings * (1 - padding_mask.unsqueeze(-1).type_as(x))
+ return src_tokens, soft_embeddings, l2_loss
+
+ def forward(
+ self,
+ source: torch.Tensor = None,
+ src_tokens: torch.Tensor = None,
+ src_lengths: torch.Tensor = None,
+ prev_output_tokens: torch.Tensor = None,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Dict[str, torch.Tensor]:
+ assert source is not None or src_tokens is not None
+ if source is not None:
+ return self.forward_speech(
+ source=source,
+ target_list=target_list,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=features_only,
+ output_layer=output_layer,
+ )
+ else:
+ return self.forward_text(
+ src_tokens=src_tokens,
+ src_lengths=src_lengths,
+ prev_output_tokens=prev_output_tokens,
+ mask=self.mask_u2t,
+ features_only=features_only,
+ output_layer=output_layer,
+ )
+
+ def forward_speech(
+ self,
+ source: torch.Tensor = None,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Dict[str, torch.Tensor]:
+ """output layer is 1-based"""
+ features = self.forward_features(source)
+ if target_list is not None:
+ features, target_list = self.forward_targets(features, target_list)
+
+ features_pen = features.float().pow(2).mean()
+
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+ unmasked_features = features.clone()
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ features = self.dropout_input(features)
+ unmasked_features = self.dropout_features(unmasked_features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
+ else:
+ x = features
+ mask_indices = None
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+ x, _ = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=None if output_layer is None else output_layer - 1,
+ )
+
+ if features_only:
+ return {"x": x, "padding_mask": padding_mask, "features": features}
+
+ logit_m_list, logit_u_list = self.compute_hubert_logits(
+ x,
+ target_list[0],
+ self.final_proj_list[0],
+ self.label_embs_list[0],
+ padding_mask,
+ mask_indices,
+ )
+
+ result = {
+ "logit_m_list": logit_m_list,
+ "logit_u_list": logit_u_list,
+ "padding_mask": padding_mask,
+ "features_pen": features_pen,
+ }
+
+ if self.add_unit_encoder:
+ src_tokens, x_emb, l2_loss = self.convert_embeddings(
+ x,
+ padding_mask, target_list[0],
+ mask_indices=mask_indices,
+ mix_with_unit=self.mix_with_unit,
+ use_pred_unit=self.use_pred_unit,
+ l2_embedding=self.l2_embedding,
+ )
+ encoder_out = self.unit_encoder(src_tokens, token_embeddings=x_emb)
+
+ result['encoder_out'] = encoder_out['encoder_out'] # [(T, B, D)]
+ result['encoder_padding_mask'] = encoder_out['encoder_padding_mask'] # [(B, T)]
+ if self.l2_embedding:
+ result['embedding_l2_loss'] = l2_loss
+
+ code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
+ encoder_out['encoder_out'][0].transpose(0, 1), # -> (B, T, C)
+ target_list[-1],
+ self.final_proj_list[1],
+ padding_mask,
+ mask_indices,
+ )
+ result['logit_m_list'] += code_logit_m_list
+ result['logit_u_list'] += code_logit_u_list
+ return result
+
+ def forward_text(
+ self,
+ src_tokens: torch.Tensor = None,
+ src_lengths: torch.Tensor = None,
+ prev_output_tokens: torch.Tensor = None,
+ target_list: Optional[List[torch.Tensor]] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Dict[str, torch.Tensor]:
+ assert self.add_unit_encoder, f"Can not forward unit-text branch without unit_encoder!"
+
+ padding_mask = src_tokens == self.padding_idx
+ unit_embeddings = self.unit_embed_tokens(src_tokens)
+ if mask:
+ unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, [src_tokens])
+
+ encoder_out = self.unit_encoder(
+ src_tokens,
+ token_embeddings=unit_embeddings,
+ return_all_hiddens=output_layer is not None,
+ )
+
+ result = {}
+ result["encoder_out"] = encoder_out["encoder_out"]
+ result["encoder_states"] = encoder_out["encoder_states"]
+ result["padding_mask"] = padding_mask
+
+ if self.add_text_ctc:
+ result["encoder_out_ctc"] = [self.unit_encoder_ctc_head(x) for x in encoder_out['encoder_out']]
+ result["encoder_padding_mask"] = [
+ self.downsample_ctc_padding_mask(padding_mask) for padding_mask in encoder_out['encoder_padding_mask']
+ ]
+
+ if features_only:
+ return result
+ if self.add_decoder:
+ assert prev_output_tokens is not None
+ decoder_out = self.decoder(
+ prev_output_tokens=prev_output_tokens, encoder_out=encoder_out,
+ )
+ result['decoder_out'] = decoder_out
+ return result
+
+ def forward_mum(self, src_tokens, target, mask=True):
+ target_list = [target]
+ padding_mask = src_tokens.eq(self.unit_encoder.padding_idx)
+ unit_embeddings = self.unit_embed_tokens(src_tokens)
+ if mask:
+ unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, target_list)
+ else:
+ ### If already applied mask on src_tokens, then the target_list should contains many padding_idx
+ mask_indices = target_list[-1] != self.padding_idx
+ unit_embeddings[mask_indices] = self.mask_emb
+
+ encoder_out = self.unit_encoder(
+ src_tokens,
+ token_embeddings=unit_embeddings,
+ )
+ code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
+ encoder_out["encoder_out"][0].transpose(0, 1),
+ target_list[-1],
+ self.final_proj_list[1],
+ padding_mask,
+ mask_indices,
+ )
+ result = {}
+ result["logit_m_list"] = code_logit_m_list
+ result["logit_u_list"] = code_logit_u_list
+ result["padding_mask"] = padding_mask
+ return result
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Extract encoder features for only speech input"""
+ res = self.forward(
+ source,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ output_layer=output_layer,
+ )
+ x = res["x"] # B x T x D
+ padding_mask = res["padding_mask"]
+
+ if self.add_unit_encoder:
+ src_tokens, x, _ = self.convert_embeddings(
+ x,
+ padding_mask,
+ mix_with_unit=False,
+ use_pred_unit=False,
+ )
+ encoder_out = self.unit_encoder(
+ src_tokens,
+ token_embeddings=x,
+ return_all_hiddens=output_layer is not None
+ )
+ res["x"] = encoder_out['encoder_out'][0].transpose(0, 1) # (B, T, D)
+
+ feature = res["features"] if ret_conv else res["x"]
+ if output_layer is not None:
+ feature = encoder_out['encoder_states']
+
+ return feature, padding_mask
+
+ def get_logits(self, net_output, is_masked=True):
+ if is_masked:
+ logits_list = net_output["logit_m_list"]
+ else:
+ logits_list = net_output["logit_u_list"]
+ logits_list = [x[0].float() for x in logits_list if x is not None]
+ return logits_list
+
+ def get_targets(self, net_output, is_masked=True):
+ if is_masked:
+ logits_list = net_output["logit_m_list"]
+ else:
+ logits_list = net_output["logit_u_list"]
+ targets_list = [x[1].long() for x in logits_list if x is not None]
+ return targets_list
+
+ def get_extra_losses(self, net_output):
+ extra_losses = []
+ names = []
+
+ if "features_pen" in net_output:
+ extra_losses.append(net_output["features_pen"])
+ names.append("features_pen")
+
+ if "embedding_l2_loss" in net_output:
+ extra_losses.append(net_output["embedding_l2_loss"])
+ names.append("embedding_l2_loss")
+
+ return extra_losses, names
+
+ def remove_pretraining_modules(self, step2=False):
+ self.target_glu = None
+
+ def load_checkpoint(self, checkpoint: str):
+ if not PathManager.exists(checkpoint):
+ raise IOError("Model file not found: {}".format(checkpoint))
+ state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
+ return state
+
+class Rotate3D(nn.Module):
+ """
+ (T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D)
+ """
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x.permute(1, 2, 0)
diff --git a/SpeechT5/Speech2S/speech2s/models/speechut_asr.py b/SpeechT5/Speech2S/speech2s/models/speechut_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9ec9d8488b4f7e552804d355de000c80fb35b78
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/models/speechut_asr.py
@@ -0,0 +1,165 @@
+# ----------------------------------------------------------------------------
+# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
+# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
+#
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# ----------------------------------------------------------------------------
+
+import contextlib
+import torch
+from dataclasses import dataclass, field
+from fairseq import utils
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.fairseq_encoder import FairseqEncoder
+from fairseq.models.hubert import HubertAsrConfig, HubertEncoder
+from fairseq.tasks import FairseqTask
+
+@dataclass
+class SpeechUTASRConfig(HubertAsrConfig):
+ add_decoder: bool = field(
+ default=True,
+ metadata={"help": "add decoder for fine-tune"},
+ )
+
+@register_model("speechut_asr", dataclass=SpeechUTASRConfig)
+class SpeechUTASR(BaseFairseqModel):
+ """
+ A encoder-ctc-decoder model if cfg.add_decoder is True, or a encoder-ctc model
+ """
+ def __init__(self, cfg: SpeechUTASRConfig, encoder: FairseqEncoder):
+ super().__init__()
+ self.cfg = cfg
+ self.encoder = encoder
+ if not cfg.add_decoder:
+ self.encoder.w2v_model.decoder = None
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ @classmethod
+ def build_model(cls, cfg: SpeechUTASRConfig, task: FairseqTask):
+ """Build a new model instance."""
+ encoder = SpeechUTEncoder(cfg, task)
+ return cls(cfg, encoder)
+
+ def forward(self, source, padding_mask, prev_output_tokens, **kwargs):
+ encoder_out = self.encoder(source, padding_mask, **kwargs)
+
+ x = self.encoder.final_dropout(encoder_out['encoder_out'][0]) # (T, B, C)
+ if self.encoder.proj:
+ x = self.encoder.proj(x)
+ if self.encoder.conv_ctc_proj:
+ padding_mask = self.encoder.w2v_model.downsample_ctc_padding_mask(encoder_out["encoder_padding_mask"][0])
+ else:
+ padding_mask = encoder_out["encoder_padding_mask"]
+
+ decoder_out = self.decoder(
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
+ ) if self.cfg.add_decoder else None
+
+ return {
+ "encoder_out_ctc": x, # (T, B, C), for CTC loss
+ "padding_mask": padding_mask, # (B, T), for CTC loss
+ "decoder_out": decoder_out, # for ED loss
+ }
+
+ def forward_decoder(self, prev_output_tokens, **kwargs):
+ return self.decoder(prev_output_tokens, **kwargs)
+
+ def get_logits(self, net_output):
+ """For CTC decoding"""
+ logits = net_output["encoder_out"]
+ padding = net_output["encoder_padding_mask"]
+ if padding is not None and padding.any():
+ padding = padding.T
+ logits[padding][..., 0] = 0
+ logits[padding][..., 1:] = float("-inf")
+
+ return logits
+
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
+ """For 1) computing CTC loss, 2) decoder decoding."""
+
+ if "encoder_out_ctc" in net_output:
+ logits = net_output["encoder_out_ctc"]
+ else:
+ return self.decoder.get_normalized_probs(net_output, log_probs, sample)
+
+ if isinstance(logits, list):
+ logits = logits[0]
+
+ if log_probs:
+ return utils.log_softmax(logits.float(), dim=-1)
+ else:
+ return utils.softmax(logits.float(), dim=-1)
+
+ @property
+ def decoder(self):
+ return self.encoder.w2v_model.decoder
+
+
+class SpeechUTEncoder(HubertEncoder):
+ """
+ Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
+ 1. make it compatible with encoder-decoder model
+ """
+ def __init__(self, cfg: HubertAsrConfig, task):
+ super().__init__(cfg, task)
+
+ if (task.target_dictionary is not None) and (
+ hasattr(self.w2v_model, "unit_encoder_ctc_head")
+ ):
+ self.proj = self.w2v_model.unit_encoder_ctc_head
+ self.conv_ctc_proj = True
+ else:
+ self.conv_ctc_proj = False
+
+ def forward(self, source, padding_mask, tbc=True, **kwargs):
+ w2v_args = {
+ "source": source,
+ "padding_mask": padding_mask,
+ "mask": self.apply_mask and self.training,
+ }
+ ft = self.freeze_finetune_updates <= self.num_updates
+ with torch.no_grad() if not ft else contextlib.ExitStack():
+ x, padding_mask = self.w2v_model.extract_features(**w2v_args)
+ if tbc:
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ return {
+ "encoder_out": [x], # T x B x C
+ "encoder_padding_mask": [padding_mask], # B x T
+ }
+
+ def forward_torchscript(self, net_input):
+ """A TorchScript-compatible version of forward.
+
+ Forward the encoder out.
+ """
+ x, padding_mask = self.w2v_model.extract_features(**net_input, mask=False)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ encoder_out = {
+ "encoder_out" : [x],
+ "encoder_padding_mask" : [padding_mask],
+ }
+ if self.proj:
+ x = self.proj(x)
+ encoder_out["encoder_out_ctc"] = x
+
+ return encoder_out
+
+ def reorder_encoder_out(self, encoder_out, new_order):
+ if encoder_out["encoder_out"] is not None:
+ encoder_out["encoder_out"] = [
+ x.index_select(1, new_order) for x in encoder_out["encoder_out"]
+ ]
+ if encoder_out["encoder_padding_mask"] is not None:
+ encoder_out["encoder_padding_mask"] = [
+ x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
+ ]
+ return encoder_out
diff --git a/SpeechT5/Speech2S/speech2s/models/speechut_st.py b/SpeechT5/Speech2S/speech2s/models/speechut_st.py
new file mode 100644
index 0000000000000000000000000000000000000000..6faaccfc89748a2692bd1eaec200588449d10423
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/models/speechut_st.py
@@ -0,0 +1,221 @@
+# ----------------------------------------------------------------------------
+# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
+# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
+#
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# ----------------------------------------------------------------------------
+
+import logging
+import contextlib
+import torch
+import torch.nn as nn
+from argparse import Namespace
+from dataclasses import dataclass
+from typing import Any
+from fairseq import checkpoint_utils, tasks
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.fairseq_encoder import FairseqEncoder
+from fairseq.tasks import FairseqTask
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.data.data_utils import lengths_to_padding_mask
+
+from fairseq.models.hubert import HubertAsrConfig
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class SpeechUTS2TConfig(HubertAsrConfig):
+ ### the following config is only for the compatibility to fairseq speech_to_text task
+ input_feat_per_channel: Any = None
+ input_channels: Any = None
+ speaker_to_id: Any = None
+
+@register_model("speechut_st_legacy", dataclass=SpeechUTS2TConfig)
+class SpeechUTS2T(BaseFairseqModel):
+ """An encoder-decoder model."""
+ def __init__(self, cfg: SpeechUTS2TConfig, encoder: FairseqEncoder):
+ super().__init__()
+ self.cfg = cfg
+ self.encoder = encoder
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ @classmethod
+ def build_model(cls, cfg: SpeechUTS2TConfig, task: FairseqTask):
+ """Build a new model instance."""
+ encoder = SpeechUTEncoder(cfg, task)
+ return cls(cfg, encoder)
+
+ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
+ encoder_out = self.encoder(src_tokens, src_lengths, **kwargs)
+ decoder_out = self.encoder.w2v_model.decoder(
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
+ )
+ return decoder_out
+
+ def forward_decoder(self, prev_output_tokens, **kwargs):
+ return self.encoder.w2v_model.decoder(prev_output_tokens, **kwargs)
+
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
+ """For decoder decoding."""
+ return self.encoder.w2v_model.decoder.get_normalized_probs(net_output, log_probs, sample)
+
+ @property
+ def decoder(self):
+ return self.encoder.w2v_model.decoder
+
+
+class SpeechUTEncoder(FairseqEncoder):
+ """
+ Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
+ 1. make it compatible with fairseq speech_to_text task
+ 2. make it compatible with encoder-decoder model
+ """
+ def __init__(self, cfg: SpeechUTS2TConfig, task):
+ self.apply_mask = cfg.apply_mask
+
+ arg_overrides = {
+ "dropout": cfg.dropout,
+ "activation_dropout": cfg.activation_dropout,
+ "dropout_input": cfg.dropout_input,
+ "attention_dropout": cfg.attention_dropout,
+ "mask_length": cfg.mask_length,
+ "mask_prob": cfg.mask_prob,
+ "mask_selection": cfg.mask_selection,
+ "mask_other": cfg.mask_other,
+ "no_mask_overlap": cfg.no_mask_overlap,
+ "mask_channel_length": cfg.mask_channel_length,
+ "mask_channel_prob": cfg.mask_channel_prob,
+ "mask_channel_selection": cfg.mask_channel_selection,
+ "mask_channel_other": cfg.mask_channel_other,
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
+ "encoder_layerdrop": cfg.layerdrop,
+ "feature_grad_mult": cfg.feature_grad_mult,
+ }
+
+ if cfg.w2v_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
+ w2v_args = state.get("cfg", None)
+ if w2v_args is None:
+ w2v_args = convert_namespace_to_omegaconf(state["args"])
+ cfg.w2v_args = w2v_args
+ else:
+ state = None
+ w2v_args = cfg.w2v_args
+ if isinstance(w2v_args, Namespace):
+ cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
+
+ assert task.data_cfg.standardize_audio() == w2v_args.task.normalize, (
+ "Fine-tuning works best when data normalization is the same. "
+ "Please check that --normalize is set or unset for "
+ "both pre-training and here"
+ )
+
+ pretrain_task = tasks.setup_task(w2v_args.task, load_local_states=False)
+ assert state is not None and "task_state" in state, f"the stored dictionaries not found in checkpoint!"
+ # This will load the stored "dictionaries" object
+ pretrain_task.load_state_dict(state["task_state"])
+
+ model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
+ if state is not None and not cfg.no_pretrained_weights:
+ try:
+ model.load_state_dict(state["model"], strict=True)
+ except Exception as e:
+ logger.warn(e)
+ model.load_state_dict(state["model"], strict=False)
+
+ model.remove_pretraining_modules()
+
+ super().__init__(pretrain_task.source_dictionary)
+
+ d = w2v_args.model.encoder_embed_dim
+
+ self.w2v_model = model
+
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
+ self.num_updates = 0
+
+ def set_num_updates(self, num_updates):
+ """Set the number of parameters updates."""
+ super().set_num_updates(num_updates)
+ self.num_updates = num_updates
+
+ def forward(self, src_tokens=None, src_lengths=None, **kwargs):
+
+ w2v_args = {
+ "source": src_tokens,
+ "padding_mask": lengths_to_padding_mask(src_lengths),
+ "mask": self.apply_mask and self.training,
+ }
+
+ ft = self.freeze_finetune_updates <= self.num_updates
+
+ with torch.no_grad() if not ft else contextlib.ExitStack():
+ x, padding_mask = self.w2v_model.extract_features(**w2v_args)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ return {
+ "encoder_out": [x], # T x B x C
+ "encoder_padding_mask": [padding_mask], # B x T
+ "padding_mask": [padding_mask],
+ }
+
+ def forward_torchscript(self, net_input):
+ """A TorchScript-compatible version of forward.
+
+ Forward the encoder out.
+ """
+ _net_input = {
+ "source": net_input["src_tokens"],
+ "padding_mask": lengths_to_padding_mask(net_input["src_lengths"]),
+ "mask": False,
+ }
+
+ x, padding_mask = self.w2v_model.extract_features(**_net_input)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ encoder_out = {
+ "encoder_out" : [x],
+ "encoder_padding_mask" : [padding_mask],
+ }
+ return encoder_out
+
+ def reorder_encoder_out(self, encoder_out, new_order):
+ if encoder_out["encoder_out"] is not None:
+ encoder_out["encoder_out"] = [
+ x.index_select(1, new_order) for x in encoder_out["encoder_out"]
+ ]
+ if encoder_out["encoder_padding_mask"] is not None:
+ encoder_out["encoder_padding_mask"] = [
+ x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
+ ]
+ return encoder_out
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return None
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ return state_dict
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
+
+
+def Linear(in_features, out_features, bias=True):
+ m = nn.Linear(in_features, out_features, bias)
+ nn.init.xavier_uniform_(m.weight)
+ if bias:
+ nn.init.constant_(m.bias, 0.0)
+ return m
diff --git a/SpeechT5/Speech2S/speech2s/models/t5_transformer_lm.py b/SpeechT5/Speech2S/speech2s/models/t5_transformer_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d16a2df00b692114f8d84d254cf486d09e1137b
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/models/t5_transformer_lm.py
@@ -0,0 +1,25 @@
+# --------------------------------------------------------
+# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+from fairseq.models import (
+ register_model_architecture,
+)
+from fairseq.models.transformer_lm import base_lm_architecture
+
+
+@register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
+def transformer_lm_t5(args):
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
+ args.decoder_layers = getattr(args, "decoder_layers", 20)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ base_lm_architecture(args)
diff --git a/SpeechT5/Speech2S/speech2s/modules/__init__.py b/SpeechT5/Speech2S/speech2s/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad97814e515d8e68d68e4e031d4f9c9055f3864
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/__init__.py
@@ -0,0 +1,27 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+from .learned_positional_embedding import LearnedPositionalEmbedding
+from .multihead_attention import MultiheadAttention
+from .relative_pos_enc import RelativePositionalEncoding
+from .transformer_layer import TransformerEncoderLayerBase, TransformerDecoderLayerBase
+from .w2v_encoder import TransformerEncoder, TransformerSentenceEncoderLayer
+from .transformer_encoder import TransformerEncoderBase
+from .transformer_decoder import TransformerDecoderScriptable, TransformerDecoderBaseScriptable
+
+__all__ = [
+ "MultiheadAttention",
+ "RelativePositionalEncoding",
+ "LearnedPositionalEmbedding",
+ "TransformerEncoderLayerBase",
+ "TransformerDecoderLayerBase",
+ "TransformerEncoder",
+ "TransformerSentenceEncoderLayer",
+ "TransformerEncoderBase",
+ "TransformerDecoderScriptable",
+ "TransformerDecoderBaseScriptable",
+]
diff --git a/SpeechT5/Speech2S/speech2s/modules/ctc_prefix_score.py b/SpeechT5/Speech2S/speech2s/modules/ctc_prefix_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..b42cbd819abf7bdd718bef3db3f553c8360ac384
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/ctc_prefix_score.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python3
+
+# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import numpy as np
+import six
+
+
+class CTCPrefixScore(object):
+ """Compute CTC label sequence scores
+ which is based on Algorithm 2 in WATANABE et al.
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
+ but extended to efficiently compute the probablities of multiple labels
+ simultaneously
+ """
+
+ def __init__(self, x, blank, eos, xp):
+ self.xp = xp
+ self.logzero = -10000000000.0
+ self.blank = blank
+ self.eos = eos
+ self.input_length = len(x)
+ self.x = x
+
+ def initial_state(self):
+ """Obtain an initial CTC state
+ :return: CTC state
+ """
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
+ # r_t^n() and r_t^b(), where 0 and 1 of axis=1 represent
+ # superscripts n and b (non-blank and blank), respectively.
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
+ r[0, 1] = self.x[0, self.blank]
+ for i in six.moves.range(1, self.input_length):
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
+ return r
+
+ def __call__(self, y, cs, r_prev):
+ """Compute CTC prefix scores for next labels
+ :param y : prefix label sequence
+ :param cs : array of next labels
+ :param r_prev: previous CTC state
+ :return ctc_scores, ctc_states
+ """
+ # initialize CTC states
+ output_length = len(y) - 1 # ignore sos
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
+ # that corresponds to r_t^n(h) and r_t^b(h).
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
+ xs = self.x[:, cs]
+ if output_length == 0:
+ r[0, 0] = xs[0]
+ r[0, 1] = self.logzero
+ else:
+ r[output_length - 1] = self.logzero
+
+ # prepare forward probabilities for the last label
+ r_sum = self.xp.logaddexp(
+ r_prev[:, 0], r_prev[:, 1]
+ ) # log(r_t^n(g) + r_t^b(g))
+ last = y[-1]
+ if output_length > 0 and last in cs:
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
+ for i in six.moves.range(len(cs)):
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
+ else:
+ log_phi = r_sum
+
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
+ # and log prefix probabilities log(psi)
+ start = max(output_length, 1)
+ log_psi = r[start - 1, 0]
+ for t in six.moves.range(start, self.input_length):
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
+ r[t, 1] = (
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
+ )
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
+
+ # get P(...eos|X) that ends with the prefix itself
+ eos_pos = self.xp.where(cs == self.eos)[0]
+ if len(eos_pos) > 0:
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
+
+ # exclude blank probs
+ blank_pos = self.xp.where(cs == self.blank)[0]
+ if len(blank_pos) > 0:
+ log_psi[blank_pos] = self.logzero
+
+ # return the log prefix probability and CTC states, where the label axis
+ # of the CTC states is moved to the first axis to slice it easily
+ return log_psi, self.xp.rollaxis(r, 2)
diff --git a/SpeechT5/Speech2S/speech2s/modules/learned_positional_embedding.py b/SpeechT5/Speech2S/speech2s/modules/learned_positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..20c8558e20b2172a8c607e2f5c32aa146ff2b9cf
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/learned_positional_embedding.py
@@ -0,0 +1,69 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+"""
+ Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py
+ 1. Add clamping if the input length exceeds the max-source-tokens
+"""
+
+from typing import Dict, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fairseq import utils
+from torch import Tensor
+
+
+class LearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ Padding ids are ignored by either offsetting based on padding_idx
+ or by setting padding_idx to None and ensuring that the appropriate
+ position ids are passed to the forward function.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.onnx_trace = False
+ if self.padding_idx is not None:
+ self.max_positions = self.num_embeddings - self.padding_idx - 1
+ else:
+ self.max_positions = self.num_embeddings
+
+ def forward(
+ self,
+ input: Tensor,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ positions: Optional[Tensor] = None,
+ ):
+ """Input is expected to be of size [bsz x seqlen]."""
+ assert (positions is None) or (
+ self.padding_idx is None
+ ), "If positions is pre-computed then padding_idx should not be set."
+
+ if positions is None:
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
+ positions = torch.zeros(
+ (1, 1), device=input.device, dtype=input.dtype
+ ).fill_(int(self.padding_idx + input.size(1)))
+ else:
+ positions = utils.make_positions(
+ input, self.padding_idx, onnx_trace=self.onnx_trace
+ )
+ positions = torch.clamp(positions, max=self.padding_idx + self.max_positions)
+ return F.embedding(
+ positions,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
diff --git a/SpeechT5/Speech2S/speech2s/modules/multihead_attention.py b/SpeechT5/Speech2S/speech2s/modules/multihead_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..89f46ab628ebe7faa1a3db2fd4f31a7269bb006a
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/multihead_attention.py
@@ -0,0 +1,346 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+from typing import Dict, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq import utils
+from torch import Tensor
+
+from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
+
+
+class MultiheadAttention(FairseqMultiheadAttention):
+ """Multi-headed attention.
+
+ See "Attention Is All You Need" for more details.
+ """
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ q_noise=0.0,
+ qn_block_size=8,
+ scaling_for_att=1.0
+ ):
+ super().__init__(
+ embed_dim,
+ num_heads,
+ kdim,
+ vdim,
+ dropout,
+ bias,
+ add_bias_kv,
+ add_zero_attn,
+ self_attention,
+ encoder_decoder_attention,
+ q_noise,
+ qn_block_size,
+ )
+ self.scaling_for_att = scaling_for_att
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ value: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ need_weights: bool = True,
+ static_kv: bool = False,
+ attn_mask: Optional[Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ position_bias: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ is_tpu = query.device.type == "xla"
+
+ tgt_len, bsz, embed_dim = query.size()
+ src_len = tgt_len
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ assert key_bsz == bsz
+ assert value is not None
+ assert src_len, bsz == value.shape[:2]
+
+ if (
+ not self.onnx_trace
+ and not is_tpu # don't use PyTorch version on TPUs
+ and incremental_state is None
+ and not static_kv
+ # A workaround for quantization to work. Otherwise JIT compilation
+ # treats bias in linear module as method.
+ and not torch.jit.is_scripting()
+ and position_bias is None
+ ):
+ assert key is not None and value is not None
+ return F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ torch.empty([0]),
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout_module.p,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training or self.dropout_module.apply_during_inference,
+ key_padding_mask,
+ need_weights,
+ attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+ q *= self.scaling
+ q *= (1 / self.scaling_for_att)
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+ ],
+ dim=1,
+ )
+
+ q = (
+ q.contiguous()
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ if k is not None:
+ k = (
+ k.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ if v is not None:
+ v = (
+ v.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = torch.cat([prev_key, k], dim=1)
+ src_len = k.size(1)
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = torch.cat([prev_value, v], dim=1)
+ prev_key_padding_mask: Optional[Tensor] = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ # In this branch incremental_state is never None
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ assert k is not None
+ assert k.size(1) == src_len
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
+ key_padding_mask
+ ),
+ ],
+ dim=1,
+ )
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ if position_bias is not None: ## first order
+ ## position_bias: [241, 241, 64]
+ #print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
+ reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
+ #print ("reshape_q: ", reshape_q.size())
+ B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
+ #print ("B: ", B.size()) ## [241, 492, 241]
+ #B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
+ B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
+ #print ("B 2: ", B.size())
+ attn_weights += B
+
+ attn_weights *= self.scaling_for_att
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ if self.onnx_trace:
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ if not is_tpu:
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+ else:
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if self.scaling_for_att > 1.0:
+ attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0]
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = utils.softmax(
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
+ )
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+
+ assert v is not None
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ if self.onnx_trace and attn.size(1) == 1:
+ # when ONNX tracing a single decoder step (sequence length == 1)
+ # the transpose is a no-op copy before view, thus unnecessary
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
+ else:
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+ attn_weights: Optional[Tensor] = None
+ if need_weights:
+ attn_weights = attn_weights_float.view(
+ bsz, self.num_heads, tgt_len, src_len
+ ).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights
diff --git a/SpeechT5/Speech2S/speech2s/modules/relative_pos_enc.py b/SpeechT5/Speech2S/speech2s/modules/relative_pos_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7021fc0941fef310ca5571c101b8a8e18ffc1db6
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/relative_pos_enc.py
@@ -0,0 +1,33 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+import torch
+
+class RelativePositionalEncoding(torch.nn.Module):
+ def __init__(self, d_model, maxlen=1000, embed_v=False):
+ super(RelativePositionalEncoding, self).__init__()
+
+ self.d_model = d_model
+ self.maxlen = maxlen
+ self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
+ if embed_v:
+ self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
+ self.embed_v = embed_v
+
+
+ def forward(self, pos_seq, incremental_state=None):
+ pos_seq[pos_seq < -self.maxlen] = -self.maxlen
+ pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
+ pos_seq = pos_seq + self.maxlen
+
+ if incremental_state is not None:
+ pos_seq = pos_seq[-1:]
+
+ if self.embed_v:
+ return self.pe_k(pos_seq), self.pe_v(pos_seq)
+ else:
+ return self.pe_k(pos_seq), None
diff --git a/SpeechT5/Speech2S/speech2s/modules/transformer_decoder.py b/SpeechT5/Speech2S/speech2s/modules/transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..84417b44b2672e49cf92bad8355d2dae48661b55
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/transformer_decoder.py
@@ -0,0 +1,543 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+"""
+ Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/transformer/transformer_decoder.py
+"""
+
+import math
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn as nn
+from fairseq import utils
+from fairseq.distributed import fsdp_wrap
+from fairseq.models import FairseqIncrementalDecoder
+from fairseq.models.transformer import TransformerConfig
+from fairseq.modules import (
+ AdaptiveSoftmax,
+ BaseLayer,
+ FairseqDropout,
+ LayerDropModuleList,
+ LayerNorm,
+ PositionalEmbedding,
+ SinusoidalPositionalEmbedding,
+)
+from fairseq.modules.checkpoint_activations import checkpoint_wrapper
+from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
+from torch import Tensor
+
+from speechut.modules import transformer_layer
+from speechut.modules import RelativePositionalEncoding
+
+# rewrite name for backward compatibility in `make_generation_fast_`
+def module_name_fordropout(module_name: str) -> str:
+ if module_name == "TransformerDecoderBase":
+ return "TransformerDecoder"
+ else:
+ return module_name
+
+
+class TransformerDecoderBase(FairseqIncrementalDecoder):
+ """
+ Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
+ is a :class:`TransformerDecoderLayer`.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
+ embed_tokens (torch.nn.Embedding): output embedding
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
+ (default: False).
+ """
+
+ def __init__(
+ self,
+ cfg,
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=False,
+ output_projection=None,
+ use_rel_pos_enc=False,
+ ):
+ self.cfg = cfg
+ super().__init__(dictionary)
+ self.register_buffer("version", torch.Tensor([3]))
+ self._future_mask = torch.empty(0)
+
+ self.dropout_module = FairseqDropout(
+ cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
+ )
+ self.decoder_layerdrop = cfg.decoder.layerdrop
+ self.share_input_output_embed = cfg.share_decoder_input_output_embed
+
+ input_embed_dim = embed_tokens.embedding_dim
+ embed_dim = cfg.decoder.embed_dim
+ self.embed_dim = embed_dim
+ self.output_embed_dim = cfg.decoder.output_dim
+
+ self.padding_idx = embed_tokens.padding_idx
+ self.max_target_positions = cfg.max_target_positions
+
+ self.embed_tokens = embed_tokens
+
+ self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
+
+ if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
+ self.quant_noise = apply_quant_noise_(
+ nn.Linear(embed_dim, embed_dim, bias=False),
+ cfg.quant_noise.pq,
+ cfg.quant_noise.pq_block_size,
+ )
+ else:
+ self.quant_noise = None
+
+ self.project_in_dim = (
+ Linear(input_embed_dim, embed_dim, bias=False)
+ if embed_dim != input_embed_dim
+ else None
+ )
+ self.embed_positions = (
+ PositionalEmbedding(
+ self.max_target_positions,
+ embed_dim,
+ self.padding_idx,
+ learned=cfg.decoder.learned_pos,
+ )
+ if not cfg.no_token_positional_embeddings
+ else None
+ )
+ if cfg.layernorm_embedding:
+ self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
+ else:
+ self.layernorm_embedding = None
+
+ self.cross_self_attention = cfg.cross_self_attention
+
+ if self.decoder_layerdrop > 0.0:
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
+ else:
+ self.layers = nn.ModuleList([])
+ self.use_rel_pos_enc = use_rel_pos_enc
+ self.layers.extend(
+ [
+ self.build_decoder_layer(cfg, no_encoder_attn)
+ for _ in range(cfg.decoder.layers)
+ ]
+ )
+ self.num_layers = len(self.layers)
+
+ if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
+ self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
+ else:
+ self.layer_norm = None
+
+ self.project_out_dim = (
+ Linear(embed_dim, self.output_embed_dim, bias=False)
+ if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
+ else None
+ )
+
+ self.adaptive_softmax = None
+ self.output_projection = output_projection
+ if self.output_projection is None:
+ self.build_output_projection(cfg, dictionary, embed_tokens)
+ if self.use_rel_pos_enc:
+ self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.decoder.attention_heads, 24)
+
+ def build_output_projection(self, cfg, dictionary, embed_tokens):
+ if cfg.adaptive_softmax_cutoff is not None:
+ self.adaptive_softmax = AdaptiveSoftmax(
+ len(dictionary),
+ self.output_embed_dim,
+ utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
+ dropout=cfg.adaptive_softmax_dropout,
+ adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
+ factor=cfg.adaptive_softmax_factor,
+ tie_proj=cfg.tie_adaptive_proj,
+ )
+ elif self.share_input_output_embed:
+ self.output_projection = nn.Linear(
+ self.embed_tokens.weight.shape[1],
+ self.embed_tokens.weight.shape[0],
+ bias=False,
+ )
+ self.output_projection.weight = self.embed_tokens.weight
+ else:
+ self.output_projection = nn.Linear(
+ self.output_embed_dim, len(dictionary), bias=False
+ )
+ nn.init.normal_(
+ self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
+ )
+ num_base_layers = cfg.base_layers
+ for i in range(num_base_layers):
+ self.layers.insert(
+ ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
+ BaseLayer(cfg),
+ )
+
+ def build_decoder_layer(self, cfg, no_encoder_attn=False):
+ layer = transformer_layer.TransformerDecoderLayerBase(cfg, no_encoder_attn, has_relative_attention_bias=self.use_rel_pos_enc)
+ checkpoint = cfg.checkpoint_activations
+ if checkpoint:
+ offload_to_cpu = cfg.offload_activations
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
+ # if we are checkpointing, enforce that FSDP always wraps the
+ # checkpointed layer, regardless of layer size
+ min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
+ return layer
+
+ def forward(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ features_only: bool = False,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ src_lengths: Optional[Any] = None,
+ return_all_hiddens: bool = False,
+ ):
+ """
+ Args:
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
+ `(batch, tgt_len)`, for teacher forcing
+ encoder_out (optional): output from the encoder, used for
+ encoder-side attention, should be of size T x B x C
+ incremental_state (dict): dictionary used for storing state during
+ :ref:`Incremental decoding`
+ features_only (bool, optional): only return features without
+ applying output layer (default: False).
+ full_context_alignment (bool, optional): don't apply
+ auto-regressive mask to self-attention (default: False).
+
+ Returns:
+ tuple:
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
+ - a dictionary with any model-specific outputs
+ """
+
+ x, extra = self.extract_features(
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ incremental_state=incremental_state,
+ full_context_alignment=full_context_alignment,
+ alignment_layer=alignment_layer,
+ alignment_heads=alignment_heads,
+ )
+
+ if not features_only:
+ x = self.output_layer(x)
+ return x, extra
+
+ def extract_features(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]],
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ ):
+ return self.extract_features_scriptable(
+ prev_output_tokens,
+ encoder_out,
+ incremental_state,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ )
+
+ """
+ A scriptable subclass of this class has an extract_features method and calls
+ super().extract_features, but super() is not supported in torchscript. A copy of
+ this function is made to be used in the subclass instead.
+ """
+
+ def extract_features_scriptable(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]],
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ ):
+ """
+ Similar to *forward* but only return features.
+
+ Includes several features from "Jointly Learning to Align and
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
+
+ Args:
+ full_context_alignment (bool, optional): don't apply
+ auto-regressive mask to self-attention (default: False).
+ alignment_layer (int, optional): return mean alignment over
+ heads at this layer (default: last layer).
+ alignment_heads (int, optional): only average alignment over
+ this many heads (default: all heads).
+
+ Returns:
+ tuple:
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
+ - a dictionary with any model-specific outputs
+ """
+ bs, slen = prev_output_tokens.size()
+ if alignment_layer is None:
+ alignment_layer = self.num_layers - 1
+
+ enc: Optional[Tensor] = None
+ padding_mask: Optional[Tensor] = None
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
+ enc = encoder_out["encoder_out"][0]
+ assert (
+ enc.size()[1] == bs
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
+ padding_mask = encoder_out["encoder_padding_mask"][0]
+
+ # embed positions
+ positions = None
+ if self.embed_positions is not None:
+ positions = self.embed_positions(
+ prev_output_tokens, incremental_state=incremental_state
+ )
+
+ if incremental_state is not None:
+ prev_output_tokens = prev_output_tokens[:, -1:]
+ if positions is not None:
+ positions = positions[:, -1:]
+
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+
+ if self.quant_noise is not None:
+ x = self.quant_noise(x)
+
+ if self.project_in_dim is not None:
+ x = self.project_in_dim(x)
+
+ if positions is not None:
+ x += positions
+
+ if self.layernorm_embedding is not None:
+ x = self.layernorm_embedding(x)
+
+ x = self.dropout_module(x)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ if self.use_rel_pos_enc:
+ pos_seq = torch.arange(0, slen).long().to(x.device)
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
+ pos_k, _ = self.pos_emb(pos_seq, incremental_state)
+ else:
+ pos_k = None
+
+ self_attn_padding_mask: Optional[Tensor] = None
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
+
+ # decoder layers
+ attn: Optional[Tensor] = None
+ inner_states: List[Optional[Tensor]] = [x]
+ for idx, layer in enumerate(self.layers):
+ if incremental_state is None and not full_context_alignment:
+ self_attn_mask = self.buffered_future_mask(x)
+ else:
+ self_attn_mask = None
+
+ x, layer_attn, _ = layer(
+ x,
+ enc,
+ padding_mask,
+ incremental_state,
+ self_attn_mask=self_attn_mask,
+ self_attn_padding_mask=self_attn_padding_mask,
+ need_attn=bool((idx == alignment_layer)),
+ need_head_weights=bool((idx == alignment_layer)),
+ pos_bias=pos_k,
+ )
+ inner_states.append(x)
+ if layer_attn is not None and idx == alignment_layer:
+ attn = layer_attn.float().to(x)
+
+ if attn is not None:
+ if alignment_heads is not None:
+ attn = attn[:alignment_heads]
+
+ # average probabilities over heads
+ attn = attn.mean(dim=0)
+
+ if self.layer_norm is not None:
+ x = self.layer_norm(x)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ if self.project_out_dim is not None:
+ x = self.project_out_dim(x)
+
+ return x, {"attn": [attn], "inner_states": inner_states}
+
+ def output_layer(self, features):
+ """Project features to the vocabulary size."""
+ if self.adaptive_softmax is None:
+ # project back to size of vocabulary
+ return self.output_projection(features)
+ else:
+ return features
+
+ def max_positions(self):
+ """Maximum output length supported by the decoder."""
+ if self.embed_positions is None:
+ return self.max_target_positions
+ return min(self.max_target_positions, self.embed_positions.max_positions)
+
+ def buffered_future_mask(self, tensor):
+ dim = tensor.size(0)
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
+ if (
+ self._future_mask.size(0) == 0
+ or (not self._future_mask.device == tensor.device)
+ or self._future_mask.size(0) < dim
+ ):
+ self._future_mask = torch.triu(
+ utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
+ )
+ self._future_mask = self._future_mask.to(tensor)
+ return self._future_mask[:dim, :dim]
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
+ weights_key = "{}.embed_positions.weights".format(name)
+ if weights_key in state_dict:
+ del state_dict[weights_key]
+ state_dict[
+ "{}.embed_positions._float_tensor".format(name)
+ ] = torch.FloatTensor(1)
+
+ if f"{name}.output_projection.weight" not in state_dict:
+ if self.share_input_output_embed:
+ embed_out_key = f"{name}.embed_tokens.weight"
+ else:
+ embed_out_key = f"{name}.embed_out"
+ if embed_out_key in state_dict:
+ state_dict[f"{name}.output_projection.weight"] = state_dict[
+ embed_out_key
+ ]
+ if not self.share_input_output_embed:
+ del state_dict[embed_out_key]
+
+ for i in range(self.num_layers):
+ # update layer norms
+ layer_norm_map = {
+ "0": "self_attn_layer_norm",
+ "1": "encoder_attn_layer_norm",
+ "2": "final_layer_norm",
+ }
+ for old, new in layer_norm_map.items():
+ for m in ("weight", "bias"):
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
+ if k in state_dict:
+ state_dict[
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
+ ] = state_dict[k]
+ del state_dict[k]
+
+ version_key = "{}.version".format(name)
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
+ # earlier checkpoints did not normalize after the stack of layers
+ self.layer_norm = None
+ self.normalize = False
+ state_dict[version_key] = torch.Tensor([1])
+
+ return state_dict
+
+
+def Linear(in_features, out_features, bias=True):
+ m = nn.Linear(in_features, out_features, bias)
+ nn.init.xavier_uniform_(m.weight)
+ if bias:
+ nn.init.constant_(m.bias, 0.0)
+ return m
+
+class TransformerDecoderBaseScriptable(TransformerDecoderBase):
+ def extract_features(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ ):
+ # call scriptable method from parent class
+ x, _ = self.extract_features_scriptable(
+ prev_output_tokens,
+ encoder_out,
+ incremental_state,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ )
+ return x, None
+
+
+class TransformerDecoder(TransformerDecoderBase):
+ def __init__(
+ self,
+ args,
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=False,
+ output_projection=None,
+ ):
+ self.args = args
+ super().__init__(
+ TransformerConfig.from_namespace(args),
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=no_encoder_attn,
+ output_projection=output_projection,
+ use_rel_pos_enc=getattr(args, "use_rel_pos_enc", False),
+ )
+
+ def build_output_projection(self, args, dictionary, embed_tokens):
+ super().build_output_projection(
+ TransformerConfig.from_namespace(args), dictionary, embed_tokens
+ )
+
+ def build_decoder_layer(self, args, no_encoder_attn=False):
+ return super().build_decoder_layer(
+ TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
+ )
+
+class TransformerDecoderScriptable(TransformerDecoder):
+ def extract_features(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ ):
+ # call scriptable method from parent class
+ x, _ = self.extract_features_scriptable(
+ prev_output_tokens,
+ encoder_out,
+ incremental_state,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ )
+ return x, None
diff --git a/SpeechT5/Speech2S/speech2s/modules/transformer_encoder.py b/SpeechT5/Speech2S/speech2s/modules/transformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f94e1fed8a005ec59d1e422157e08d88ff95bfda
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/transformer_encoder.py
@@ -0,0 +1,401 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+import math
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fairseq import utils
+from fairseq.distributed import fsdp_wrap
+from fairseq.models import FairseqEncoder
+from fairseq.modules import (
+ FairseqDropout,
+ LayerDropModuleList,
+ LayerNorm,
+ SinusoidalPositionalEmbedding,
+)
+from fairseq.modules.checkpoint_activations import checkpoint_wrapper
+from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
+from torch import Tensor
+from fairseq.models.transformer import (
+ TransformerConfig,
+)
+
+
+from speechut.modules import transformer_layer, LearnedPositionalEmbedding
+from speechut.modules import RelativePositionalEncoding
+
+# rewrite name for backward compatibility in `make_generation_fast_`
+def module_name_fordropout(module_name: str) -> str:
+ if module_name == "TransformerEncoderBase":
+ return "TransformerEncoder"
+ else:
+ return module_name
+
+
+class TransformerEncoderBase(FairseqEncoder):
+ """
+ Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
+ is a :class:`TransformerEncoderLayer`.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
+ embed_tokens (torch.nn.Embedding): input embedding
+ """
+
+ def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0):
+ self.cfg = cfg
+ super().__init__(dictionary)
+ self.register_buffer("version", torch.Tensor([3]))
+
+ self.dropout_module = FairseqDropout(
+ cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
+ )
+ self.encoder_layerdrop = cfg.encoder.layerdrop
+
+ embed_dim = embed_tokens.embedding_dim
+ self.padding_idx = embed_tokens.padding_idx
+ self.max_source_positions = cfg.max_source_positions
+
+ self.embed_tokens = embed_tokens
+
+ self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
+
+ self.embed_positions = (
+ PositionalEmbedding(
+ cfg.max_source_positions,
+ embed_dim,
+ self.padding_idx,
+ learned=cfg.encoder.learned_pos,
+ )
+ if not cfg.no_token_positional_embeddings
+ else None
+ )
+ if cfg.layernorm_embedding:
+ self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
+ else:
+ self.layernorm_embedding = None
+
+ if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
+ self.quant_noise = apply_quant_noise_(
+ nn.Linear(embed_dim, embed_dim, bias=False),
+ cfg.quant_noise.pq,
+ cfg.quant_noise.pq_block_size,
+ )
+ else:
+ self.quant_noise = None
+
+ if self.encoder_layerdrop > 0.0:
+ self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
+ else:
+ self.layers = nn.ModuleList([])
+ self.use_rel_pos_enc = use_rel_pos_enc
+ self.scaling_for_att = scaling_for_att
+ self.layers.extend(
+ [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]
+ )
+ self.num_layers = len(self.layers)
+
+ if cfg.encoder.normalize_before:
+ self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
+ else:
+ self.layer_norm = None
+ if self.use_rel_pos_enc:
+ self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160)
+
+ def build_encoder_layer(self, cfg):
+ layer = transformer_layer.TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att)
+ checkpoint = cfg.checkpoint_activations
+ if checkpoint:
+ offload_to_cpu = cfg.offload_activations
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
+ # if we are checkpointing, enforce that FSDP always wraps the
+ # checkpointed layer, regardless of layer size
+ min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
+ return layer
+
+ def forward_embedding(
+ self, src_tokens, token_embedding: Optional[torch.Tensor] = None
+ ):
+ # embed tokens and positions
+ if token_embedding is None:
+ token_embedding = self.embed_tokens(src_tokens)
+ x = embed = self.embed_scale * token_embedding
+ if self.embed_positions is not None:
+ x = embed + self.embed_positions(src_tokens)
+ if self.layernorm_embedding is not None:
+ x = self.layernorm_embedding(x)
+ x = self.dropout_module(x)
+ if self.quant_noise is not None:
+ x = self.quant_noise(x)
+ return x, embed
+
+ def forward(
+ self,
+ src_tokens,
+ src_lengths: Optional[torch.Tensor] = None,
+ return_all_hiddens: bool = False,
+ token_embeddings: Optional[torch.Tensor] = None,
+ uniformity_layers: Optional[List[int]] = None,
+ ):
+ """
+ Args:
+ src_tokens (LongTensor): tokens in the source language of shape
+ `(batch, src_len)`
+ src_lengths (torch.LongTensor): lengths of each source sentence of
+ shape `(batch)`
+ return_all_hiddens (bool, optional): also return all of the
+ intermediate hidden states (default: False).
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
+ default `None` will recompute embeddings
+
+ Returns:
+ dict:
+ - **encoder_out** (Tensor): the last encoder layer's output of
+ shape `(src_len, batch, embed_dim)`
+ - **encoder_padding_mask** (ByteTensor): the positions of
+ padding elements of shape `(batch, src_len)`
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
+ of shape `(batch, src_len, embed_dim)`
+ - **encoder_states** (List[Tensor]): all intermediate
+ hidden states of shape `(src_len, batch, embed_dim)`.
+ Only populated if *return_all_hiddens* is True.
+ """
+ return self.forward_scriptable(
+ src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers
+ )
+
+ # TorchScript doesn't support super() method so that the scriptable Subclass
+ # can't access the base class model in Torchscript.
+ # Current workaround is to add a helper function with different name and
+ # call the helper function from scriptable Subclass.
+ def forward_scriptable(
+ self,
+ src_tokens,
+ src_lengths: Optional[torch.Tensor] = None,
+ return_all_hiddens: bool = False,
+ token_embeddings: Optional[torch.Tensor] = None,
+ uniformity_layers: Optional[List[int]] = None,
+ ):
+ """
+ Args:
+ src_tokens (LongTensor): tokens in the source language of shape
+ `(batch, src_len)`
+ src_lengths (torch.LongTensor): lengths of each source sentence of
+ shape `(batch)`
+ return_all_hiddens (bool, optional): also return all of the
+ intermediate hidden states (default: False).
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
+ default `None` will recompute embeddings
+
+ Returns:
+ dict:
+ - **encoder_out** (Tensor): the last encoder layer's output of
+ shape `(src_len, batch, embed_dim)`
+ - **encoder_padding_mask** (ByteTensor): the positions of
+ padding elements of shape `(batch, src_len)`
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
+ of shape `(batch, src_len, embed_dim)`
+ - **encoder_states** (List[Tensor]): all intermediate
+ hidden states of shape `(src_len, batch, embed_dim)`.
+ Only populated if *return_all_hiddens* is True.
+ """
+ # compute padding mask
+ encoder_padding_mask = src_tokens.eq(self.padding_idx)
+ has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
+
+ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
+
+ # account for padding while computing the representation
+ if has_pads:
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ if self.use_rel_pos_enc:
+ x_len = x.shape[0]
+ pos_seq = torch.arange(0, x_len).long().to(x.device)
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
+ pos_k, pos_v = self.pos_emb(pos_seq)
+ else:
+ pos_k = None
+
+ encoder_states = []
+ uniformity_hiddens = []
+
+ if return_all_hiddens:
+ encoder_states.append(x)
+
+ if uniformity_layers is not None and 0 in uniformity_layers:
+ x = F.normalize(x.float(), dim=-1).type_as(x)
+ uniformity_hiddens.append(x)
+
+ # encoder layers
+ for i, layer in enumerate(self.layers):
+ x = layer(
+ x, encoder_padding_mask=encoder_padding_mask if has_pads else None,
+ pos_bias=pos_k,
+ )
+ if uniformity_layers is not None and i+1 in uniformity_layers:
+ x = F.normalize(x.float(), dim=-1).type_as(x)
+ uniformity_hiddens.append(x)
+ if return_all_hiddens:
+ assert encoder_states is not None
+ encoder_states.append(x)
+
+ if self.layer_norm is not None:
+ x = self.layer_norm(x)
+
+ # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
+ # `forward` so we use a dictionary instead.
+ # TorchScript does not support mixed values so the values are all lists.
+ # The empty list is equivalent to None.
+ src_lengths = (
+ src_tokens.ne(self.padding_idx)
+ .sum(dim=1, dtype=torch.int32)
+ .reshape(-1, 1)
+ .contiguous()
+ )
+ return {
+ "encoder_out": [x], # T x B x C
+ "encoder_padding_mask": [encoder_padding_mask], # B x T
+ "encoder_embedding": [encoder_embedding], # B x T x C
+ "encoder_states": encoder_states, # List[T x B x C]
+ "uniformity_hiddens": uniformity_hiddens, # List[T x B x C]
+ "src_tokens": [],
+ "src_lengths": [src_lengths],
+ }
+
+ @torch.jit.export
+ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
+ """
+ Reorder encoder output according to *new_order*.
+
+ Args:
+ encoder_out: output from the ``forward()`` method
+ new_order (LongTensor): desired order
+
+ Returns:
+ *encoder_out* rearranged according to *new_order*
+ """
+ if len(encoder_out["encoder_out"]) == 0:
+ new_encoder_out = []
+ else:
+ new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
+ if len(encoder_out["encoder_padding_mask"]) == 0:
+ new_encoder_padding_mask = []
+ else:
+ new_encoder_padding_mask = [
+ encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
+ ]
+ if len(encoder_out["encoder_embedding"]) == 0:
+ new_encoder_embedding = []
+ else:
+ new_encoder_embedding = [
+ encoder_out["encoder_embedding"][0].index_select(0, new_order)
+ ]
+
+ if len(encoder_out["src_tokens"]) == 0:
+ src_tokens = []
+ else:
+ src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
+
+ if len(encoder_out["src_lengths"]) == 0:
+ src_lengths = []
+ else:
+ src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
+
+ encoder_states = encoder_out["encoder_states"]
+ if len(encoder_states) > 0:
+ for idx, state in enumerate(encoder_states):
+ encoder_states[idx] = state.index_select(1, new_order)
+
+ return {
+ "encoder_out": new_encoder_out, # T x B x C
+ "encoder_padding_mask": new_encoder_padding_mask, # B x T
+ "encoder_embedding": new_encoder_embedding, # B x T x C
+ "encoder_states": encoder_states, # List[T x B x C]
+ "src_tokens": src_tokens, # B x T
+ "src_lengths": src_lengths, # B x 1
+ }
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ if self.embed_positions is None:
+ return self.max_source_positions
+ return min(self.max_source_positions, self.embed_positions.max_positions)
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
+ weights_key = "{}.embed_positions.weights".format(name)
+ if weights_key in state_dict:
+ print("deleting {0}".format(weights_key))
+ del state_dict[weights_key]
+ state_dict[
+ "{}.embed_positions._float_tensor".format(name)
+ ] = torch.FloatTensor(1)
+ for i in range(self.num_layers):
+ # update layer norms
+ self.layers[i].upgrade_state_dict_named(
+ state_dict, "{}.layers.{}".format(name, i)
+ )
+
+ version_key = "{}.version".format(name)
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
+ # earlier checkpoints did not normalize after the stack of layers
+ self.layer_norm = None
+ self.normalize = False
+ state_dict[version_key] = torch.Tensor([1])
+ return state_dict
+
+
+class TransformerEncoder(TransformerEncoderBase):
+ def __init__(self, args, dictionary, embed_tokens):
+ self.args = args
+ super().__init__(
+ TransformerConfig.from_namespace(args),
+ dictionary,
+ embed_tokens,
+ use_rel_pos_enc=getattr(args, "use_rel_pos_enc", False),
+ scaling_for_att=getattr(args, "scaling_for_att", 1.0),
+ )
+
+ def build_encoder_layer(self, args):
+ return super().build_encoder_layer(
+ TransformerConfig.from_namespace(args),
+ )
+
+
+def PositionalEmbedding(
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: int,
+ learned: bool = False,
+):
+ if learned:
+ # if padding_idx is specified then offset the embedding ids by
+ # this index and adjust num_embeddings appropriately
+ # TODO: The right place for this offset would be inside
+ # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
+ if padding_idx is not None:
+ num_embeddings = num_embeddings + padding_idx + 1
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ else:
+ m = SinusoidalPositionalEmbedding(
+ embedding_dim,
+ padding_idx,
+ init_size=num_embeddings + padding_idx + 1,
+ )
+ return m
diff --git a/SpeechT5/Speech2S/speech2s/modules/transformer_layer.py b/SpeechT5/Speech2S/speech2s/modules/transformer_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a71a848f1a5436756168aafd12d71637520b6b67
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/transformer_layer.py
@@ -0,0 +1,330 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+"""
+ Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_layer.py
+ https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_decoder_layer.py
+"""
+
+from typing import Dict, List, Optional
+
+import torch
+from torch import Tensor
+from fairseq.modules import LayerNorm
+from fairseq.modules.transformer_layer import TransformerEncoderLayerBase as FairseqTransformerEncoderLayerBase
+from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase
+
+from speechut.modules import MultiheadAttention
+
+class TransformerEncoderLayerBase(FairseqTransformerEncoderLayerBase):
+ """Encoder layer block.
+
+ In the original paper each operation (multi-head attention or FFN) is
+ postprocessed with: `dropout -> add residual -> layernorm`. In the
+ tensor2tensor code they suggest that learning is more robust when
+ preprocessing each layer with layernorm and postprocessing with:
+ `dropout -> add residual`. We default to the approach in the paper, but the
+ tensor2tensor approach can be enabled by setting
+ *cfg.encoder.normalize_before* to ``True``.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ """
+
+ def __init__(self, cfg, has_relative_attention_bias=False, scaling_for_att=1.0):
+ self.scaling_for_att = scaling_for_att
+ super().__init__(cfg)
+ if has_relative_attention_bias:
+ self.norm_k = LayerNorm(self.embed_dim // cfg.encoder.attention_heads)
+
+ def build_self_attention(self, embed_dim, cfg, scaling_for_att=1.0):
+ return MultiheadAttention(
+ embed_dim,
+ cfg.encoder.attention_heads,
+ dropout=cfg.attention_dropout,
+ self_attention=True,
+ q_noise=self.quant_noise,
+ qn_block_size=self.quant_noise_block_size,
+ scaling_for_att=self.scaling_for_att,
+ )
+
+ def forward(
+ self,
+ x,
+ encoder_padding_mask: Optional[Tensor],
+ attn_mask: Optional[Tensor] = None,
+ pos_bias=None,
+ ):
+ """
+ Args:
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
+ encoder_padding_mask (ByteTensor): binary ByteTensor of shape
+ `(batch, seq_len)` where padding elements are indicated by ``1``.
+ attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
+ where `tgt_len` is the length of output and `src_len` is the
+ length of input, though here both are equal to `seq_len`.
+ `attn_mask[tgt_i, src_j] = 1` means that when calculating the
+ embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
+ useful for strided self-attention.
+
+ Returns:
+ encoded output of shape `(seq_len, batch, embed_dim)`
+ """
+ # anything in original attn_mask = 1, becomes -1e8
+ # anything in original attn_mask = 0, becomes 0
+ # Note that we cannot use -inf here, because at some edge cases,
+ # the attention weight (before softmax) for some padded element in query
+ # will become -inf, which results in NaN in model parameters
+ if attn_mask is not None:
+ attn_mask = attn_mask.masked_fill(
+ attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
+ )
+
+ residual = x
+ if self.normalize_before:
+ x = self.self_attn_layer_norm(x)
+ if pos_bias is not None:
+ pos_bias = self.norm_k(pos_bias)
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask,
+ need_weights=False,
+ attn_mask=attn_mask,
+ position_bias=pos_bias,
+ )
+ x = self.dropout_module(x)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.final_layer_norm(x)
+ x = self.activation_fn(self.fc1(x))
+ x = self.activation_dropout_module(x)
+ x = self.fc2(x)
+ x = self.dropout_module(x)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.final_layer_norm(x)
+ return x
+
+
+
+class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase):
+ """Decoder layer block.
+
+ In the original paper each operation (multi-head attention, encoder
+ attention or FFN) is postprocessed with: `dropout -> add residual ->
+ layernorm`. In the tensor2tensor code they suggest that learning is more
+ robust when preprocessing each layer with layernorm and postprocessing with:
+ `dropout -> add residual`. We default to the approach in the paper, but the
+ tensor2tensor approach can be enabled by setting
+ *cfg.decoder.normalize_before* to ``True``.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
+ (default: False).
+ """
+
+ def __init__(
+ self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False, scaling_for_att=1.0,
+ ):
+ self.scaling_for_att = scaling_for_att
+ super().__init__(cfg,
+ no_encoder_attn,
+ add_bias_kv,
+ add_zero_attn,
+ )
+
+ if has_relative_attention_bias:
+ self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads)
+
+ def build_self_attention(
+ self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
+ ):
+ return MultiheadAttention(
+ embed_dim,
+ cfg.decoder.attention_heads,
+ dropout=cfg.attention_dropout,
+ add_bias_kv=add_bias_kv,
+ add_zero_attn=add_zero_attn,
+ self_attention=not cfg.cross_self_attention,
+ q_noise=self.quant_noise,
+ qn_block_size=self.quant_noise_block_size,
+ scaling_for_att=self.scaling_for_att,
+ )
+
+ def build_encoder_attention(self, embed_dim, cfg):
+ return MultiheadAttention(
+ embed_dim,
+ cfg.decoder.attention_heads,
+ kdim=cfg.encoder.embed_dim,
+ vdim=cfg.encoder.embed_dim,
+ dropout=cfg.attention_dropout,
+ encoder_decoder_attention=True,
+ q_noise=self.quant_noise,
+ qn_block_size=self.quant_noise_block_size,
+ scaling_for_att=self.scaling_for_att,
+ )
+
+ def forward(
+ self,
+ x,
+ encoder_out: Optional[torch.Tensor] = None,
+ encoder_padding_mask: Optional[torch.Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ prev_self_attn_state: Optional[List[torch.Tensor]] = None,
+ prev_attn_state: Optional[List[torch.Tensor]] = None,
+ self_attn_mask: Optional[torch.Tensor] = None,
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
+ need_attn: bool = False,
+ need_head_weights: bool = False,
+ pos_bias=None,
+ ):
+ """
+ Args:
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
+ encoder_padding_mask (ByteTensor, optional): binary
+ ByteTensor of shape `(batch, src_len)` where padding
+ elements are indicated by ``1``.
+ need_attn (bool, optional): return attention weights
+ need_head_weights (bool, optional): return attention weights
+ for each head (default: return average over heads).
+
+ Returns:
+ encoded output of shape `(seq_len, batch, embed_dim)`
+ """
+ if need_head_weights:
+ need_attn = True
+
+ residual = x
+ if self.normalize_before:
+ x = self.self_attn_layer_norm(x)
+ if pos_bias is not None:
+ pos_bias = self.norm_k(pos_bias)
+ if prev_self_attn_state is not None:
+ prev_key, prev_value = prev_self_attn_state[:2]
+ saved_state: Dict[str, Optional[Tensor]] = {
+ "prev_key": prev_key,
+ "prev_value": prev_value,
+ }
+ if len(prev_self_attn_state) >= 3:
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
+ assert incremental_state is not None
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
+ _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
+ if self.cross_self_attention and not (
+ incremental_state is not None
+ and _self_attn_input_buffer is not None
+ and "prev_key" in _self_attn_input_buffer
+ ):
+ if self_attn_mask is not None:
+ assert encoder_out is not None
+ self_attn_mask = torch.cat(
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
+ )
+ if self_attn_padding_mask is not None:
+ if encoder_padding_mask is None:
+ assert encoder_out is not None
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
+ encoder_out.size(1), encoder_out.size(0)
+ )
+ self_attn_padding_mask = torch.cat(
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
+ )
+ assert encoder_out is not None
+ y = torch.cat((encoder_out, x), dim=0)
+ else:
+ y = x
+
+ x, attn = self.self_attn(
+ query=x,
+ key=y,
+ value=y,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ need_weights=False,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias,
+ )
+ if self.c_attn is not None:
+ tgt_len, bsz = x.size(0), x.size(1)
+ x = x.view(tgt_len, bsz, self.nh, self.head_dim)
+ x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
+ x = x.reshape(tgt_len, bsz, self.embed_dim)
+ if self.attn_ln is not None:
+ x = self.attn_ln(x)
+ x = self.dropout_module(x)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.self_attn_layer_norm(x)
+
+ if self.encoder_attn is not None and encoder_out is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.encoder_attn_layer_norm(x)
+ if prev_attn_state is not None:
+ prev_key, prev_value = prev_attn_state[:2]
+ saved_state: Dict[str, Optional[Tensor]] = {
+ "prev_key": prev_key,
+ "prev_value": prev_value,
+ }
+ if len(prev_attn_state) >= 3:
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
+ assert incremental_state is not None
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
+
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ need_weights=need_attn or (not self.training and self.need_attn),
+ need_head_weights=need_head_weights,
+ )
+ x = self.dropout_module(x)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.encoder_attn_layer_norm(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.final_layer_norm(x)
+
+ x = self.activation_fn(self.fc1(x))
+ x = self.activation_dropout_module(x)
+ if self.ffn_layernorm is not None:
+ x = self.ffn_layernorm(x)
+ x = self.fc2(x)
+ x = self.dropout_module(x)
+ if self.w_resid is not None:
+ residual = torch.mul(self.w_resid, residual)
+ x = self.residual_connection(x, residual)
+ if not self.normalize_before:
+ x = self.final_layer_norm(x)
+ if self.onnx_trace and incremental_state is not None:
+ saved_state = self.self_attn._get_input_buffer(incremental_state)
+ assert saved_state is not None
+ if self_attn_padding_mask is not None:
+ self_attn_state = [
+ saved_state["prev_key"],
+ saved_state["prev_value"],
+ saved_state["prev_key_padding_mask"],
+ ]
+ else:
+ self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
+ return x, attn, self_attn_state
+ return x, attn, None
+
+ def make_generation_fast_(self, need_attn: bool = False, **kwargs):
+ self.need_attn = need_attn
diff --git a/SpeechT5/Speech2S/speech2s/modules/w2v_encoder.py b/SpeechT5/Speech2S/speech2s/modules/w2v_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..386f1eb0a4f4f67b552271e65c0b402d197e5bb2
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/modules/w2v_encoder.py
@@ -0,0 +1,281 @@
+# --------------------------------------------------------
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/facebookresearch/fairseq
+# --------------------------------------------------------
+
+"""
+ wav2vec encoder adding relitive position bias, modified from
+ https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_encoder.py
+ https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
+"""
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fairseq import utils
+from fairseq.dataclass import ChoiceEnum
+from fairseq.modules import (
+ LayerNorm,
+ SamePad,
+)
+from fairseq.modules.checkpoint_activations import checkpoint_wrapper
+from fairseq.modules.transformer_sentence_encoder import init_bert_params
+from fairseq.utils import index_put
+from fairseq.distributed import fsdp_wrap
+from fairseq.models.wav2vec.utils import pad_to_multiple
+
+## reload multi-head attition with rel-pos-bias
+from fairseq.models.wav2vec.wav2vec2 import TransformerEncoder as W2vTransformerEncoder
+from speechut.modules import RelativePositionalEncoding
+from speechut.modules import MultiheadAttention
+
+EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
+MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
+
+
+class TransformerEncoder(W2vTransformerEncoder):
+ def __init__(self, args):
+ super().__init__(args)
+
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder_embed_dim
+ self.required_seq_len_multiple = args.required_seq_len_multiple
+ self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
+
+ self.pos_conv = nn.Conv1d(
+ self.embedding_dim,
+ self.embedding_dim,
+ kernel_size=args.conv_pos,
+ padding=args.conv_pos // 2,
+ groups=args.conv_pos_groups,
+ )
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(self.pos_conv.bias, 0)
+
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
+
+ layers = []
+ for _ in range(args.encoder_layers):
+ layer = TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ has_relative_attention_bias=self.use_rel_pos_enc,
+ )
+ if args.checkpoint_activations:
+ layer = fsdp_wrap(layer)
+ layer = checkpoint_wrapper(layer)
+ layers.append(layer)
+ self.layers = nn.ModuleList(layers)
+
+ self.layer_norm_first = args.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = args.encoder_layerdrop
+ if self.use_rel_pos_enc:
+ self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
+
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None, layer=None):
+ x, layer_results = self.extract_features(x, padding_mask, layer)
+
+ if self.layer_norm_first and layer is None:
+ x = self.layer_norm(x)
+
+ return x, layer_results
+
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
+
+ if padding_mask is not None:
+ x = index_put(x, padding_mask, 0)
+
+ x_conv = self.pos_conv(x.transpose(1, 2))
+ x_conv = x_conv.transpose(1, 2)
+ x = x + x_conv
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ # pad to the sequence length dimension
+ x, pad_length = pad_to_multiple(
+ x, self.required_seq_len_multiple, dim=-2, value=0
+ )
+ if pad_length > 0 and padding_mask is None:
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
+ padding_mask[:, -pad_length:] = True
+ else:
+ padding_mask, _ = pad_to_multiple(
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
+ )
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ if self.use_rel_pos_enc:
+ x_len = x.shape[0]
+ pos_seq = torch.arange(0, x_len).long().to(x.device)
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
+ pos_k, pos_v = self.pos_emb(pos_seq)
+ else:
+ pos_k = None
+
+ layer_results = []
+ r = None
+ for i, layer in enumerate(self.layers):
+ dropout_probability = np.random.random()
+ if not self.training or (dropout_probability > self.layerdrop):
+ x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
+ if tgt_layer is not None:
+ # unpad if needed
+ if pad_length > 0:
+ layer_results.append(
+ (
+ x[:-pad_length],
+ z[:, :-pad_length, :-pad_length]
+ if z is not None
+ else z,
+ )
+ )
+ else:
+ layer_results.append((x, z))
+ if i == tgt_layer:
+ r = x
+ break
+
+ if r is not None:
+ x = r
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+ # undo paddding
+ if pad_length > 0:
+ x = x[:, :-pad_length]
+
+ return x, layer_results
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ """
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+ models.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: float = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ has_relative_attention_bias: bool = False,
+ ) -> None:
+
+ super().__init__()
+ # Initialize parameters
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+
+ # Initialize blocks
+ self.activation_fn = utils.get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(
+ self.embedding_dim,
+ num_attention_heads,
+ dropout=attention_dropout,
+ self_attention=True,
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.layer_norm_first = layer_norm_first
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ if has_relative_attention_bias:
+ self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ self_attn_mask: torch.Tensor = None,
+ self_attn_padding_mask: torch.Tensor = None,
+ need_weights: bool = False,
+ att_args=None,
+ pos_bias=None,
+ ):
+ """
+ LayerNorm is applied either before or after the self-attention/ffn
+ modules similar to the original Transformer imlementation.
+ """
+ residual = x
+
+ if self.layer_norm_first:
+ x = self.self_attn_layer_norm(x)
+ if pos_bias is not None:
+ pos_bias = self.norm_k(pos_bias)
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias,
+ )
+ x = self.dropout1(x)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ else:
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ position_bias=pos_bias,
+ )
+
+ x = self.dropout1(x)
+ x = residual + x
+
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ x = self.final_layer_norm(x)
+
+ return x, attn
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_asr.sh b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_asr.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d5bc7311331208c3f2f65c17586c73ee63cd98f0
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_asr.sh
@@ -0,0 +1,40 @@
+# ####################################
+# SpeechUT Base model #
+# ####################################
+[ $# -lt 2 ] && echo "Usage: $0 [mount=${PWD}] [world_size=32] [update_freq=1]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+DATA_DIR=$1
+TEXT_DATA_DIR=$2
+mount=$3
+world_size=$4
+update_freq=$5
+[ -z $mount ] && mount=${PWD}
+[ -z $world_size ] && world_size=32
+[ -z $update_freq ] && update_freq=1
+
+CODE_ROOT=${PWD}
+MODEL_DIR="${mount}/exp/pretrain/base_speechut4asr_${world_size}gpu_${update_freq}accum"
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
+ --config-dir $CODE_ROOT/speechut/config/pretrain \
+ --config-name speechut_base_librispeech \
+ common.user_dir=$CODE_ROOT/speechut \
+ \
+ task.labels='["km"]' \
+ model.label_rate=50 \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.text_cfg.text_data=$TEXT_DATA_DIR \
+ \
+ dataset.train_subset=\"train_960+pseudo_libritext.kmu-ltr+merge_960.kmu-none\" \
+ dataset.valid_subset=\"dev_clean+dev.kmu-ltr+dev.kmu-none\" \
+ dataset.num_workers=0 \
+ dataset.max_tokens=1400000 \
+ distributed_training.distributed_world_size=${world_size} \
+ optimization.update_freq=[${update_freq}] \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=base_speechut4asr_${world_size}gpu_${update_freq}accum
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st.sh b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st.sh
new file mode 100644
index 0000000000000000000000000000000000000000..438a43f55275938c51faefab181dacc1af3567d0
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st.sh
@@ -0,0 +1,47 @@
+# ####################################
+# SpeechUT Base model #
+# ####################################
+[ $# -lt 3 ] && echo "Usage: $0 [mount=${PWD}] [world_size=32] [update_freq=1]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+DATA_DIR=$1
+TEXT_DATA_DIR=$2
+lang=$3
+mount=$4
+world_size=$5
+update_freq=$6
+[ -z $mount ] && mount=${PWD}
+[ -z $world_size ] && world_size=32
+[ -z $update_freq ] && update_freq=1
+
+CODE_ROOT=${PWD}
+MODEL_DIR="${mount}/exp/pretrain/base_speechut4en${lang}_${world_size}gpu_${update_freq}accum"
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
+ --config-dir $CODE_ROOT/speechut/config/pretrain \
+ --config-name speechut_base_librispeech \
+ common.user_dir=$CODE_ROOT/speechut \
+ \
+ task.labels='["km"]' \
+ model.label_rate=50 \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.text_cfg.text_data=$TEXT_DATA_DIR \
+ \
+ model.add_text_ctc=false \
+ model.text_transformer.share_decoder_input_output_embed=true \
+ criterion.u2t_ed_weight=1.0 \
+ criterion.u2t_ctc_weight=0 \
+ \
+ dataset.train_subset=\"train_960,mustcuns_${lang}+pseudo_wmt_en${lang}.kmu-spm+train_960.kmu-none,mustcuns_${lang}.kmu-none\" \
+ dataset.valid_subset=\"dev_clean+pseudo_valid.kmu-spm+dev.kmu-none\" \
+ dataset.num_workers=0 \
+ dataset.max_tokens=1400000 \
+ distributed_training.distributed_world_size=${world_size} \
+ optimization.update_freq=[${update_freq}] \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=base_speechut4en${lang}_${world_size}gpu_${update_freq}accum
+
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st_enfr.sh b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st_enfr.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c0c7217d0c124e603bb3b95ff11b7e7e462290c0
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st_enfr.sh
@@ -0,0 +1,48 @@
+# ####################################
+# SpeechUT Base model #
+# ####################################
+[ $# -lt 3 ] && echo "Usage: $0 [lang=fr] [mount=${PWD}] [world_size=32] [update_freq=1]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+DATA_DIR=$1
+TEXT_DATA_DIR=$2
+lang=$3
+mount=$4
+world_size=$5
+update_freq=$6
+[ -z $lang ] && lang=fr
+[ -z $mount ] && mount=${PWD}
+[ -z $world_size ] && world_size=32
+[ -z $update_freq ] && update_freq=1
+
+CODE_ROOT=${PWD}
+MODEL_DIR="${mount}/exp/pretrain/base_speechut4en${lang}_${world_size}gpu_${update_freq}accum"
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
+ --config-dir $CODE_ROOT/speechut/config/pretrain \
+ --config-name speechut_base_librispeech \
+ common.user_dir=$CODE_ROOT/speechut \
+ \
+ task.labels='["km"]' \
+ model.label_rate=50 \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.text_cfg.text_data=$TEXT_DATA_DIR \
+ \
+ model.add_text_ctc=false \
+ criterion.u2t_ed_weight=1.0 \
+ criterion.u2t_ctc_weight=0 \
+ \
+ dataset.train_subset=\"train_960,pretrain_mustc+pseudo_wmt14_enfr.kmu-spm+train_960.kmu-none,pretrain_mustc.kmu-none\" \
+ dataset.valid_subset=\"dev_clean+pseudo_valid.kmu-spm+dev.kmu-none\" \
+ dataset.num_workers=0 \
+ dataset.max_tokens=1400000 \
+ optimization.max_update=600000 \
+ distributed_training.distributed_world_size=${world_size} \
+ optimization.update_freq=[${update_freq}] \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=base_speechut4en${lang}_${world_size}gpu_${update_freq}accum
+
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/large_speechut_for_asr.sh b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/large_speechut_for_asr.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9d64d789ed0421252edd71aa9c8268a42dc42f3
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/pretrain_speechut/large_speechut_for_asr.sh
@@ -0,0 +1,41 @@
+# ####################################
+# SpeechUT Large model #
+# ####################################
+[ $# -lt 2 ] && echo "Usage: $0 [mount=${PWD}] [world_size=32] [update_freq=4]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+DATA_DIR=$1
+TEXT_DATA_DIR=$2
+mount=$3
+world_size=$4
+update_freq=$5
+[ -z $mount ] && mount=${PWD}
+[ -z $world_size ] && world_size=32
+[ -z $update_freq ] && update_freq=4
+
+CODE_ROOT=${PWD}
+MODEL_DIR="${mount}/exp/pretrain/large_speechut4asr_${world_size}gpu_${update_freq}accum"
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
+ --config-dir $CODE_ROOT/speechut/config/pretrain \
+ --config-name speechut_large_librilight \
+ common.user_dir=$CODE_ROOT/speechut \
+ \
+ task.labels='["km"]' \
+ model.label_rate=50 \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.text_cfg.text_data=$TEXT_DATA_DIR \
+ \
+ dataset.train_subset=\"train_small+pseudo_libritext.kmu-ltr\" \
+ dataset.valid_subset=\"dev_clean+dev.kmu-ltr\" \
+ dataset.num_workers=0 \
+ dataset.max_tokens=900000 \
+ distributed_training.distributed_world_size=${world_size} \
+ optimization.update_freq=[${update_freq}] \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=large_speechut4asr_${world_size}gpu_${update_freq}accum
+
\ No newline at end of file
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/finetune960h_large_edctc.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/finetune960h_large_edctc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..08a25818bc9fc519e65fa175886545a8650c0906
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/finetune960h_large_edctc.sh
@@ -0,0 +1,45 @@
+# ####################################
+# SpeechUT Large model #
+# ####################################
+[ $# -lt 3 ] && echo "Usage: $0 [mount=${PWD}] [world_size=8] [update_freq=3]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+w2v_path=$1
+DATA_DIR=$2
+cpt=$3
+mount=$4
+world_size=$5
+update_freq=$6
+[ -z $mount ] && mount=${PWD}
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=3
+
+CODE_ROOT=${PWD}
+
+exp_name=${w2v_path%/*}
+exp_name=${exp_name##*/}
+MODEL_DIR="${mount}/exp/finetune_asr/$exp_name/960h_edctc80k_from_${cpt}_bz3.3m_lr1e-5"
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
+ --config-dir $CODE_ROOT/speechut/config/finetune_asr \
+ --config-name speechut_large_960h \
+ common.user_dir=$CODE_ROOT/speechut \
+ \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ model.w2v_path=${w2v_path} \
+ \
+ optimization.lr=[0.00001] \
+ optimization.max_update=80000 \
+ dataset.max_tokens=1100000 \
+ optimization.update_freq=[${update_freq}] \
+ distributed_training.distributed_world_size=${world_size} \
+ \
+ dataset.train_subset="train_960" \
+ dataset.valid_subset="dev_other" \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=960h_edctc80k_from_${cpt}_bz3.3m_lr1e-5
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/finetune_base_edctc.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/finetune_base_edctc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cad7bd0a11336a2b5e0c34372d57b7b4b953a414
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/finetune_base_edctc.sh
@@ -0,0 +1,45 @@
+# ####################################
+# SpeechUT Base model #
+# ####################################
+[ $# -lt 3 ] && echo "Usage: $0 [mount=${PWD}] [world_size=8] [update_freq=2]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+w2v_path=$1
+DATA_DIR=$2
+cpt=$3
+mount=$4
+world_size=$5
+update_freq=$6
+[ -z $mount ] && mount=${PWD}
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=2
+
+CODE_ROOT=${PWD}
+
+exp_name=${w2v_path%/*}
+exp_name=${exp_name##*/}
+MODEL_DIR="${mount}/exp/finetune_asr/$exp_name/edctc40k_from_${cpt}_bz2.6m_lr1e-5"
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
+ --config-dir $CODE_ROOT/speechut/config/finetune_asr \
+ --config-name speechut_base_100h \
+ common.user_dir=$CODE_ROOT/speechut \
+ \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ model.w2v_path=${w2v_path} \
+ \
+ optimization.lr=[0.00001] \
+ optimization.max_update=40000 \
+ dataset.max_tokens=1300000 \
+ optimization.update_freq=[${update_freq}] \
+ distributed_training.distributed_world_size=${world_size} \
+ \
+ dataset.train_subset="train_clean_100" \
+ dataset.valid_subset="dev_other" \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=edctc40k_from_${cpt}_bz2.6m_lr1e-5
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_edctc.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_edctc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9dce06398c476a26290839b7f3a8f8632a5060e0
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_edctc.sh
@@ -0,0 +1,61 @@
+#####################################
+# SpeechUT ASR model #
+#####################################
+[ $# -lt 2 ] && echo "Usage: $0 [gen-set=dev_other] [beam_size=10] [ctc_weight=0.2] [--normalize]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+model_path=$1
+DATA_DIR=$2
+gen_set=$3
+beam_size=$4
+ctc_weight=$5
+extra=$6
+[ -z $extra ] && echo "Assert decoding base model! If you are decoding large model, please add '--normalize' at the end..."
+[ -z $gen_set ] && gen_set="dev_other"
+[ -z $beam_size ] && beam_size=10
+[ -z $ctc_weight ] && ctc_weight=0.2
+[ $ctc_weight == 0 ] && [ $beam_size != 1 ] && echo "Change beam size to 1 as no ctc-decoding used..." && beam_size=1
+[ $ctc_weight != 0 ] && extra="$extra --batch-size 1"
+
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+CODE_ROOT=${PWD}
+
+for subset in ${gen_set//,/ }; do
+ results_path=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}/${subset}_${world_size}_${rank}
+ [ ! -d $results_path ] && mkdir -p $results_path
+
+ python $CODE_ROOT/fairseq/fairseq_cli/generate.py $DATA_DIR \
+ --user-dir $CODE_ROOT/speechut \
+ --label-dir ${DATA_DIR} \
+ --labels '["ltr"]' \
+ --single-target \
+ --post-process letter \
+ --gen-subset ${subset} \
+ --max-tokens 2000000 \
+ \
+ --task joint_sc2t_pretraining \
+ --add-decoder-target \
+ --fine-tuning \
+ --pad-audio \
+ --random-crop \
+ \
+ --ctc-weight ${ctc_weight} $extra \
+ --beam ${beam_size} \
+ \
+ --path ${model_path} \
+ --results-path $results_path \
+ \
+ --scoring wer --max-len-a 0.00078125 --max-len-b 200 \
+ &
+done
+wait
+
+
+for subset in ${gen_set//,/ }; do
+ results_path=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}/${subset}_${world_size}_${rank}
+ echo $results_path
+ tail -n 1 $results_path/generate-*.txt
+done
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_edctclm.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_edctclm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dadd1a4286de52cef0250640ef64fd4117e11ecb
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_edctclm.sh
@@ -0,0 +1,66 @@
+#####################################
+# SpeechUT ASR model #
+#####################################
+[ $# -lt 2 ] && echo "Usage: $0 [gen-set=dev_other] [beam_size=30] [ctc_weight=0.3] [lm_weight=0.7] [lm_path] [--normalize]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+model_path=$1
+DATA_DIR=$2
+gen_set=$3
+beam_size=$4
+ctc_weight=$5
+lm_weight=$6
+lm_path=$7
+extra=$8
+[ -z $extra ] && echo "Assert decoding base model! If you are decoding large model, please add '--normalize' at the end..."
+[ -z $gen_set ] && gen_set="dev_other"
+[ -z $beam_size ] && beam_size=30
+[ -z $ctc_weight ] && ctc_weight=0.3
+[ -z $lm_weight ] && lm_weight=0.7
+[ -z $lm_path ] && lm_path="/mnt/default/v-junyiao/librispeech/lm/lm_ctc_form/checkpoint_best.pt"
+[ $ctc_weight == 0 ] && [ $beam_size != 1 ] && echo "Change beam size to 1 and lm_weight to 0 as no ctc-decoding used..." && beam_size=1 && lm_weight=0
+[ $ctc_weight != 0 ] && extra="$extra --batch-size 1"
+
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+CODE_ROOT=${PWD}
+
+for subset in ${gen_set//,/ }; do
+ results_path=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}_lm${lm_weight}/${subset}_${world_size}_${rank}
+ [ ! -d $results_path ] && mkdir -p $results_path
+
+ python $CODE_ROOT/fairseq/fairseq_cli/generate.py $DATA_DIR \
+ --user-dir $CODE_ROOT/speechut \
+ --label-dir ${DATA_DIR} \
+ --labels '["ltr"]' \
+ --single-target \
+ --post-process letter \
+ --gen-subset ${subset} \
+ --max-tokens 800000 \
+ \
+ --task joint_sc2t_pretraining \
+ --add-decoder-target \
+ --fine-tuning \
+ --pad-audio \
+ --random-crop \
+ \
+ --ctc-weight ${ctc_weight} $extra \
+ --lm-weight ${lm_weight} --lm-path ${lm_path} \
+ --beam ${beam_size} \
+ \
+ --path ${model_path} \
+ --results-path ${results_path} \
+ \
+ --scoring wer --max-len-a 0.00078125 --max-len-b 200 \
+ &
+done
+wait
+
+
+for subset in ${gen_set//,/ }; do
+ results_path=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}_lm${lm_weight}/${subset}_${world_size}_${rank}
+ echo $results_path
+ tail -n 1 $results_path/generate-*.txt
+done
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_lm_nj.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_lm_nj.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a5627a59975a01736907a5cc3fb76df335709b43
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_lm_nj.sh
@@ -0,0 +1,74 @@
+#####################################
+# SpeechUT ASR model #
+#####################################
+[ $# -lt 2 ] && echo "Usage: $0 [gen-set=dev_other] [beam_size=30] [ctc_weight=0.3] [lm_weight=0.7] [lm_path] [nj=8] [ngpu=8] [--normalize]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+model_path=$1
+DATA_DIR=$2
+gen_set=$3
+beam_size=$4
+ctc_weight=$5
+lm_weight=$6
+lm_path=$7
+nj=$8
+ngpu=$9
+extra=${10}
+[ -z $extra ] && echo "Assert decoding base model! If you are decoding large model, please add '--normalize' at the end..."
+[ -z $gen_set ] && gen_set="dev_other"
+[ -z $beam_size ] && beam_size=30
+[ -z $ctc_weight ] && ctc_weight=0.3
+[ -z $lm_weight ] && lm_weight=0.7
+[ -z $lm_path ] && lm_path="/mnt/default/v-junyiao/librispeech/lm/lm_ctc_form/checkpoint_best.pt"
+[ $ctc_weight == 0 ] && [ $beam_size != 1 ] && echo "Change beam size to 1 and lm_weight to 0 as no ctc-decoding used..." && beam_size=1 && lm_weight=0
+[ $ctc_weight != 0 ] && extra="$extra --batch-size 1"
+[ -z $nj ] && nj=8
+[ -z $ngpu ] && ngpu=8
+
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+CODE_ROOT=${PWD}
+
+world_size=$nj
+for rank in $(seq 0 $((nj - 1))); do
+ export CUDA_VISIBLE_DEVICES=$((rank % $ngpu))
+ for subset in ${gen_set//,/ }; do
+ results_path=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}_lm${lm_weight}/${subset}_${world_size}_${rank}
+ [ ! -d $results_path ] && mkdir -p $results_path
+
+ python $CODE_ROOT/fairseq/fairseq_cli/generate.py $DATA_DIR \
+ --user-dir $CODE_ROOT/speechut \
+ --label-dir ${DATA_DIR} \
+ --labels '["ltr"]' \
+ --single-target \
+ --post-process letter \
+ --gen-subset ${subset} \
+ --max-tokens 800000 \
+ \
+ --task joint_sc2t_pretraining \
+ --add-decoder-target \
+ --fine-tuning \
+ --pad-audio \
+ --random-crop \
+ \
+ --ctc-weight ${ctc_weight} $extra \
+ --lm-weight ${lm_weight} --lm-path ${lm_path} \
+ --beam ${beam_size} \
+ \
+ --path ${model_path} \
+ --results-path $results_path \
+ \
+ --scoring wer --max-len-a 0.00078125 --max-len-b 200 \
+ --distributed-world-size ${world_size} --distributed-rank ${rank} \
+ &
+ done
+done
+wait
+
+
+for subset in ${gen_set//,/ }; do
+ results_dir=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}_lm${lm_weight}
+ cat $results_dir/${subset}_${world_size}_*/generate-${subset}.txt | grep -v "^Generate" > $results_dir/generate-${subset}.all.txt
+done
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_nj.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_nj.sh
new file mode 100644
index 0000000000000000000000000000000000000000..08e6df431c9856f24122118017b8ae85bacc5444
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_asr/inference_nj.sh
@@ -0,0 +1,69 @@
+#####################################
+# SpeechUT ASR model #
+#####################################
+[ $# -lt 2 ] && echo "Usage: $0 [gen-set=dev_other] [beam_size=10] [ctc_weight=0.2] [nj=32] [ngpu=8] [--normalize]" && exit 1
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+model_path=$1
+DATA_DIR=$2
+gen_set=$3
+beam_size=$4
+ctc_weight=$5
+nj=$6
+ngpu=$7
+extra=$8
+[ -z $extra ] && echo "Assert decoding base model! If you are decoding large model, please add '--normalize' at the end..."
+[ -z $gen_set ] && gen_set="dev_other"
+[ -z $beam_size ] && beam_size=10
+[ -z $ctc_weight ] && ctc_weight=0.2
+[ $ctc_weight == 0 ] && [ $beam_size != 1 ] && echo "Change beam size to 1 as no ctc-decoding used..." && beam_size=1
+[ $ctc_weight != 0 ] && extra="$extra --batch-size 1"
+[ -z $nj ] && nj=32
+[ -z $ngpu ] && ngpu=8
+
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+CODE_ROOT=${PWD}
+
+world_size=$nj
+for rank in $(seq 0 $((nj - 1))); do
+ export CUDA_VISIBLE_DEVICES=$((rank % $ngpu))
+ for subset in ${gen_set//,/ }; do
+ results_path=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}/${subset}_${world_size}_${rank}
+ [ ! -d $results_path ] && mkdir -p $results_path
+
+ python $CODE_ROOT/fairseq/fairseq_cli/generate.py $DATA_DIR \
+ --user-dir $CODE_ROOT/speechut \
+ --label-dir ${DATA_DIR} \
+ --labels '["ltr"]' \
+ --single-target \
+ --post-process letter \
+ --gen-subset ${subset} \
+ --max-tokens 2000000 \
+ \
+ --task joint_sc2t_pretraining \
+ --add-decoder-target \
+ --fine-tuning \
+ --pad-audio \
+ --random-crop \
+ \
+ --ctc-weight ${ctc_weight} $extra \
+ --beam ${beam_size} \
+ \
+ --path ${model_path} \
+ --results-path $results_path \
+ \
+ --scoring wer --max-len-a 0.00078125 --max-len-b 200 \
+ --distributed-world-size ${world_size} --distributed-rank ${rank} \
+ &
+ done
+done
+wait
+
+
+for subset in ${gen_set//,/ }; do
+ results_dir=$src_dir/decode_${cpt}/beam${beam_size}_ctc${ctc_weight}
+ cat $results_dir/${subset}_${world_size}_*/generate-${subset}.txt | grep -v "^Generate" > $results_dir/generate-${subset}.all.txt
+done
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_st/finetune_base_mustc_enxx.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_st/finetune_base_mustc_enxx.sh
new file mode 100644
index 0000000000000000000000000000000000000000..59c8a2a0346b708894b1568fa691c062537aa559
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_st/finetune_base_mustc_enxx.sh
@@ -0,0 +1,77 @@
+# ####################################
+# SpeechUT Base model #
+# ####################################
+[ $# -lt 4 ] && echo "Usage: $0 [mount=${PWD}] [world_size=8] [update_freq=4/6]" && exit 0
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+w2v_path=$1
+DATA_DIR=$2
+lang=$3
+cpt=$4
+mount=$5
+world_size=$6
+update_freq=$7
+[ -z $mount ] && mount=${PWD}
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=4
+
+CODE_ROOT=${PWD}
+
+exp_name=${w2v_path%/*}
+exp_name=${exp_name##*/}
+MODEL_DIR="$mount/exp/finetune_mustc/$exp_name/legacy_en${lang}_from_${cpt}_bz3.2m_lr3e-5"
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+max_tokens=800000
+python $CODE_ROOT/fairseq/fairseq_cli/train.py ${DATA_DIR} \
+ --save-dir ${MODEL_DIR} \
+ --user-dir $CODE_ROOT/speechut \
+ --task speech_to_text \
+ --config-yaml config_en${lang}.yaml \
+ --train-subset "train_st" \
+ --valid-subset "dev_st" \
+ --fp16 \
+ --seed 1 \
+ \
+ --ddp-backend no_c10d \
+ --distributed-world-size ${world_size} \
+ --tensorboard-logdir ${MODEL_DIR} \
+ \
+ --criterion label_smoothed_cross_entropy --report-accuracy \
+ --label-smoothing 0.3 \
+ \
+ --optimizer adam \
+ --clip-norm 1.0 \
+ --lr 3e-05 \
+ --lr-scheduler polynomial_decay --warmup-updates 5000 \
+ --max-update 50000 \
+ --total-num-update 50000 \
+ --update-freq ${update_freq} \
+ \
+ --max-tokens ${max_tokens} \
+ --max-sentences 16 \
+ --max-tokens-valid ${max_tokens} \
+ --grouped-shuffling \
+ --max-source-positions ${max_tokens} \
+ --skip-invalid-size-inputs-valid-test \
+ --num-workers 0 \
+ --best-checkpoint-metric "accuracy" \
+ --maximize-best-checkpoint-metric \
+ \
+ --arch "speechut_st_legacy" \
+ --w2v-path ${w2v_path} \
+ --layerdrop 0.1 \
+ --activation-dropout 0.1 \
+ --attention-dropout 0.1 \
+ --feature-grad-mult 1.0 \
+ \
+ --apply-mask --mask-prob 0.5 \
+ \
+ --log-format json \
+ --log-interval 100 \
+ --save-interval 1 \
+ --keep-last-epochs 5 \
+ --keep-best-checkpoints 5 \
+ \
+ 2>&1 | tee ${MODEL_DIR}/train_en${lang}.log
+
diff --git a/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_st/inference_st.sh b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_st/inference_st.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3aefa10e360f57dbf66cff9d84c800b4da89619f
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts copy/tune_speechut_st/inference_st.sh
@@ -0,0 +1,44 @@
+# ####################################
+# SpeechUT Base model #
+# ####################################
+[ $# -lt 3 ] && echo "Usage: $0 [gen-set=dev] [beam_size=10] [lenpen=1.0]" && exit 0
+[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
+
+model_path=$1
+DATA_DIR=$2
+lang=$3
+gen_set=$4
+beam_size=$5
+lenpen=$6
+[ -z $gen_set ] && gen_set="dev"
+[ -z $beam_size ] && beam_size=10
+[ -z $lenpen ] && lenpen=1
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+CODE_ROOT=${PWD}
+results_path=$src_dir/decode_${cpt}_beam${beam_size}/${gen_set}
+[ ! -d $results_path ] && mkdir -p $results_path
+
+python $CODE_ROOT/fairseq/fairseq_cli/generate.py $DATA_DIR \
+ --gen-subset ${gen_set}_st \
+ --max-tokens 2000000 \
+ --max-source-positions 2000000 \
+ --num-workers 0 \
+ \
+ --user-dir $CODE_ROOT/speechut \
+ --task speech_to_text \
+ --config-yaml config_en${lang}.yaml \
+ \
+ --path ${model_path} \
+ --results-path $results_path \
+ \
+ --scoring sacrebleu --max-len-a 0 --max-len-b 512 \
+ --beam ${beam_size} \
+ --lenpen $lenpen \
+ # --model-overrides "{'model':{'w2v_path':'/path/to/your/pretrained/model.pt'}}" \
+
+ echo $results_path
+ tail -n 1 $results_path/generate-*.txt
+ sleep 1s
diff --git a/SpeechT5/Speech2S/speech2s/scripts/__init__.py b/SpeechT5/Speech2S/speech2s/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SpeechT5/Speech2S/speech2s/scripts/average_checkpoints.py b/SpeechT5/Speech2S/speech2s/scripts/average_checkpoints.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4711e4840a45118c9e28d0258f89fe64e964cf3
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/average_checkpoints.py
@@ -0,0 +1,160 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import collections
+import os
+import re
+
+import torch
+from fairseq.file_io import PathManager
+
+
+def average_checkpoints(inputs):
+ """Loads checkpoints from inputs and returns a model with averaged weights.
+
+ Args:
+ inputs: An iterable of string paths of checkpoints to load from.
+
+ Returns:
+ A dict of string keys mapping to various values. The 'model' key
+ from the returned dict should correspond to an OrderedDict mapping
+ string parameter names to torch Tensors.
+ """
+ params_dict = collections.OrderedDict()
+ params_keys = None
+ new_state = None
+ num_models = len(inputs)
+
+ for fpath in inputs:
+ with PathManager.open(fpath, "rb") as f:
+ state = torch.load(
+ f,
+ map_location=(
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
+ ),
+ )
+ # Copies over the settings from the first checkpoint
+ if new_state is None:
+ new_state = state
+
+ model_params = state["model"]
+
+ model_params_keys = list(model_params.keys())
+ if params_keys is None:
+ params_keys = model_params_keys
+ elif params_keys != model_params_keys:
+ raise KeyError(
+ "For checkpoint {}, expected list of params: {}, "
+ "but found: {}".format(f, params_keys, model_params_keys)
+ )
+
+ for k in params_keys:
+ p = model_params[k]
+ if isinstance(p, torch.HalfTensor):
+ p = p.float()
+ if k not in params_dict:
+ params_dict[k] = p.clone()
+ # NOTE: clone() is needed in case of p is a shared parameter
+ else:
+ params_dict[k] += p
+
+ averaged_params = collections.OrderedDict()
+ for k, v in params_dict.items():
+ averaged_params[k] = v
+ if averaged_params[k].is_floating_point():
+ averaged_params[k].div_(num_models)
+ else:
+ averaged_params[k] //= num_models
+ new_state["model"] = averaged_params
+ return new_state
+
+
+def last_n_checkpoints(paths, n, update_based, upper_bound=None):
+ assert len(paths) == 1
+ path = paths[0]
+ if update_based:
+ pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
+ else:
+ pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
+ files = PathManager.ls(path)
+
+ entries = []
+ for f in files:
+ m = pt_regexp.fullmatch(f)
+ if m is not None:
+ sort_key = int(m.group(1))
+ if upper_bound is None or sort_key <= upper_bound:
+ entries.append((sort_key, m.group(0)))
+ if len(entries) < n:
+ raise Exception(
+ "Found {} checkpoint files but need at least {}", len(entries), n
+ )
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Tool to average the params of input checkpoints to "
+ "produce a new checkpoint",
+ )
+ # fmt: off
+ parser.add_argument('--inputs', required=True, nargs='+',
+ help='Input checkpoint file paths.')
+ parser.add_argument('--output', required=True, metavar='FILE',
+ help='Write the new checkpoint containing the averaged weights to this path.')
+ num_group = parser.add_mutually_exclusive_group()
+ num_group.add_argument('--num-epoch-checkpoints', type=int,
+ help='if set, will try to find checkpoints with names checkpoint_xx.pt in the '
+ 'path specified by input, and average last this many of them.')
+ num_group.add_argument('--num-update-checkpoints', type=int,
+ help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by'
+ ' input, and average last this many of them.')
+ parser.add_argument('--checkpoint-upper-bound', type=int,
+ help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
+ 'when using --num-update-checkpoints, this will set an upper bound on which update to use'
+ 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be'
+ ' averaged.'
+ 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would'
+ ' be averaged assuming --save-interval-updates 500'
+ )
+ # fmt: on
+ args = parser.parse_args()
+ print(args)
+
+ num = None
+ is_update_based = False
+ if args.num_update_checkpoints is not None:
+ num = args.num_update_checkpoints
+ is_update_based = True
+ elif args.num_epoch_checkpoints is not None:
+ num = args.num_epoch_checkpoints
+
+ assert args.checkpoint_upper_bound is None or (
+ args.num_epoch_checkpoints is not None
+ or args.num_update_checkpoints is not None
+ ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints"
+ assert (
+ args.num_epoch_checkpoints is None or args.num_update_checkpoints is None
+ ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints"
+
+ if num is not None:
+ args.inputs = last_n_checkpoints(
+ args.inputs,
+ num,
+ is_update_based,
+ upper_bound=args.checkpoint_upper_bound,
+ )
+ print("averaging checkpoints: ", args.inputs)
+
+ new_state = average_checkpoints(args.inputs)
+ with PathManager.open(args.output, "wb") as f:
+ torch.save(new_state, f)
+ print("Finished writing averaged checkpoint to {}".format(args.output))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/build_sym_alignment.py b/SpeechT5/Speech2S/speech2s/scripts/build_sym_alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ca5c18f7bd4b0fbf58b203793506ca395466129
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/build_sym_alignment.py
@@ -0,0 +1,97 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Use this script in order to build symmetric alignments for your translation
+dataset.
+This script depends on fast_align and mosesdecoder tools. You will need to
+build those before running the script.
+fast_align:
+ github: http://github.com/clab/fast_align
+ instructions: follow the instructions in README.md
+mosesdecoder:
+ github: http://github.com/moses-smt/mosesdecoder
+ instructions: http://www.statmt.org/moses/?n=Development.GetStarted
+The script produces the following files under --output_dir:
+ text.joined - concatenation of lines from the source_file and the
+ target_file.
+ align.forward - forward pass of fast_align.
+ align.backward - backward pass of fast_align.
+ aligned.sym_heuristic - symmetrized alignment.
+"""
+
+import argparse
+import os
+from itertools import zip_longest
+
+
+def main():
+ parser = argparse.ArgumentParser(description="symmetric alignment builer")
+ # fmt: off
+ parser.add_argument('--fast_align_dir',
+ help='path to fast_align build directory')
+ parser.add_argument('--mosesdecoder_dir',
+ help='path to mosesdecoder root directory')
+ parser.add_argument('--sym_heuristic',
+ help='heuristic to use for symmetrization',
+ default='grow-diag-final-and')
+ parser.add_argument('--source_file',
+ help='path to a file with sentences '
+ 'in the source language')
+ parser.add_argument('--target_file',
+ help='path to a file with sentences '
+ 'in the target language')
+ parser.add_argument('--output_dir',
+ help='output directory')
+ # fmt: on
+ args = parser.parse_args()
+
+ fast_align_bin = os.path.join(args.fast_align_dir, "fast_align")
+ symal_bin = os.path.join(args.mosesdecoder_dir, "bin", "symal")
+ sym_fast_align_bin = os.path.join(
+ args.mosesdecoder_dir, "scripts", "ems", "support", "symmetrize-fast-align.perl"
+ )
+
+ # create joined file
+ joined_file = os.path.join(args.output_dir, "text.joined")
+ with open(args.source_file, "r", encoding="utf-8") as src, open(
+ args.target_file, "r", encoding="utf-8"
+ ) as tgt:
+ with open(joined_file, "w", encoding="utf-8") as joined:
+ for s, t in zip_longest(src, tgt):
+ print("{} ||| {}".format(s.strip(), t.strip()), file=joined)
+
+ bwd_align_file = os.path.join(args.output_dir, "align.backward")
+
+ # run forward alignment
+ fwd_align_file = os.path.join(args.output_dir, "align.forward")
+ fwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v > {FWD}".format(
+ FASTALIGN=fast_align_bin, JOINED=joined_file, FWD=fwd_align_file
+ )
+ assert os.system(fwd_fast_align_cmd) == 0
+
+ # run backward alignment
+ bwd_align_file = os.path.join(args.output_dir, "align.backward")
+ bwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}".format(
+ FASTALIGN=fast_align_bin, JOINED=joined_file, BWD=bwd_align_file
+ )
+ assert os.system(bwd_fast_align_cmd) == 0
+
+ # run symmetrization
+ sym_out_file = os.path.join(args.output_dir, "aligned")
+ sym_cmd = "{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}".format(
+ SYMFASTALIGN=sym_fast_align_bin,
+ FWD=fwd_align_file,
+ BWD=bwd_align_file,
+ SRC=args.source_file,
+ TGT=args.target_file,
+ OUT=sym_out_file,
+ HEURISTIC=args.sym_heuristic,
+ SYMAL=symal_bin,
+ )
+ assert os.system(sym_cmd) == 0
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/compare_namespaces.py b/SpeechT5/Speech2S/speech2s/scripts/compare_namespaces.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc24db624f8db36f546c263ba3a806dae6d466bf
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/compare_namespaces.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+"""Helper script to compare two argparse.Namespace objects."""
+
+from argparse import Namespace # noqa
+
+
+def main():
+
+ ns1 = eval(input("Namespace 1: "))
+ ns2 = eval(input("Namespace 2: "))
+
+ def keys(ns):
+ ks = set()
+ for k in dir(ns):
+ if not k.startswith("_"):
+ ks.add(k)
+ return ks
+
+ k1 = keys(ns1)
+ k2 = keys(ns2)
+
+ def print_keys(ks, ns1, ns2=None):
+ for k in ks:
+ if ns2 is None:
+ print("{}\t{}".format(k, getattr(ns1, k, None)))
+ else:
+ print(
+ "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None))
+ )
+
+ print("Keys unique to namespace 1:")
+ print_keys(k1 - k2, ns1)
+ print()
+
+ print("Keys unique to namespace 2:")
+ print_keys(k2 - k1, ns2)
+ print()
+
+ print("Overlapping keys with different values:")
+ ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")]
+ print_keys(ks, ns1, ns2)
+ print()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/compound_split_bleu.sh b/SpeechT5/Speech2S/speech2s/scripts/compound_split_bleu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1972fddcebff9a43a70bcf14c287175c68f60e3f
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/compound_split_bleu.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+if [ $# -ne 1 ]; then
+ echo "usage: $0 GENERATE_PY_OUTPUT"
+ exit 1
+fi
+
+GEN=$1
+
+SYS=$GEN.sys
+REF=$GEN.ref
+
+if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
+ echo "not done generating"
+ exit
+fi
+
+grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
+grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
+fairseq-score --sys $SYS --ref $REF
diff --git a/SpeechT5/Speech2S/speech2s/scripts/constraints/extract.py b/SpeechT5/Speech2S/speech2s/scripts/constraints/extract.py
new file mode 100644
index 0000000000000000000000000000000000000000..437b373856966e568ca93c13ebbd1417291e49da
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/constraints/extract.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+#
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Extracts random constraints from reference files."""
+
+import argparse
+import random
+import sys
+
+
+def get_phrase(words, index, length):
+ assert index < len(words) - length + 1
+ phr = " ".join(words[index : index + length])
+ for i in range(index, index + length):
+ words.pop(index)
+ return phr
+
+
+def main(args):
+
+ if args.seed:
+ random.seed(args.seed)
+
+ for line in sys.stdin:
+ constraints = []
+
+ def add_constraint(constraint):
+ constraints.append(constraint)
+
+ source = line.rstrip()
+ if "\t" in line:
+ source, target = line.split("\t")
+ if args.add_sos:
+ target = f" {target}"
+ if args.add_eos:
+ target = f"{target} "
+
+ if len(target.split()) >= args.len:
+ words = [target]
+
+ num = args.number
+
+ choices = {}
+ for i in range(num):
+ if len(words) == 0:
+ break
+ segmentno = random.choice(range(len(words)))
+ segment = words.pop(segmentno)
+ tokens = segment.split()
+ phrase_index = random.choice(range(len(tokens)))
+ choice = " ".join(
+ tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
+ )
+ for j in range(
+ phrase_index, min(len(tokens), phrase_index + args.len)
+ ):
+ tokens.pop(phrase_index)
+ if phrase_index > 0:
+ words.append(" ".join(tokens[0:phrase_index]))
+ if phrase_index + 1 < len(tokens):
+ words.append(" ".join(tokens[phrase_index:]))
+ choices[target.find(choice)] = choice
+
+ # mask out with spaces
+ target = target.replace(choice, " " * len(choice), 1)
+
+ for key in sorted(choices.keys()):
+ add_constraint(choices[key])
+
+ print(source, *constraints, sep="\t")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
+ parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
+ parser.add_argument(
+ "--add-sos", default=False, action="store_true", help="add token"
+ )
+ parser.add_argument(
+ "--add-eos", default=False, action="store_true", help="add token"
+ )
+ parser.add_argument("--seed", "-s", default=0, type=int)
+ args = parser.parse_args()
+
+ main(args)
diff --git a/SpeechT5/Speech2S/speech2s/scripts/constraints/validate.py b/SpeechT5/Speech2S/speech2s/scripts/constraints/validate.py
new file mode 100644
index 0000000000000000000000000000000000000000..d531ad9f39b1df42c98fe8f26ad61fe53a9ac0c5
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/constraints/validate.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python3
+#
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+
+
+"""Reads in a fairseq output file, and verifies that the constraints
+(C- lines) are present in the output (the first H- line). Assumes that
+constraints are listed prior to the first hypothesis.
+"""
+
+constraints = []
+found = 0
+total = 0
+for line in sys.stdin:
+ if line.startswith("C-"):
+ constraints.append(line.rstrip().split("\t")[1])
+ elif line.startswith("H-"):
+ text = line.split("\t")[2]
+
+ for constraint in constraints:
+ total += 1
+ if constraint in text:
+ found += 1
+ else:
+ print(f"No {constraint} in {text}", file=sys.stderr)
+
+ constraints = []
+
+print(f"Found {found} / {total} = {100 * found / total:.1f}%")
diff --git a/SpeechT5/Speech2S/speech2s/scripts/convert_dictionary.lua b/SpeechT5/Speech2S/speech2s/scripts/convert_dictionary.lua
new file mode 100644
index 0000000000000000000000000000000000000000..14ee8c997f642c8ff196617c2dcd0584037a60c4
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/convert_dictionary.lua
@@ -0,0 +1,34 @@
+-- Copyright (c) Facebook, Inc. and its affiliates.
+--
+-- This source code is licensed under the MIT license found in the
+-- LICENSE file in the root directory of this source tree.
+--
+-- Usage: convert_dictionary.lua
+require 'fairseq'
+require 'torch'
+require 'paths'
+
+if #arg < 1 then
+ print('usage: convert_dictionary.lua ')
+ os.exit(1)
+end
+if not paths.filep(arg[1]) then
+ print('error: file does not exit: ' .. arg[1])
+ os.exit(1)
+end
+
+dict = torch.load(arg[1])
+dst = paths.basename(arg[1]):gsub('.th7', '.txt')
+assert(dst:match('.txt$'))
+
+f = io.open(dst, 'w')
+for idx, symbol in ipairs(dict.index_to_symbol) do
+ if idx > dict.cutoff then
+ break
+ end
+ f:write(symbol)
+ f:write(' ')
+ f:write(dict.index_to_freq[idx])
+ f:write('\n')
+end
+f:close()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/convert_model.lua b/SpeechT5/Speech2S/speech2s/scripts/convert_model.lua
new file mode 100644
index 0000000000000000000000000000000000000000..61b92139294fb90a25989ebd2ee52a765fb278a2
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/convert_model.lua
@@ -0,0 +1,108 @@
+-- Copyright (c) Facebook, Inc. and its affiliates.
+--
+-- This source code is licensed under the MIT license found in the
+-- LICENSE file in the root directory of this source tree.
+--
+-- Usage: convert_model.lua
+require 'torch'
+local fairseq = require 'fairseq'
+
+model = torch.load(arg[1])
+
+function find_weight_norm(container, module)
+ for _, wn in ipairs(container:listModules()) do
+ if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then
+ return wn
+ end
+ end
+end
+
+function push_state(dict, key, module)
+ if torch.type(module) == 'nn.Linear' then
+ local wn = find_weight_norm(model.module, module)
+ assert(wn)
+ dict[key .. '.weight_v'] = wn.v:float()
+ dict[key .. '.weight_g'] = wn.g:float()
+ elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then
+ local wn = find_weight_norm(model.module, module)
+ assert(wn)
+ local v = wn.v:float():view(wn.viewOut):transpose(2, 3)
+ dict[key .. '.weight_v'] = v
+ dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1)
+ else
+ dict[key .. '.weight'] = module.weight:float()
+ end
+ if module.bias then
+ dict[key .. '.bias'] = module.bias:float()
+ end
+end
+
+encoder_dict = {}
+decoder_dict = {}
+combined_dict = {}
+
+function encoder_state(encoder)
+ luts = encoder:findModules('nn.LookupTable')
+ push_state(encoder_dict, 'embed_tokens', luts[1])
+ push_state(encoder_dict, 'embed_positions', luts[2])
+
+ fcs = encoder:findModules('nn.Linear')
+ assert(#fcs >= 2)
+ local nInputPlane = fcs[1].weight:size(1)
+ push_state(encoder_dict, 'fc1', table.remove(fcs, 1))
+ push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs))
+
+ for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do
+ push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module)
+ if nInputPlane ~= module.weight:size(3) / 2 then
+ push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
+ end
+ nInputPlane = module.weight:size(3) / 2
+ end
+ assert(#fcs == 0)
+end
+
+function decoder_state(decoder)
+ luts = decoder:findModules('nn.LookupTable')
+ push_state(decoder_dict, 'embed_tokens', luts[1])
+ push_state(decoder_dict, 'embed_positions', luts[2])
+
+ fcs = decoder:findModules('nn.Linear')
+ local nInputPlane = fcs[1].weight:size(1)
+ push_state(decoder_dict, 'fc1', table.remove(fcs, 1))
+ push_state(decoder_dict, 'fc2', fcs[#fcs - 1])
+ push_state(decoder_dict, 'fc3', fcs[#fcs])
+
+ table.remove(fcs, #fcs)
+ table.remove(fcs, #fcs)
+
+ for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do
+ if nInputPlane ~= module.weight:size(3) / 2 then
+ push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
+ end
+ nInputPlane = module.weight:size(3) / 2
+
+ local prefix = 'attention.' .. tostring(i - 1)
+ push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1))
+ push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1))
+ push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module)
+ end
+ assert(#fcs == 0)
+end
+
+
+_encoder = model.module.modules[2]
+_decoder = model.module.modules[3]
+
+encoder_state(_encoder)
+decoder_state(_decoder)
+
+for k, v in pairs(encoder_dict) do
+ combined_dict['encoder.' .. k] = v
+end
+for k, v in pairs(decoder_dict) do
+ combined_dict['decoder.' .. k] = v
+end
+
+
+torch.save('state_dict.t7', combined_dict)
diff --git a/SpeechT5/Speech2S/speech2s/scripts/count_docs.py b/SpeechT5/Speech2S/speech2s/scripts/count_docs.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d85af85e91377a34dbd01f7674436152fd08e8
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/count_docs.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Count the number of documents and average number of lines and tokens per
+document in a large file. Documents should be separated by a single empty line.
+"""
+
+import argparse
+import gzip
+import sys
+
+import numpy as np
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input")
+ parser.add_argument("--gzip", action="store_true")
+ args = parser.parse_args()
+
+ def gopen():
+ if args.gzip:
+ return gzip.open(args.input, "r")
+ else:
+ return open(args.input, "r", encoding="utf-8")
+
+ num_lines = []
+ num_toks = []
+ with gopen() as h:
+ num_docs = 1
+ num_lines_in_doc = 0
+ num_toks_in_doc = 0
+ for i, line in enumerate(h):
+ if len(line.strip()) == 0: # empty line indicates new document
+ num_docs += 1
+ num_lines.append(num_lines_in_doc)
+ num_toks.append(num_toks_in_doc)
+ num_lines_in_doc = 0
+ num_toks_in_doc = 0
+ else:
+ num_lines_in_doc += 1
+ num_toks_in_doc += len(line.rstrip().split())
+ if i % 1000000 == 0:
+ print(i, file=sys.stderr, end="", flush=True)
+ elif i % 100000 == 0:
+ print(".", file=sys.stderr, end="", flush=True)
+ print(file=sys.stderr, flush=True)
+
+ print("found {} docs".format(num_docs))
+ print("average num lines per doc: {}".format(np.mean(num_lines)))
+ print("average num toks per doc: {}".format(np.mean(num_toks)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/read_binarized.py b/SpeechT5/Speech2S/speech2s/scripts/read_binarized.py
new file mode 100644
index 0000000000000000000000000000000000000000..a414095d03fb022a6753e816fc8bfd80e11db24d
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/read_binarized.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+
+from fairseq.data import Dictionary, data_utils, indexed_dataset
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="writes text from binarized file to stdout"
+ )
+ # fmt: off
+ parser.add_argument('--dataset-impl', help='dataset implementation',
+ choices=indexed_dataset.get_available_dataset_impl())
+ parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
+ parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
+ # fmt: on
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ dictionary = Dictionary.load(args.dict) if args.dict is not None else None
+ dataset = data_utils.load_indexed_dataset(
+ args.input,
+ dictionary,
+ dataset_impl=args.dataset_impl,
+ default="lazy",
+ )
+
+ for tensor_line in dataset:
+ if dictionary is None:
+ line = " ".join([str(int(x)) for x in tensor_line])
+ else:
+ line = dictionary.string(tensor_line)
+
+ print(line)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/rm_pt.py b/SpeechT5/Speech2S/speech2s/scripts/rm_pt.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cd063d21f0610fa7c42c2cfb2ee8af7c9c78677
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/rm_pt.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+import re
+import shutil
+import sys
+
+
+pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt")
+pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt")
+pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt")
+
+
+def parse_checkpoints(files):
+ entries = []
+ for f in files:
+ m = pt_regexp_epoch_based.fullmatch(f)
+ if m is not None:
+ entries.append((int(m.group(1)), m.group(0)))
+ else:
+ m = pt_regexp_update_based.fullmatch(f)
+ if m is not None:
+ entries.append((int(m.group(1)), m.group(0)))
+ return entries
+
+
+def last_n_checkpoints(files, n):
+ entries = parse_checkpoints(files)
+ return [x[1] for x in sorted(entries, reverse=True)[:n]]
+
+
+def every_n_checkpoints(files, n):
+ entries = parse_checkpoints(files)
+ return [x[1] for x in sorted(sorted(entries)[::-n])]
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description=(
+ "Recursively delete checkpoint files from `root_dir`, "
+ "but preserve checkpoint_best.pt and checkpoint_last.pt"
+ )
+ )
+ parser.add_argument("root_dirs", nargs="*")
+ parser.add_argument(
+ "--save-last", type=int, default=0, help="number of last checkpoints to save"
+ )
+ parser.add_argument(
+ "--save-every", type=int, default=0, help="interval of checkpoints to save"
+ )
+ parser.add_argument(
+ "--preserve-test",
+ action="store_true",
+ help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)",
+ )
+ parser.add_argument(
+ "--delete-best", action="store_true", help="delete checkpoint_best.pt"
+ )
+ parser.add_argument(
+ "--delete-last", action="store_true", help="delete checkpoint_last.pt"
+ )
+ parser.add_argument(
+ "--no-dereference", action="store_true", help="don't dereference symlinks"
+ )
+ args = parser.parse_args()
+
+ files_to_desymlink = []
+ files_to_preserve = []
+ files_to_delete = []
+ for root_dir in args.root_dirs:
+ for root, _subdirs, files in os.walk(root_dir):
+ if args.save_last > 0:
+ to_save = last_n_checkpoints(files, args.save_last)
+ else:
+ to_save = []
+ if args.save_every > 0:
+ to_save += every_n_checkpoints(files, args.save_every)
+ for file in files:
+ if not pt_regexp.fullmatch(file):
+ continue
+ full_path = os.path.join(root, file)
+ if (
+ not os.path.basename(root).startswith("test_") or args.preserve_test
+ ) and (
+ (file == "checkpoint_last.pt" and not args.delete_last)
+ or (file == "checkpoint_best.pt" and not args.delete_best)
+ or file in to_save
+ ):
+ if os.path.islink(full_path) and not args.no_dereference:
+ files_to_desymlink.append(full_path)
+ else:
+ files_to_preserve.append(full_path)
+ else:
+ files_to_delete.append(full_path)
+
+ if len(files_to_desymlink) == 0 and len(files_to_delete) == 0:
+ print("Nothing to do.")
+ sys.exit(0)
+
+ files_to_desymlink = sorted(files_to_desymlink)
+ files_to_preserve = sorted(files_to_preserve)
+ files_to_delete = sorted(files_to_delete)
+
+ print("Operations to perform (in order):")
+ if len(files_to_desymlink) > 0:
+ for file in files_to_desymlink:
+ print(" - preserve (and dereference symlink): " + file)
+ if len(files_to_preserve) > 0:
+ for file in files_to_preserve:
+ print(" - preserve: " + file)
+ if len(files_to_delete) > 0:
+ for file in files_to_delete:
+ print(" - delete: " + file)
+ while True:
+ resp = input("Continue? (Y/N): ")
+ if resp.strip().lower() == "y":
+ break
+ elif resp.strip().lower() == "n":
+ sys.exit(0)
+
+ print("Executing...")
+ if len(files_to_desymlink) > 0:
+ for file in files_to_desymlink:
+ realpath = os.path.realpath(file)
+ print("rm " + file)
+ os.remove(file)
+ print("cp {} {}".format(realpath, file))
+ shutil.copyfile(realpath, file)
+ if len(files_to_delete) > 0:
+ for file in files_to_delete:
+ print("rm " + file)
+ os.remove(file)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/sacrebleu.sh b/SpeechT5/Speech2S/speech2s/scripts/sacrebleu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c10bf2b76ea032deabab6f5c9d8a3e1e884f1642
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/sacrebleu.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+if [ $# -ne 4 ]; then
+ echo "usage: $0 TESTSET SRCLANG TGTLANG GEN"
+ exit 1
+fi
+
+TESTSET=$1
+SRCLANG=$2
+TGTLANG=$3
+
+GEN=$4
+
+if ! command -v sacremoses &> /dev/null
+then
+ echo "sacremoses could not be found, please install with: pip install sacremoses"
+ exit
+fi
+
+grep ^H $GEN \
+| sed 's/^H\-//' \
+| sort -n -k 1 \
+| cut -f 3 \
+| sacremoses detokenize \
+> $GEN.sorted.detok
+
+sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok
diff --git a/SpeechT5/Speech2S/speech2s/scripts/shard_docs.py b/SpeechT5/Speech2S/speech2s/scripts/shard_docs.py
new file mode 100644
index 0000000000000000000000000000000000000000..97232c3c845ee01dc5ab627388934cc0f9588280
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/shard_docs.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Split a large file into shards while respecting document boundaries. Documents
+should be separated by a single empty line.
+"""
+
+import argparse
+import contextlib
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input")
+ parser.add_argument("--num-shards", type=int)
+ args = parser.parse_args()
+
+ assert args.num_shards is not None and args.num_shards > 1
+
+ with open(args.input, "r", encoding="utf-8") as h:
+ with contextlib.ExitStack() as stack:
+ outputs = [
+ stack.enter_context(
+ open(args.input + ".shard" + str(i), "w", encoding="utf-8")
+ )
+ for i in range(args.num_shards)
+ ]
+
+ doc = []
+ first_doc = [True] * args.num_shards
+
+ def output_doc(i):
+ if not first_doc[i]:
+ outputs[i].write("\n")
+ first_doc[i] = False
+ for line in doc:
+ outputs[i].write(line)
+ doc.clear()
+
+ num_docs = 0
+ for line in h:
+ if line.strip() == "": # empty line indicates new document
+ output_doc(num_docs % args.num_shards)
+ num_docs += 1
+ else:
+ doc.append(line)
+ output_doc(num_docs % args.num_shards)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/split_train_valid_docs.py b/SpeechT5/Speech2S/speech2s/scripts/split_train_valid_docs.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff159785284a13b44626b207d84430c592acaf8f
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/split_train_valid_docs.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Split a large file into a train and valid set while respecting document
+boundaries. Documents should be separated by a single empty line.
+"""
+
+import argparse
+import random
+import sys
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input")
+ parser.add_argument("sample_output", help="train output file")
+ parser.add_argument("remainder_output", help="valid output file")
+ parser.add_argument("-k", type=int, help="remainder size")
+ parser.add_argument(
+ "--lines", action="store_true", help="split lines instead of docs"
+ )
+ args = parser.parse_args()
+
+ assert args.k is not None
+
+ sample = []
+ remainder = []
+ num_docs = [0]
+
+ def update_sample(doc):
+ if len(sample) < args.k:
+ sample.append(doc.copy())
+ else:
+ i = num_docs[0]
+ j = random.randrange(i + 1)
+ if j < args.k:
+ remainder.append(sample[j])
+ sample[j] = doc.copy()
+ else:
+ remainder.append(doc.copy())
+ num_docs[0] += 1
+ doc.clear()
+
+ with open(args.input, "r", encoding="utf-8") as h:
+ doc = []
+ for i, line in enumerate(h):
+ if line.strip() == "": # empty line indicates new document
+ update_sample(doc)
+ else:
+ doc.append(line)
+ if args.lines:
+ update_sample(doc)
+ if i % 1000000 == 0:
+ print(i, file=sys.stderr, end="", flush=True)
+ elif i % 100000 == 0:
+ print(".", file=sys.stderr, end="", flush=True)
+ if len(doc) > 0:
+ update_sample(doc)
+ print(file=sys.stderr, flush=True)
+
+ assert len(sample) == args.k
+
+ with open(args.sample_output, "w", encoding="utf-8") as out:
+ first = True
+ for doc in sample:
+ if not first and not args.lines:
+ out.write("\n")
+ first = False
+ for line in doc:
+ out.write(line)
+
+ with open(args.remainder_output, "w", encoding="utf-8") as out:
+ first = True
+ for doc in remainder:
+ if not first and not args.lines:
+ out.write("\n")
+ first = False
+ for line in doc:
+ out.write(line)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/spm_decode.py b/SpeechT5/Speech2S/speech2s/scripts/spm_decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d7b68b240265924601ca6a738ed3d7b4b8e9cda
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/spm_decode.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import argparse
+
+import sentencepiece as spm
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model", required=True, help="sentencepiece model to use for decoding"
+ )
+ parser.add_argument("--input", required=True, help="input file to decode")
+ parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
+ args = parser.parse_args()
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(args.model)
+
+ if args.input_format == "piece":
+
+ def decode(input):
+ return "".join(sp.DecodePieces(input))
+
+ elif args.input_format == "id":
+
+ def decode(input):
+ return "".join(sp.DecodeIds(input))
+
+ else:
+ raise NotImplementedError
+
+ def tok2int(tok):
+ # remap reference-side (represented as <>) to 0
+ return int(tok) if tok != "<>" else 0
+
+ with open(args.input, "r", encoding="utf-8") as h:
+ for line in h:
+ if args.input_format == "id":
+ print(decode(list(map(tok2int, line.rstrip().split()))))
+ elif args.input_format == "piece":
+ print(decode(line.rstrip().split()))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/spm_encode.py b/SpeechT5/Speech2S/speech2s/scripts/spm_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91e0bb728a33448c1415aee6036ac9d0feac11f
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/spm_encode.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model", required=True, help="sentencepiece model to use for encoding"
+ )
+ parser.add_argument(
+ "--inputs", nargs="+", default=["-"], help="input files to filter/encode"
+ )
+ parser.add_argument(
+ "--outputs", nargs="+", default=["-"], help="path to save encoded outputs"
+ )
+ parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+ parser.add_argument(
+ "--min-len",
+ type=int,
+ metavar="N",
+ help="filter sentence pairs with fewer than N tokens",
+ )
+ parser.add_argument(
+ "--max-len",
+ type=int,
+ metavar="N",
+ help="filter sentence pairs with more than N tokens",
+ )
+ args = parser.parse_args()
+
+ assert len(args.inputs) == len(
+ args.outputs
+ ), "number of input and output paths should match"
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(args.model)
+
+ if args.output_format == "piece":
+
+ def encode(input):
+ return sp.EncodeAsPieces(input)
+
+ elif args.output_format == "id":
+
+ def encode(input):
+ return list(map(str, sp.EncodeAsIds(input)))
+
+ else:
+ raise NotImplementedError
+
+ if args.min_len is not None or args.max_len is not None:
+
+ def valid(line):
+ return (args.min_len is None or len(line) >= args.min_len) and (
+ args.max_len is None or len(line) <= args.max_len
+ )
+
+ else:
+
+ def valid(lines):
+ return True
+
+ with contextlib.ExitStack() as stack:
+ inputs = [
+ stack.enter_context(open(input, "r", encoding="utf-8"))
+ if input != "-"
+ else sys.stdin
+ for input in args.inputs
+ ]
+ outputs = [
+ stack.enter_context(open(output, "w", encoding="utf-8"))
+ if output != "-"
+ else sys.stdout
+ for output in args.outputs
+ ]
+
+ stats = {
+ "num_empty": 0,
+ "num_filtered": 0,
+ }
+
+ def encode_line(line):
+ line = line.strip()
+ if len(line) > 0:
+ line = encode(line)
+ if valid(line):
+ return line
+ else:
+ stats["num_filtered"] += 1
+ else:
+ stats["num_empty"] += 1
+ return None
+
+ for i, lines in enumerate(zip(*inputs), start=1):
+ enc_lines = list(map(encode_line, lines))
+ if not any(enc_line is None for enc_line in enc_lines):
+ for enc_line, output_h in zip(enc_lines, outputs):
+ print(" ".join(enc_line), file=output_h)
+ if i % 10000 == 0:
+ print("processed {} lines".format(i), file=sys.stderr)
+
+ print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+ print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/scripts/spm_train.py b/SpeechT5/Speech2S/speech2s/scripts/spm_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9db668fd4166a860198784990de68ea26157995d
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/spm_train.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import sys
+
+import sentencepiece as spm
+
+
+if __name__ == "__main__":
+ spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/SpeechT5/Speech2S/speech2s/scripts/test_fsdp.sh b/SpeechT5/Speech2S/speech2s/scripts/test_fsdp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1f428a035e4474427ded991f8e8307ea59f61f69
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/scripts/test_fsdp.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+rm -rf fsdp_dummy
+mkdir -p fsdp_dummy
+CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
+ --cpu-offload --checkpoint-activations \
+ --task language_modeling --tokens-per-sample 256 --batch-size 8 \
+ --arch transformer_lm_gpt2_tiny \
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
+ --max-update 5 --log-format json --log-interval 1 \
+ --save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \
+ --restore-file x.pt "$@"
+
+# Now we try to load the checkpoint
+CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
+ --cpu-offload --checkpoint-activations \
+ --task language_modeling --tokens-per-sample 256 --batch-size 8 \
+ --arch transformer_lm_gpt2_tiny \
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
+ --max-update 2 --log-format json --log-interval 1 \
+ --save-interval-updates 2 --save-dir fsdp_dummy
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/base_sc2c_enes.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/base_sc2c_enes.sh
new file mode 100644
index 0000000000000000000000000000000000000000..08e00403f961625ec2c819f5ee85a2ce74e64e9a
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/base_sc2c_enes.sh
@@ -0,0 +1,64 @@
+
+# ####################################
+# Hubert SCT2T ED model #
+# ####################################
+
+world_size=$1
+update_freq=$2
+exp_name=$3
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=1
+[ -z $exp_name ] && exp_name=sc2t_base_enes_${world_size}gpu_${update_freq}accum6666
+
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+CONFIG_DIR=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config
+DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/speech_enes"
+TEXT_DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/text_enes/bin-idx"
+MODEL_DIR="/mnt/output/v-kunwei/data/s2s_data/exp/S2S_enes/$exp_name"
+
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+
+python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
+ --config-dir $CONFIG_DIR/pretrain \
+ --config-name sc2t_base_librispeech \
+ \
+ +task.store_labels=true \
+ task.labels='["km"]' \
+ model.label_rate=50 \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.text_cfg.text_data=$TEXT_DATA_DIR \
+ +task.text_cfg.data_config=config.yaml \
+ task.text_cfg.text_maxtokens_ratio=3.0 \
+ \
+ +criterion.dec_loss_type="ce" \
+ \
+ criterion.text_weight=1.0 \
+ \
+ model.use_rel_pos_enc=true \
+ +model.code_use_rel_pos_enc=true \
+ +model.pad_with_code=true \
+ model.text_transformer.no_scale_embedding=true \
+ model.text_transformer.layernorm_embedding=true \
+ +model.share_decoder_input_output_embed=true \
+ \
+ dataset.train_subset=\"train_all+en.kmu-spm\" \
+ dataset.valid_subset=\"valid+en_valid.kmu-spm\" \
+ dataset.num_workers=0 \
+ dataset.max_tokens=1000000 \
+ optimization.update_freq=[${update_freq}] \
+ optimization.max_update=400000 \
+ \
+ distributed_training.distributed_world_size=${world_size} \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=${exp_name}
+
+
+sleep 5m
+echo "All finished"
+
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/base_sc2c_esen.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/base_sc2c_esen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2a15bd129b961e9c5eeff211f7c03f7f8fcc20c9
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/base_sc2c_esen.sh
@@ -0,0 +1,64 @@
+
+# ####################################
+# Hubert SCT2T ED model #
+# ####################################
+
+world_size=$1
+update_freq=$2
+exp_name=$3
+[ -z $world_size ] && world_size=24
+[ -z $update_freq ] && update_freq=3
+[ -z $exp_name ] && exp_name=sc2t_base_esen_${world_size}gpu_${update_freq}accum1
+
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+CONFIG_DIR=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config
+DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/speech_esen"
+TEXT_DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/text_esen"
+MODEL_DIR="/mnt/output/v-kunwei/data/s2s_data/exp/S2S_esen/$exp_name"
+
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+
+python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
+ --config-dir $CONFIG_DIR/pretrain \
+ --config-name sc2t_base_librispeech \
+ \
+ +task.store_labels=true \
+ task.labels='["km"]' \
+ model.label_rate=50 \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.text_cfg.text_data=$TEXT_DATA_DIR \
+ +task.text_cfg.data_config=config.yaml \
+ task.text_cfg.text_maxtokens_ratio=3.0 \
+ \
+ +criterion.dec_loss_type="ce" \
+ \
+ criterion.text_weight=1.0 \
+ \
+ model.use_rel_pos_enc=true \
+ +model.code_use_rel_pos_enc=true \
+ +model.pad_with_code=true \
+ model.text_transformer.no_scale_embedding=true \
+ model.text_transformer.layernorm_embedding=true \
+ +model.share_decoder_input_output_embed=true \
+ \
+ dataset.train_subset=\"train+en.kmu-spm\" \
+ dataset.valid_subset=\"valid+en_valid.kmu-spm\" \
+ dataset.num_workers=0 \
+ dataset.max_tokens=1000000 \
+ optimization.update_freq=[${update_freq}] \
+ optimization.max_update=400000 \
+ \
+ distributed_training.distributed_world_size=${world_size} \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=${exp_name}
+
+
+sleep 5m
+echo "All finished"
+
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config.yaml b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..58ba896d1a38a7ac980d213d818b1d2e427c9eb6
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config.yaml
@@ -0,0 +1,4 @@
+audio_root: ./
+standardize_audio: true
+use_audio_input: true
+vocab_filename: dict.txt
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/finetune_asr/base_100h.yaml b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/finetune_asr/base_100h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7c9fae8e626ccb3d209334d754ff6823b40c2c4e
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/finetune_asr/base_100h.yaml
@@ -0,0 +1,101 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tblog
+ seed: 1337
+
+checkpoint:
+ save_interval: 1
+ keep_last_epochs: 5
+ keep_best_checkpoints: 5
+ best_checkpoint_metric: wer
+ restore_file: checkpoint_last.pt
+
+distributed_training:
+ ddp_backend: c10d
+ find_unused_parameters: true
+ distributed_world_size: 1
+ distributed_port: -1
+ nprocs_per_node: 8
+
+task:
+ _name: hubert_pretraining
+ data: ???
+ fine_tuning: true
+ label_dir: ???
+ normalize: false # must be consistent with pre-training
+ labels: ["ltr"]
+ single_target: true
+ add_decoder: false
+ pad_audio: false
+ random_crop: true
+ tokenizer: "none"
+ sp_path: None
+
+dataset:
+ num_workers: 0
+ max_tokens: 1200000
+ skip_invalid_size_inputs_valid_test: true
+ train_subset: train_100
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: label_smoothed_cross_entropy
+ #zero_infinity: true
+
+
+optimization:
+ max_update: 80000
+ lr: [0.00003]
+ sentence_avg: true
+ update_freq: [1]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+ weight_decay: 0.0
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: hubert_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.1
+ decoder_layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
+ add_decoder: false
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ - criterion.wer_kenlm_model
+ - criterion.wer_lexicon
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/finetune_asr/large_960h.yaml b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/finetune_asr/large_960h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..360182329dd245e1d2f8d10f412654fc5ba2afb3
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/finetune_asr/large_960h.yaml
@@ -0,0 +1,98 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_interval: 1
+ keep_last_epochs: 10
+ keep_best_checkpoints: 5
+ best_checkpoint_metric: wer
+ restore_file: checkpoint_last.pt
+
+distributed_training:
+ ddp_backend: c10d
+ find_unused_parameters: true
+ distributed_world_size: 24
+ distributed_port: -1
+ nprocs_per_node: 8
+
+task:
+ _name: hubert_pretraining
+ data: ???
+ fine_tuning: true
+ label_dir: ???
+ normalize: true # must be consistent with pre-training
+ labels: ["ltr"]
+ single_target: true
+ add_decoder: false
+ pad_audio: false
+ random_crop: true
+ tokenizer: "none"
+ sp_path: None
+
+dataset:
+ num_workers: 0
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 200000
+ lr: [0.00003]
+ sentence_avg: true
+ update_freq: [1]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+ weight_decay: 0.0
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: hubert_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.5
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.0
+ decoder_layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
+ add_decoder: false
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ - criterion.wer_kenlm_model
+ - criterion.wer_lexicon
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/pretrain/mbart.yaml b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/pretrain/mbart.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..51025f2f8ec584a888a4e07c8c246829351af948
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/pretrain/mbart.yaml
@@ -0,0 +1,120 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ seed: 1337
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_dir: ???
+ save_interval: 4
+ keep_last_epochs: 4
+ save_interval_updates: 20000
+ keep_interval_updates: -1
+ keep_interval_updates_pattern: 50000
+ # no_epoch_checkpoints: true
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_world_size: 8
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: denoising
+ data: ???
+ mask: 0.15
+
+dataset:
+ num_workers: 6
+ max_tokens: 1400000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: ${checkpoint.save_interval}
+ validate_interval_updates: ${checkpoint.save_interval_updates}
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: sc2t
+ pred_masked_weight: 1.0
+ pred_nomask_weight: 0.0
+ loss_weights: [10,]
+ label_smoothing: 0.1
+ text_weight: 0.1
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 10.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: stbert
+ label_rate: ???
+ skip_masked: false
+ skip_nomask: false
+ mask_prob: 0.80
+ extractor_mode: default
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
+ final_dim: 256
+ encoder_layers: 6
+ encoder_attention_heads: 8
+ decoder_layerdrop: 0.05
+ dropout_input: 0.1
+ dropout_features: 0.1
+ dropout: 0.1
+ attention_dropout: 0.1
+ feature_grad_mult: 0.1
+ untie_final_proj: true
+ activation_dropout: 0.0
+ use_rel_pos_enc: true
+ add_code_encoder: true
+ add_adaptor: false
+ text_transformer:
+ activation_fn: ${model.activation_fn}
+ dropout: ${model.dropout}
+ attention_dropout: ${model.attention_dropout}
+ activation_dropout: ${model.activation_dropout}
+ adaptive_input: ${model.adaptive_input}
+ max_source_positions: 3000
+ checkpoint_activations: ${model.checkpoint_activations}
+ no_scale_embedding: false
+ layernorm_embedding: false
+ quant_noise:
+ pq: ${model.quant_noise_pq}
+ encoder:
+ embed_dim: 768
+ ffn_embed_dim: 3072
+ layers: 6
+ attention_heads: 8
+ normalize_before: false
+ learned_pos: true
+ layerdrop: ${model.encoder_layerdrop}
+
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/pretrain/sc2t_base_librispeech.yaml b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/pretrain/sc2t_base_librispeech.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0cd16561c9d4715d21824cbbc7271940d3ceeda7
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/pretrain/sc2t_base_librispeech.yaml
@@ -0,0 +1,137 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ seed: 1337
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_dir: ???
+ save_interval: 4
+ keep_last_epochs: 4
+ save_interval_updates: 20000
+ keep_interval_updates: -1
+ keep_interval_updates_pattern: 50000
+ # no_epoch_checkpoints: true
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_world_size: 8
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: joint_sc2t_pretraining
+ data: ???
+ label_dir: ???
+ labels: ???
+ label_rate: ${model.label_rate}
+ sample_rate: 16000
+ max_sample_size: 250000
+ min_sample_size: 32000
+ pad_audio: false
+ random_crop: true
+ normalize: false # must be consistent with extractor
+ add_decoder: true
+ text_cfg:
+ seed: ${common.seed}
+ text_data: ???
+ sample_break_mode: eos
+ tokens_per_sample: 1024
+ shorten_method: "random_crop"
+ text_maxtokens_ratio: 1.0
+
+
+dataset:
+ num_workers: 6
+ max_tokens: 1400000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: ${checkpoint.save_interval}
+ validate_interval_updates: ${checkpoint.save_interval_updates}
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: sc2t
+ pred_masked_weight: 1.0
+ pred_nomask_weight: 0.0
+ loss_weights: [10,]
+ label_smoothing: 0.1
+ text_weight: 0.1
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 10.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: stbert
+ label_rate: ???
+ skip_masked: false
+ skip_nomask: false
+ mask_prob: 0.80
+ extractor_mode: default
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
+ final_dim: 256
+ encoder_layers: 6
+ encoder_attention_heads: 8
+ decoder_layerdrop: 0.05
+ dropout_input: 0.1
+ dropout_features: 0.1
+ dropout: 0.1
+ attention_dropout: 0.1
+ feature_grad_mult: 0.1
+ untie_final_proj: true
+ activation_dropout: 0.0
+ use_rel_pos_enc: true
+ add_code_encoder: true
+ add_adaptor: false
+ text_transformer:
+ activation_fn: ${model.activation_fn}
+ dropout: ${model.dropout}
+ attention_dropout: ${model.attention_dropout}
+ activation_dropout: ${model.activation_dropout}
+ adaptive_input: ${model.adaptive_input}
+ max_source_positions: 3000
+ checkpoint_activations: ${model.checkpoint_activations}
+ no_scale_embedding: false
+ layernorm_embedding: false
+ quant_noise:
+ pq: ${model.quant_noise_pq}
+ encoder:
+ embed_dim: 768
+ ffn_embed_dim: 3072
+ layers: 6
+ attention_heads: 8
+ normalize_before: false
+ learned_pos: true
+ layerdrop: ${model.encoder_layerdrop}
+
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/translation/text2code.yaml b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/translation/text2code.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bed25135e0da21c20d33475ad33437c63e6703d7
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config/translation/text2code.yaml
@@ -0,0 +1,81 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tblog
+ seed: 1337
+
+checkpoint:
+ save_interval: 1000000
+ keep_last_epochs: 5
+ save_interval_updates: 1000
+ keep_interval_updates_pattern: 10000
+ keep_interval_updates: 5
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+
+distributed_training:
+ ddp_backend: c10d
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 8
+
+
+criterion:
+ _name: "label_smoothed_cross_entropy"
+
+
+task:
+ _name: "translation_from_jst"
+
+dataset:
+ num_workers: 0
+ max_tokens: 4096
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: ${model.freeze_finetune_updates}
+ validate_interval: ${checkpoint.save_interval}
+ validate_interval_updates: ${checkpoint.save_interval_updates}
+ train_subset: train_clean_100
+ valid_subset: dev_clean
+ required_batch_size_multiple: 1
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.0
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: hubert_t2c
+ w2v_path: ???
+ layerdrop: 0.1
+ decoder_layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ - model.w2v_path
+ - dataset.train_subset
+ - dataset.valid_subset
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config_mbart.yaml b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config_mbart.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..51025f2f8ec584a888a4e07c8c246829351af948
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/config_mbart.yaml
@@ -0,0 +1,120 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ seed: 1337
+ tensorboard_logdir: tblog
+
+checkpoint:
+ save_dir: ???
+ save_interval: 4
+ keep_last_epochs: 4
+ save_interval_updates: 20000
+ keep_interval_updates: -1
+ keep_interval_updates_pattern: 50000
+ # no_epoch_checkpoints: true
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_world_size: 8
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: denoising
+ data: ???
+ mask: 0.15
+
+dataset:
+ num_workers: 6
+ max_tokens: 1400000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: ${checkpoint.save_interval}
+ validate_interval_updates: ${checkpoint.save_interval_updates}
+ required_batch_size_multiple: 1
+
+criterion:
+ _name: sc2t
+ pred_masked_weight: 1.0
+ pred_nomask_weight: 0.0
+ loss_weights: [10,]
+ label_smoothing: 0.1
+ text_weight: 0.1
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 10.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: stbert
+ label_rate: ???
+ skip_masked: false
+ skip_nomask: false
+ mask_prob: 0.80
+ extractor_mode: default
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
+ final_dim: 256
+ encoder_layers: 6
+ encoder_attention_heads: 8
+ decoder_layerdrop: 0.05
+ dropout_input: 0.1
+ dropout_features: 0.1
+ dropout: 0.1
+ attention_dropout: 0.1
+ feature_grad_mult: 0.1
+ untie_final_proj: true
+ activation_dropout: 0.0
+ use_rel_pos_enc: true
+ add_code_encoder: true
+ add_adaptor: false
+ text_transformer:
+ activation_fn: ${model.activation_fn}
+ dropout: ${model.dropout}
+ attention_dropout: ${model.attention_dropout}
+ activation_dropout: ${model.activation_dropout}
+ adaptive_input: ${model.adaptive_input}
+ max_source_positions: 3000
+ checkpoint_activations: ${model.checkpoint_activations}
+ no_scale_embedding: false
+ layernorm_embedding: false
+ quant_noise:
+ pq: ${model.quant_noise_pq}
+ encoder:
+ embed_dim: 768
+ ffn_embed_dim: 3072
+ layers: 6
+ attention_heads: 8
+ normalize_before: false
+ learned_pos: true
+ layerdrop: ${model.encoder_layerdrop}
+
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/extract_hubert_feature_itp.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/extract_hubert_feature_itp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..52929896c612957d7fc8df452015411b0e6038bc
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/extract_hubert_feature_itp.sh
@@ -0,0 +1,41 @@
+
+if [ ! -d ${HOME}/azcopy_linux_amd64_10.11.0 ]; then
+ CURRENT_DIR=`pwd`
+ cd ${HOME} && wget https://azcopyvnext.azureedge.net/release20210616/azcopy_linux_amd64_10.11.0.tar.gz && tar -zxvf azcopy_linux_amd64_10.11.0.tar.gz && rm -f azcopy_linux_amd64_10.11.0.tar.gz && cd ${CURRENT_DIR}
+fi
+export PATH=$PATH:${HOME}/azcopy_linux_amd64_10.11.0/:${HOME}/.local/bin
+export PYTHONPATH=$PYTHONPATH:/mnt/output/users/v-kunwei/code/fairseq
+
+rank=$1
+nshard=$2
+split=$3
+[ -z $rank ] && echo "please specify rank"
+[ -z $nshard ] && nshard=1
+[ -z $split ] && split="train"
+
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq
+ckpt_path=/mnt/output/users/v-kunwei/code/fairseq/examples/speech_to_speech/mhubert_base_vp_en_es_fr_it3.pt
+tsv_dir=/home/v-kunwei
+
+feat_dir=${HOME}/$split
+python $FAIRSEQ_ROOT/examples/hubert/simple_kmeans/dump_hubert_feature.py ${tsv_dir} ${split} ${ckpt_path} 9 ${nshard} ${rank} ${feat_dir} || exit 1
+
+
+echo "-------------------------------------------------------------------------------------------"
+echo "---------------------------------- done ---------------------------------------------"
+echo "-------------------------------------------------------------------------------------------"
+
+km_path=/mnt/output/users/v-kunwei/code/fairseq/examples/speech_to_speech/mhubert_base_vp_en_es_fr_it3_L11_km1000.bin
+lab_dir=${HOME}/${split}
+python $FAIRSEQ_ROOT/examples/hubert/simple_kmeans/dump_km_label.py ${feat_dir} ${split} ${km_path} ${nshard} ${rank} ${lab_dir}
+
+
+# sas="?sv=2020-08-04&st=2022-01-02T04%3A58%3A15Z&se=2022-06-01T04%3A58%3A00Z&sr=c&sp=racwdl&sig=NyZKOHivgesEoZ8yvLsVT6aZMYQZMevLLmXNOTaWyvU%3D"
+# blob="https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-ziqzhang/data/stbert/data/librispeech/libri_960/hubert_release_iter2_layer9_kmeans/${split}"
+# azcopy copy $feat_dir/${split}_${rank}_${nshard}.len "$blob/$sas"
+# azcopy copy $feat_dir/${split}_${rank}_${nshard}.npy "$blob/$sas"
+# azcopy copy $lab_dir "$blob/$sas" --recursive
+
+
+
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/merge_code.py b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/merge_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..a02ba3e3058b75e2e603d7470e9ef93beebabcfa
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/merge_code.py
@@ -0,0 +1,14 @@
+import sys
+import torch
+
+
+def main():
+ for line in sys.stdin:
+ line = line.rstrip()
+ codes = list(map(int, line.split()))
+ merged_codes = torch.unique_consecutive(torch.tensor(codes)).numpy()
+ merged_codes = map(str, merged_codes)
+ print(" ".join(merged_codes))
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/txt2idx.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/txt2idx.sh
new file mode 100644
index 0000000000000000000000000000000000000000..466f8a3ef8debba9c9f5a76cfb02d1e25217c6b4
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/txt2idx.sh
@@ -0,0 +1,43 @@
+[ $# -lt 3 ] && echo "Usage: $0 " && exit 0
+
+if [ ! -d ${HOME}/sentencepiece ]; then
+ CURRENT_DIR=`pwd`
+ cd ${HOME}
+ git clone https://github.com/google/sentencepiece.git
+ cd sentencepiece
+ mkdir build && cd build
+ cmake .. && make -j 16
+ sudo make install
+ sudo ldconfig -v
+ cd ${HOME}
+ cd ${CURRENT_DIR}
+fi
+
+input=$1
+outdir=$2
+DICT=$3
+suffix=$4
+outname=${input##*/}
+outname=${outname%.txt*}
+[ -z $input ] && echo "You must specify a source file" && exit 1
+
+[ -z $DICT ] && echo "No dict was specified!" && exit 1
+[ -z $outdir ] && outdir=${input%/*}
+[ -z $outdir ] && outdir="."
+[ ! -d $outdir ] && mkdir -p $outdir
+
+echo "Dict : $DICT"
+echo "------------------------------- creating idx/bin--------------------------------------------"
+echo "$input --> $outdir/${outname}${suffix}.idx"
+fairseq-preprocess \
+ --only-source \
+ --trainpref $input \
+ --destdir $outdir \
+ --thresholdsrc 0 \
+ --srcdict ${DICT} \
+ --workers 40
+
+mv $outdir/train.idx $outdir/${outname}${suffix}.idx
+mv $outdir/train.bin $outdir/${outname}${suffix}.bin
+echo "----------------------------------- done --------------------------------------------"
+
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/txt2spm.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/txt2spm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6baf72227b4013512af8a6724d2bff2156a47078
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/txt2spm.sh
@@ -0,0 +1,33 @@
+[ $# -lt 2 ] && echo "Usage: $0 " && exit 0
+
+if [ ! -d ${HOME}/sentencepiece ]; then
+ CURRENT_DIR=`pwd`
+ cd ${HOME}
+ git clone https://github.com/google/sentencepiece.git
+ cd sentencepiece
+ mkdir build && cd build
+ cmake .. && make -j 16
+ sudo make install
+ sudo ldconfig -v
+ cd ${HOME}
+ cd ${CURRENT_DIR}
+fi
+
+input=$1
+outdir=$2
+MODEL=$3
+suffix=$4
+outname=${input##*/}
+outname=${outname%.wrd*}
+[ -z $input ] && echo "You must specify a source file" && exit 1
+
+[ -z $MODEL ] && MODEL=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/hubert_release_iter2_layer9_kmeans/spm_unigram_10000.model && echo "No spm model was specified!, set default to $MODEL"
+[ -z $outdir ] && outdir=${input%/*}
+[ -z $outdir ] && outdir="."
+[ ! -d $outdir ] && mkdir -p $outdir
+
+echo "Output: $outdir/$outname.spm"
+
+echo "------------------------------- tokenize text...--------------------------------------------"
+spm_encode --model=$MODEL < ${input} > $outdir/$outname.spm || exit 1
+echo "----------------------------------- done --------------------------------------------"
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/wmt/normalize_en_text.py b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/wmt/normalize_en_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..83e332575ba317ded70c4095eeebbc5ec588b965
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/wmt/normalize_en_text.py
@@ -0,0 +1,46 @@
+import re
+import sys
+import regex
+import argparse
+from tqdm import tqdm
+from num2words import num2words
+
+def writefile(filename, lines):
+ with open(filename, 'w', encoding='utf-8') as f:
+ f.writelines(lines)
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", "-i", required=True, type=str)
+ parser.add_argument("--output", "-o", required=True, type=str)
+ args = parser.parse_args()
+ outlines = []
+
+ with open(f"{args.input}", 'r') as f:
+ inputs = f.readlines()
+
+ for line in tqdm(inputs):
+ line = line.strip().upper()
+ line = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039\'])", " ", line)
+ items = []
+ for item in line.split():
+ if item.isdigit():
+ try:
+ item = num2words(item)
+ except Exception as e:
+ print(line)
+ raise(e)
+ items.append(item)
+ line = " ".join(items)
+ line = line.replace("-", " ")
+ line = line.upper()
+ line = line.replace("' S", "'S")
+ line = line.replace(" ", "|")
+ line = " ".join(line) + " |"
+ outlines.append(line + '\n')
+ # print(line)
+
+ writefile(args.output, outlines)
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/wmt/normalize_es_text.py b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/wmt/normalize_es_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..0136b534be0bf4fef1c84b51c83a7ac9ad437700
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/data_process/wmt/normalize_es_text.py
@@ -0,0 +1,49 @@
+import re
+import sys
+import regex
+import argparse
+import re,string
+from tqdm import tqdm
+from num2words import num2words
+
+def writefile(filename, lines):
+ with open(filename, 'w', encoding='utf-8') as f:
+ f.writelines(lines)
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", "-i", required=True, type=str)
+ parser.add_argument("--output", "-o", required=True, type=str)
+ args = parser.parse_args()
+ outlines = []
+
+ with open(f"{args.input}", 'r') as f:
+ inputs = f.readlines()
+
+ for line in tqdm(inputs):
+ line = line.strip()
+ line = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039\u00d1\u00f1\'])", " ", line)
+ items = []
+ punc='~`!#$%^&*()_+-=|\';":/.,?><~.'
+ for item in line.split():
+ if item.isdigit():
+ try:
+ item = num2words(item, lang='es')
+ except Exception as e:
+ print(line)
+ raise(e)
+ items.append(item)
+ line = " ".join(items)
+ line = (re.sub(r"[%s]+" %punc, "",line))
+ line = line.replace("-", " ")
+ line = line.lower()
+ line = line.replace("' S", "'S")
+ line = line.replace(" ", "|")
+ line = " ".join(line) + " |"
+ outlines.append(line + '\n')
+ # print(line)
+
+ writefile(args.output, outlines)
+
+if __name__ == "__main__":
+ main()
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/decode_text2code_beam2.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/decode_text2code_beam2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c9dcc10425a3a519ec456c73d15f3339de2a0eba
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/decode_text2code_beam2.sh
@@ -0,0 +1,50 @@
+
+#####################################
+# Hubert ED model #
+#####################################
+[ $# -lt 1 ] && echo "Usage: $0 " && exit 0
+#source /mnt/default/v-ziqzhang/.bashrc_sing
+
+model_path=$1
+gen_set=$2
+tgt=$3
+src="ltr"
+max_tokens=$4
+word_size=$5
+rank=$6
+outdir=$7
+
+[ -z $tgt ] && tgt="kmu"
+[ -z $gen_set ] && gen_set="dev_clean"
+[ -z $word_size ] && word_size=1
+[ -z $rank ] && rank=0
+[ -z $max_tokens ] && max_tokens=16000
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+DATA_DIR=/home/v-kunwei/
+[ $gen_set == "test" ] && DATA_DIR=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+[ -z $outdir ] && outdir=$DATA_DIR
+
+
+results_path=$outdir/pseudo_${gen_set}_${rank}
+[ ! -d $results_path ] && mkdir -p $results_path
+
+for subset in $gen_set; do
+ python $FAIRSEQ_ROOT/fairseq_cli/generate_mt_label.py $DATA_DIR \
+ --path ${model_path} \
+ --task "translation_from_jst" \
+ --max-target-positions 18000 \
+ --gen-subset $subset \
+ -t $tgt -s "ltr" \
+ --dataset-impl "raw" \
+ --max-tokens ${max_tokens} \
+ --beam 2 \
+ --max-len-a 3 --max-len-b 100 \
+ --results-path $results_path \
+ --distributed-world-size $word_size --distributed-rank $rank \
+
+ echo "$model" > $results_path/model.record
+ sleep 1s
+done | tee $results_path/decode.log
+
+sleep 2s
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/eval2.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/eval2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0736ef4e338c9837cafc61d3c903d4683d684ea9
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/eval2.sh
@@ -0,0 +1,12 @@
+lmweight=0
+num_gpus=8
+python examples/speech_recognition/new/infer.py --config-dir /mnt/output/users/v-kunwei/code/fairseq/examples/speech_recognition/new/conf \
+--config-name infer task=audio_finetuning task.data=/home/v-kunwei common.user_dir=/mnt/output/users/v-kunwei/code/fairseq/examples/data2vec \
+task.labels=ltr decoding.type=viterbi \
+decoding.lexicon=models/es_eval/espeak_dict.txt \
+decoding.unique_wer_file=True \
+dataset.gen_subset=test \
+common_eval.path=/mnt/output/users/v-kunwei/code/fairseq/models/es_eval/espeak_26lang_m10.pt decoding.beam=1500 distributed_training.distributed_world_size=${num_gpus} \
+decoding.results_path=/home/v-kunwei
+
+#sclite -h "/home/v-kunwei/hypo.units" -r "/home/v-kunwei/ref.units" -i rm -o all stdout > "./result.txt"
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/eval3.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/eval3.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4a2354319ddc7a672506e92e7577d3dc978b47a8
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/eval3.sh
@@ -0,0 +1,4 @@
+#$subset=test
+python examples/speech_recognition/infer.py /home/v-kunwei --task audio_finetuning \
+--nbest 1 --path /mnt/output/users/v-kunwei/code/fairseq/models/es_eval/espeak_26lang_m10.pt --gen-subset test --results-path /home/v-kunwei --criterion ctc --labels ltr --max-tokens 4000000 \
+--post-process letter
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/finetune_enes.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/finetune_enes.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eaae1476bc5f80640abee6a85bdd1f453c15d97a
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/finetune_enes.sh
@@ -0,0 +1,85 @@
+# ####################################
+# Hubert ED model #
+# ####################################
+#source /mnt/default/v-ziqzhang/.bashrc_sing
+
+[ $# -lt 4 ] && echo "Usage: $0 " && exit 0
+world_size=$1
+update_freq=$2
+w2v_path=$3
+cpt=$4
+Mount=$5
+
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=3
+[ -z $w2v_path ] && echo "you must specify a wav_path !" && exit 1
+[ -z $cpt ] && cpt=030.pt
+[ -z $Mount ] && Mount=/mnt/default
+
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+CONFIG_DIR=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config
+DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/fin_enes100"
+
+exp_name=${w2v_path%/*}
+exp_name=${exp_name##*/}
+MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/finetune/tune_ST_from_eneshu"
+exp_name="tune_enes_lr5e-5_from_$cpt"
+MODEL_DIR=$MODEL_DIR/$exp_name
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+max_tokens=490000
+
+python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
+ --config-dir $CONFIG_DIR/finetune_asr \
+ --config-name base_100h \
+ \
+ +task.store_labels=true \
+ task.labels='["spm"]' \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.add_decoder=true \
+ +task.max_keep_size=490000 \
+ \
+ +model.reuse_text_emb=true \
+ model._name="stbert_st" \
+ model.w2v_path=${w2v_path} \
+ model.add_decoder=true \
+ \
+ criterion._name="label_smoothed_cross_entropy" \
+ +criterion.label_smoothing=0.2 \
+ +criterion.report_accuracy=true \
+ \
+ lr_scheduler._name="polynomial_decay" \
+ +lr_scheduler.warmup_updates=20000 \
+ \
+ optimization.lr=[0.0003] \
+ optimization.max_update=100000 \
+ checkpoint.best_checkpoint_metric="accuracy" \
+ checkpoint.maximize_best_checkpoint_metric=true \
+ checkpoint.save_interval=1 \
+ \
+ dataset.train_subset="train" \
+ dataset.valid_subset="valid" \
+ dataset.max_tokens=$max_tokens \
+ optimization.update_freq=[${update_freq}] \
+ \
+ distributed_training.distributed_world_size=${world_size} \
+ distributed_training.distributed_port=-1 \
+ \
+ common.log_interval=100 \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=${exp_name}
+
+
+
+sleep 20s
+
+ # \
+ # lr_scheduler._name="polynomial_decay" \
+ # +lr_scheduler.warmup_updates=5000 \
+
+
+# /mnt/default/v-ziqzhang/data/stbert-ed/exp/ST_enes/sc2t_base_ende_32gpu_1accum/checkpoint_204_400000.pt
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/finetune_esen.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/finetune_esen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a9051f67008817d200c797b67ee4919ed5e2797a
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/finetune_esen.sh
@@ -0,0 +1,85 @@
+# ####################################
+# Hubert ED model #
+# ####################################
+#source /mnt/default/v-ziqzhang/.bashrc_sing
+
+[ $# -lt 4 ] && echo "Usage: $0 " && exit 0
+world_size=$1
+update_freq=$2
+w2v_path=$3
+cpt=$4
+Mount=$5
+
+[ -z $world_size ] && world_size=1
+[ -z $update_freq ] && update_freq=1
+[ -z $w2v_path ] && echo "you must specify a wav_path !" && exit 1
+[ -z $cpt ] && cpt=030.pt
+[ -z $Mount ] && Mount=/mnt/default
+
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+CONFIG_DIR=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config
+DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/fin_esen"
+
+exp_name=${w2v_path%/*}
+exp_name=${exp_name##*/}
+MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/finetune/tune_ST_from_esen"
+exp_name="tune_esen_lr5e-5_from_$cpt"
+MODEL_DIR=$MODEL_DIR/$exp_name
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+max_tokens=4900
+
+python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
+ --config-dir $CONFIG_DIR/finetune_asr \
+ --config-name base_100h \
+ \
+ +task.store_labels=true \
+ task.labels='["spm"]' \
+ task.data=$DATA_DIR \
+ task.label_dir=$DATA_DIR \
+ task.add_decoder=true \
+ +task.max_keep_size=4900 \
+ \
+ +model.reuse_text_emb=true \
+ model._name="stbert_st" \
+ model.w2v_path=${w2v_path} \
+ model.add_decoder=true \
+ \
+ criterion._name="label_smoothed_cross_entropy" \
+ +criterion.label_smoothing=0.2 \
+ +criterion.report_accuracy=true \
+ \
+ lr_scheduler._name="polynomial_decay" \
+ +lr_scheduler.warmup_updates=20000 \
+ \
+ optimization.lr=[0.0002] \
+ optimization.max_update=100000 \
+ checkpoint.best_checkpoint_metric="accuracy" \
+ checkpoint.maximize_best_checkpoint_metric=true \
+ checkpoint.save_interval=1 \
+ \
+ dataset.train_subset="train" \
+ dataset.valid_subset="valid" \
+ dataset.max_tokens=$max_tokens \
+ optimization.update_freq=[${update_freq}] \
+ \
+ distributed_training.distributed_world_size=${world_size} \
+ distributed_training.distributed_port=-1 \
+ \
+ common.log_interval=100 \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=${exp_name}
+
+
+
+sleep 20s
+
+ # \
+ # lr_scheduler._name="polynomial_decay" \
+ # +lr_scheduler.warmup_updates=5000 \
+
+
+# /mnt/default/v-ziqzhang/data/stbert-ed/exp/ST_enes/sc2t_base_ende_32gpu_1accum/checkpoint_204_400000.pt
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/inference_ed.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/inference_ed.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3fd9ef1231c827d980077a30b278b8986d31c4d7
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/inference_ed.sh
@@ -0,0 +1,38 @@
+#####################################
+# Hubert base model #
+#####################################
+[ $# -lt 1 ] && echo "Usage: $0 " && exit 0
+
+model_path=$1
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+#beam_size=$2
+gen_set=$2
+#lang=$4
+[ -z $gen_set ] && gen_set="test_et"
+[ -z $beam_size ] && beam_size=2
+[ -z $lang ] && lang="fr"
+
+
+#DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/fin_enes
+DATA_DIR=/home/v-kunwei
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+
+for subset in $gen_set; do
+ results_path=$src_dir/decode_${cpt}_beam${beam_size}/${subset}
+ [ ! -d $results_path ] && mkdir -p $results_path
+
+ python $FAIRSEQ_ROOT/fairseq_cli/generate.py \
+ $DATA_DIR --label-dir ${DATA_DIR} \
+ --labels '["spm"]' --gen-subset ${subset} \
+ --max-tokens 9000000 --task hubert_pretraining \
+ --add-decoder --fine-tuning --random-crop \
+ --path ${model_path} --results-path /home/v-kunwei --scoring sacrebleu \
+ --max-len-a 0 --max-len-b 900 \
+ --beam 10 --single-target
+
+ tail -n 1 /home/v-kunwei/generate-*.txt
+ sleep 1s
+done
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..34d1594d8fda2954b8a70dbdfc059402571d70ee
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k.sh
@@ -0,0 +1,70 @@
+#####################################
+# Hubert mt model #
+#####################################
+[ $# -gt 3 ] && echo "Usage: $0 " && exit 0
+world_size=$1
+update_freq=$2
+w2v_path=$3
+Mount=""
+
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=1
+[ -z $w2v_path ] && w2v_path="/mnt/output/users/v-kunwei/data/s2s_data/model_wo_emb_32_1004.pt"
+
+
+langs="ltr,kmu"
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+CONFIG_ROOT=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config/translation
+DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/en_asr_data/
+
+### set save-dir
+MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/exp/text2unicode_en"
+exp_name="base_pt400k_releaseiter2_${world_size}gpu_${update_freq}accum_lr1e-4_alll"
+MODEL_DIR=$MODEL_DIR/$exp_name
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+
+python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
+ --config-dir $CONFIG_ROOT \
+ --config-name text2code \
+ +task.data=$DATA_DIR \
+ dataset.dataset_impl="raw" \
+ +task.source_lang="ltr" +task.target_lang="kmu" \
+ +task.normalize=false \
+ \
+ +criterion.label_smoothing=0.1 \
+ +criterion.report_accuracy=true \
+ optimizer.weight_decay=0.00001 \
+ +lr_scheduler.lr="[0.0001]" \
+ optimization.max_update=500000 \
+ \
+ +model.dropout=0.1 \
+ +model.attention_dropout=0.1 \
+ model.activation_dropout=0.1 \
+ model.decoder_layerdrop=0 \
+ model.layerdrop=0 \
+ model.w2v_path=$w2v_path \
+ +model.text_transformer_encoder_layers=6 \
+ \
+ dataset.train_subset="en_train" \
+ dataset.valid_subset="en_dev" \
+ optimization.update_freq=[${update_freq}] \
+ optimization.clip_norm=5 \
+ \
+ common.seed=222 \
+ common.log_interval=100 \
+ common.log_format="json" \
+ \
+ distributed_training.distributed_world_size=${world_size} \
+ distributed_training.nprocs_per_node=8 \
+ distributed_training.ddp_backend="legacy_ddp" \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=${exp_name} \
+
+sleep 10s
+ # sleep infinity
+
+
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k_es.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k_es.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1caf2f97f4b01def88b91d8a8422588f4f7a26d5
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k_es.sh
@@ -0,0 +1,70 @@
+#####################################
+# Hubert mt model #
+#####################################
+[ $# -gt 3 ] && echo "Usage: $0 " && exit 0
+world_size=$1
+update_freq=$2
+w2v_path=$3
+Mount=""
+
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=1
+[ -z $w2v_path ] && w2v_path="/mnt/output/users/v-kunwei/data/s2s_data/model_es_emb_90_1004.pt"
+
+
+langs="ltr,kmu"
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+CONFIG_ROOT=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config/translation
+DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_no_data/
+
+### set save-dir
+MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/exp/text2unicode_es"
+exp_name="base_pt400k_releaseiter2_${world_size}gpu_${update_freq}accum_lr1e-4_no"
+MODEL_DIR=$MODEL_DIR/$exp_name
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+
+python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
+ --config-dir $CONFIG_ROOT \
+ --config-name text2code \
+ +task.data=$DATA_DIR \
+ dataset.dataset_impl="raw" \
+ +task.source_lang="ltr" +task.target_lang="kmu" \
+ +task.normalize=false \
+ \
+ +criterion.label_smoothing=0.1 \
+ +criterion.report_accuracy=true \
+ optimizer.weight_decay=0.00001 \
+ +lr_scheduler.lr="[0.0001]" \
+ optimization.max_update=500000 \
+ \
+ +model.dropout=0.1 \
+ +model.attention_dropout=0.1 \
+ model.activation_dropout=0.1 \
+ model.decoder_layerdrop=0 \
+ model.layerdrop=0 \
+ model.w2v_path=$w2v_path \
+ +model.text_transformer_encoder_layers=6 \
+ \
+ dataset.train_subset="es_train" \
+ dataset.valid_subset="es_dev" \
+ optimization.update_freq=[${update_freq}] \
+ optimization.clip_norm=5 \
+ \
+ common.seed=222 \
+ common.log_interval=100 \
+ common.log_format="json" \
+ \
+ distributed_training.distributed_world_size=${world_size} \
+ distributed_training.nprocs_per_node=8 \
+ distributed_training.ddp_backend="legacy_ddp" \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=${exp_name} \
+
+sleep 10s
+ # sleep infinity
+
+
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k_es2.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k_es2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..910a6f35e43a0451b241a2033236039f009f0f75
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/base_ReleaseIter2_text2unicode_from400k_es2.sh
@@ -0,0 +1,70 @@
+#####################################
+# Hubert mt model #
+#####################################
+[ $# -gt 3 ] && echo "Usage: $0 " && exit 0
+world_size=$1
+update_freq=$2
+w2v_path=$3
+Mount=""
+
+[ -z $world_size ] && world_size=8
+[ -z $update_freq ] && update_freq=1
+[ -z $w2v_path ] && w2v_path="/mnt/output/users/v-kunwei/data/s2s_data/model_es_emb_81_1004.pt"
+
+
+langs="ltr,kmu"
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+CONFIG_ROOT=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config/translation
+DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_asrl_data/
+
+### set save-dir
+MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/exp/text2unicode_es"
+exp_name="base_pt400k_releaseiter2_${world_size}gpu_${update_freq}accum_lr1e-4_ll"
+MODEL_DIR=$MODEL_DIR/$exp_name
+[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
+
+
+python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
+ --config-dir $CONFIG_ROOT \
+ --config-name text2code \
+ +task.data=$DATA_DIR \
+ dataset.dataset_impl="raw" \
+ +task.source_lang="ltr" +task.target_lang="kmu" \
+ +task.normalize=false \
+ \
+ +criterion.label_smoothing=0.1 \
+ +criterion.report_accuracy=true \
+ optimizer.weight_decay=0.00001 \
+ +lr_scheduler.lr="[0.0001]" \
+ optimization.max_update=500000 \
+ \
+ +model.dropout=0.1 \
+ +model.attention_dropout=0.1 \
+ model.activation_dropout=0.1 \
+ model.decoder_layerdrop=0 \
+ model.layerdrop=0 \
+ model.w2v_path=$w2v_path \
+ +model.text_transformer_encoder_layers=6 \
+ \
+ dataset.train_subset="es_train" \
+ dataset.valid_subset="es_dev" \
+ optimization.update_freq=[${update_freq}] \
+ optimization.clip_norm=5 \
+ \
+ common.seed=222 \
+ common.log_interval=100 \
+ common.log_format="json" \
+ \
+ distributed_training.distributed_world_size=${world_size} \
+ distributed_training.nprocs_per_node=8 \
+ distributed_training.ddp_backend="legacy_ddp" \
+ \
+ common.tensorboard_logdir=$MODEL_DIR \
+ checkpoint.save_dir=$MODEL_DIR \
+ hydra.run.dir=$MODEL_DIR \
+ hydra.job.name=${exp_name} \
+
+sleep 10s
+ # sleep infinity
+
+
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/decode_text2code.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/decode_text2code.sh
new file mode 100644
index 0000000000000000000000000000000000000000..866146d4a26cea23c4dc51d5f53c90f58bfadc21
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/decode_text2code.sh
@@ -0,0 +1,51 @@
+
+#####################################
+# Hubert ED model #
+#####################################
+[ $# -lt 1 ] && echo "Usage: $0 " && exit 0
+#source /mnt/default/v-ziqzhang/.bashrc_sing
+
+model_path=$1
+gen_set=$2
+tgt=$3
+src="ltr"
+max_tokens=$4
+word_size=$5
+rank=$6
+outdir=$7
+
+[ -z $tgt ] && tgt="kmu"
+[ -z $gen_set ] && gen_set="dev_clean"
+[ -z $word_size ] && word_size=1
+[ -z $rank ] && rank=0
+[ -z $max_tokens ] && max_tokens=2000
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlst
+DATA_DIR=${gen_set%/*}
+gen_set=${gen_set##*/}
+[ $gen_set == "test" ] && DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/en_asr_data
+[ -z $outdir ] && outdir=$DATA_DIR
+
+
+results_path=$outdir/pseudo_${gen_set}_${rank}
+[ ! -d $results_path ] && mkdir -p $results_path
+
+for subset in $gen_set; do
+ python $FAIRSEQ_ROOT/fairseq_cli/generate_mt_label.py $DATA_DIR \
+ --path ${model_path} \
+ --task "translation_from_jst" \
+ --max-target-positions 3000 \
+ --gen-subset $subset \
+ -t $tgt -s "ltr" \
+ --max-tokens ${max_tokens} \
+ --dataset-impl "raw" \
+ --max-len-a 2 --max-len-b 100 \
+ --results-path $results_path \
+ --skip-invalid-size-inputs-valid-test \
+ --distributed-world-size $word_size --distributed-rank $rank \
+
+ echo "$model" > $results_path/model.record
+ sleep 1s
+done | tee $results_path/decode.log
+
+sleep 2s
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/decode_text2code_beam2.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/decode_text2code_beam2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9cad721b3dfcf0bbca8d82b57290dacb616b74b2
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/decode_text2code_beam2.sh
@@ -0,0 +1,52 @@
+
+#####################################
+# Hubert ED model #
+#####################################
+[ $# -lt 1 ] && echo "Usage: $0 " && exit 0
+#source /mnt/default/v-ziqzhang/.bashrc_sing
+
+model_path=$1
+gen_set=$2
+tgt=$3
+src="ltr"
+max_tokens=$4
+word_size=$5
+rank=$6
+outdir=$7
+
+[ -z $tgt ] && tgt="kmu"
+[ -z $gen_set ] && gen_set="dev_clean"
+[ -z $word_size ] && word_size=1
+[ -z $rank ] && rank=0
+[ -z $max_tokens ] && max_tokens=2000
+
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+DATA_DIR=${gen_set%/*}
+gen_set=${gen_set##*/}
+[ $gen_set == "test" ] && DATA_DIR=/mnt/output/users/v-kunwei/code/fairseq_mlstku
+[ -z $outdir ] && outdir=$DATA_DIR
+
+
+results_path=$outdir/pseudo_${gen_set}_${rank}
+[ ! -d $results_path ] && mkdir -p $results_path
+
+for subset in $gen_set; do
+ python $FAIRSEQ_ROOT/fairseq_cli/generate_mt_label.py $DATA_DIR \
+ --path ${model_path} \
+ --task "translation_from_jst" \
+ --max-target-positions 3000 \
+ --gen-subset $subset \
+ -t $tgt -s "ltr" \
+ --dataset-impl "raw" \
+ --max-tokens ${max_tokens} \
+ --beam 2 \
+ --max-len-a 2 --max-len-b 100 \
+ --results-path $results_path \
+ --skip-invalid-size-inputs-valid-test \
+ --distributed-world-size $word_size --distributed-rank $rank \
+
+ echo "$model" > $results_path/model.record
+ sleep 1s
+done | tee $results_path/decode.log
+
+sleep 2s
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/inference_code_bleu.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/inference_code_bleu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..240d4874c02fb1b06c18af32382ae4aee3297113
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/inference_code_bleu.sh
@@ -0,0 +1,52 @@
+
+#####################################
+# Hubert ED model #
+#####################################
+[ $# -lt 1 ] && echo "Usage: $0 " && exit 0
+
+model_path=$1
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+gen_set=$2
+tgt=$3
+outdir=$4
+src="ltr"
+[ -z $tgt ] && tgt="kmu"
+[ -z $gen_set ] && gen_set="es_dev"
+[ -z $outdir ] && outdir=$src_dir/decode_${cpt}
+
+DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_asr_data/
+# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_joint_splitenc_400k/ltr-$tgt
+# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_400k/ltr-$tgt
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlst
+
+langs="ltr,$tgt"
+
+for subset in $gen_set; do
+ results_path=$outdir/${subset}
+ [ ! -d $results_path ] && mkdir -p $results_path
+
+ python $FAIRSEQ_ROOT/fairseq_cli/generate.py $DATA_DIR \
+ --path ${model_path} \
+ --task "translation_from_jst" \
+ --max-target-positions 3000 \
+ --gen-subset $subset \
+ -t $tgt -s "ltr" --dataset-impl "raw" \
+ --batch-size 16 \
+ --max-len-a 2 --max-len-b 400 \
+ --results-path $results_path \
+ --scoring sacrebleu $extra
+
+ echo $results_path
+ tail -n 1 $results_path/generate-*.txt
+ sleep 1s
+done
+
+# --distributed-world-size 1000 --distributed-rank 0 \
+
+sleep 2s
+
+# cat generate-newstest2020_enja.txt | grep "^D-" | cut -d'-' -f 2- | sort -n -k1 | cut -f3 > decode-newstest2020_enja.txt
+# sacrebleu -t wmt20 -l en-ja -i decode-newstest2020_enja.txt --tokenize char
diff --git a/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/inference_code_wer.sh b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/inference_code_wer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8fa9670ff8629ccc857d55c7c07983cc3d2c700b
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/stpretrain_scripts/train_text2code/inference_code_wer.sh
@@ -0,0 +1,53 @@
+
+#####################################
+# Hubert ED model #
+#####################################
+[ $# -lt 1 ] && echo "Usage: $0 " && exit 0
+
+model_path=$1
+src_dir=${model_path%/*}
+cpt=${model_path##*/}
+cpt=${cpt%.*}
+
+gen_set=$2
+tgt=$3
+outdir=$4
+src="ltr"
+[ -z $tgt ] && tgt="kmu"
+[ -z $gen_set ] && gen_set="en_dev"
+[ -z $outdir ] && outdir=$src_dir/decode_${cpt}
+
+# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/hubert_release_iter2_layer9_kmeans/ltr-$tgt
+# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_joint_splitenc_400k/ltr-$tgt
+#DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_400k/ltr-$tgt
+DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_asr_data/
+FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlst
+
+langs="ltr,$tgt"
+
+for subset in $gen_set; do
+ results_path=$outdir/${subset}
+ [ ! -d $results_path ] && mkdir -p $results_path
+
+ python $FAIRSEQ_ROOT/fairseq_cli/generate.py $DATA_DIR \
+ --path ${model_path} \
+ --task "translation_from_jst" \
+ --max-target-positions 3000 \
+ --gen-subset $subset \
+ -t $tgt -s "ltr" --dataset-impl "raw" \
+ --batch-size 16 \
+ --max-len-a 2 --max-len-b 400 \
+ --results-path $results_path \
+ --scoring wer
+
+ echo $results_path
+ tail -n 1 $results_path/generate-*.txt
+ sleep 1s
+done
+
+# --distributed-world-size 1000 --distributed-rank 0 \
+
+sleep 2s
+
+# cat generate-newstest2020_enja.txt | grep "^D-" | cut -d'-' -f 2- | sort -n -k1 | cut -f3 > decode-newstest2020_enja.txt
+# sacrebleu -t wmt20 -l en-ja -i decode-newstest2020_enja.txt --tokenize char
diff --git a/SpeechT5/Speech2S/speech2s/tasks/joint_sc2t_pretrain.py b/SpeechT5/Speech2S/speech2s/tasks/joint_sc2t_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..db6e4e611f01d58f53ede5fd529fb9ceca44bcc8
--- /dev/null
+++ b/SpeechT5/Speech2S/speech2s/tasks/joint_sc2t_pretrain.py
@@ -0,0 +1,1004 @@
+# ----------------------------------------------------------------------------
+# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
+# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
+#
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# ----------------------------------------------------------------------------
+
+import logging
+import os
+import sys
+from typing import Dict, List, Optional, Tuple
+from pathlib import Path
+
+import numpy as np
+from argparse import Namespace
+from collections import OrderedDict
+
+import torch
+from dataclasses import dataclass, field
+from fairseq.data import (
+ Dictionary,
+ encoders,
+ data_utils,
+ StripTokenDataset,
+ PrependTokenDataset,
+ AppendTokenDataset,
+ DenoisingDataset,
+ ConcatDataset,
+ FairseqDataset,
+ iterators,
+ ResamplingDataset,
+ MaskTokensDataset,
+ LanguagePairDataset,
+)
+from fairseq.data.audio.speech_to_text_joint_dataset import S2TJointDataConfig
+from fairseq.data.shorten_dataset import maybe_shorten_dataset
+# from fairseq.data.encoders.utils import get_whole_word_mask
+from fairseq.dataclass.configs import FairseqDataclass
+from fairseq.tasks import register_task
+from fairseq.tasks.fairseq_task import FairseqTask
+from fairseq.dataclass.constants import ChoiceEnum
+from omegaconf import MISSING
+
+from speechut.data.multimodal_corpus_dataset import MultiCorpusDataset
+from speechut.data.load_langpair_dataset import load_langpair_dataset
+from speechut.data.language_trible_dataset import LanguageTripleDataset, load_langtriple_dataset
+from speechut.data.hubert_dataset import HubertDataset
+
+logger = logging.getLogger(__name__)
+
+TOKENIZER_CHOICES = ChoiceEnum(["sentencepiece", "hubert_letters", "none"])
+
+def _lang_token(lang: str):
+ return "".format(lang)
+
+def _lang_token_index(dic: Dictionary, lang: str):
+ """Return language token index."""
+ idx = dic.index(_lang_token(lang))
+ assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang)
+ return idx
+
+
+class LabelEncoder(object):
+ def __init__(self, dictionary: Dictionary) -> None:
+ self.dictionary = dictionary
+
+ def __call__(self, label: str) -> List[str]:
+ return self.dictionary.encode_line(
+ label, append_eos=False, add_if_not_exist=False,
+ )
+
+
+### wrap the initial get_whole_word_mask which needs bpe_tokenizer,
+### here we just assume words are splited by "|" or ""
+def get_whole_word_mask(args, dictionary):
+ def is_beginning_of_word(i):
+ if i < dictionary.nspecial:
+ # special elements are always considered beginnings
+ return True
+ tok = dictionary[i]
+ if tok.startswith("madeupword"):
+ return True
+ elif tok in ["", "", "", "", "|", ""]:
+ return True
+ else:
+ return False
+
+ mask_whole_words = torch.ByteTensor(
+ list(map(is_beginning_of_word, range(len(dictionary))))
+ )
+ return mask_whole_words
+
+def get_repeative_start(tokens):
+ """
+ tokens: torch.Tensor with repeative tokens
+ """
+ length = len(tokens)
+ rep_start_id = tokens[:-1] != tokens[1:]
+ return torch.cat([torch.tensor([True]), rep_start_id])
+
+@dataclass
+class TextPretrainingConfig(FairseqDataclass):
+ ### added for joint pretraining
+ text_data: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "if set, path to text data directory",
+ },
+ )
+ seed: Optional[int] = field(
+ default=1,
+ metadata={
+ "help": "for ordered_indices in MulticorpusDataset",
+ },
+ )
+ tokens_per_sample: Optional[int] = field(
+ default=512,
+ metadata={
+ "help": "max number of total tokens over all segments per sample for dataset",
+ },
+ )
+ tokens_per_sample_tgt: Optional[int] = field(
+ default=512,
+ metadata={
+ "help": "max number of total tokens over all segments per target sample for dataset",
+ },
+ )
+ sample_break_mode: Optional[str] = field(
+ default="eos",
+ metadata={
+ "help": "mode for breaking sentence",
+ },
+ )
+ mask: Optional[float] = field(
+ default=0.3,
+ metadata={
+ "help": "fraction of words/subwords that will be masked",
+ },
+ )
+ leave_unmasked_prob: float = field(
+ default=0.1,
+ metadata={"help": "probability that a masked token is unmasked"},
+ )
+ mask_random: Optional[float] = field(
+ default=0.1,
+ metadata={
+ "help": "instead of using [MASK], use random token this often",
+ },
+ )
+ freq_weighted_replacement: bool = field(
+ default=False,
+ metadata={"help": "sample random replacement words based on word frequencies"},
+ )
+ mask_whole_words: bool = field(
+ default=True,
+ metadata={"help": "mask whole words; you may also want to set --bpe"},
+ )
+ mask_repeative_tokens: bool = field(
+ default=True,
+ metadata={"help": "mask repeative_tokens; if mask_whole_words=False"},
+ )
+ mask_multiple_length: int = field(
+ default=1,
+ metadata={"help": "repeat the mask indices multiple times"},
+ )
+ mask_stdev: float = field(
+ default=0.0,
+ metadata={"help": "stdev of the mask length"},
+ )
+ shorten_method: Optional[str] = field(
+ default="none",
+ metadata={
+ "help": "if not none, shorten sequences that exceed tokens_per_sample",
+ "choices": "none/truncate/random_crop"
+ },
+ )
+ shorten_data_split_list: Optional[str] = field(
+ default="",
+ metadata={
+ "help": "comma_separated list of dataset splits to apply shortening to, e.g., train,valid (default: all dataset splits)",
+ },
+ )
+
+ ### below hypra-parameters is used in bart
+ insert: Optional[float] = field(
+ default=0.0,
+ metadata={
+ "help": "insert this percentage of additional random tokens",
+ },
+ )
+ permute: Optional[float] = field(
+ default=0.0,
+ metadata={
+ "help": "take this proportion of subwords and permute them",
+ },
+ )
+ rotate: Optional[float] = field(
+ default=0.0,
+ metadata={
+ "help": "rotate this proportion of inputs",
+ },
+ )
+ poisson_lambda: Optional[float] = field(
+ default=3.5,
+ metadata={
+ "help": "randomly shuffle sentences for this proportion of inputs",
+ },
+ )
+ permute_sentences: Optional[float] = field(
+ default=0.0,
+ metadata={
+ "help": "shuffle this proportion of sentences in all inputs",
+ },
+ )
+ mask_length: Optional[str] = field(
+ default="span-poisson",
+ metadata={
+ "help": "mask length to choose",
+ "choice": "subword/word/span-poisson"
+ },
+ )
+ replace_length: Optional[int] = field(
+ default=1,
+ metadata={
+ "help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
+ },
+ )
+ shuffle_instance: Optional[bool] = field(
+ default=False,
+ metadata={"help": "shuffle instance"},
+ )
+ max_source_positions: Optional[int] = field(
+ default=1024,
+ metadata={"help": "max number of tokens in the source sequence"},
+ )
+ max_target_positions: Optional[int] = field(
+ default=1024,
+ metadata={"help": "max number of tokens in the target sequence"},
+ )
+ bpe: Optional[str] = field(
+ default="",
+ metadata={
+ "help": "will wrapped by the text_data_config yaml",
+ },
+ )
+ data_config: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "a config yaml specify the bpe model of text data",
+ },
+ )
+ text_maxtokens_ratio: Optional[float] = field(
+ default=1.0,
+ metadata={
+ "help": "for text, max_tokens = max_tokens * text_maxtokens_ratio / 320 ",
+ },
+ )
+ prepend_tgt_lang_tag: bool = field(
+ default=False,
+ metadata={"help": "prepend tgt_lang_tag to replace "},
+ )
+ mask_text_ratio: Optional[float] = field(
+ default=0.0,
+ metadata={
+ "help": "mask_text_ratio, for paired data",
+ },
+ )
+ truncate_mono_source: bool = field(
+ default=True,
+ metadata={"help": "truncate mono source-side examples that exceed max-positions"},
+ )
+
+
+@dataclass
+class JointPretrainingConfig(FairseqDataclass):
+ data: str = field(
+ default=MISSING, metadata={"help": "path to speech data directory"}
+ )
+ fine_tuning: bool = field(
+ default=False, metadata={"help": "set to true if fine-tuning Hubert"}
+ )
+ labels: List[str] = field(
+ default_factory=lambda: ["ltr"],
+ metadata={
+ "help": (
+ "extension of the label files to load, frame-level labels for"
+ " pre-training, and sequence-level label for fine-tuning"
+ )
+ },
+ )
+ label_dir: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "if set, looks for labels in this directory instead",
+ },
+ )
+ label_rate: int = field(
+ default=-1,
+ metadata={"help": "label frame rate. -1 for sequence label"},
+ )
+ sample_rate: int = field(
+ default=16_000,
+ metadata={
+ "help": "target sample rate. audio files will be up/down "
+ "sampled to this rate"
+ },
+ )
+ normalize: bool = field(
+ default=False,
+ metadata={
+ "help": "if set, normalizes input to have 0 mean and unit variance"
+ },
+ )
+ enable_padding: bool = field(
+ default=False,
+ metadata={"help": "pad shorter samples instead of cropping"},
+ )
+ max_keep_size: Optional[int] = field(
+ default=None,
+ metadata={"help": "exclude sample longer than this"},
+ )
+ max_sample_size: Optional[int] = field(
+ default=None,
+ metadata={"help": "max sample size to crop to for batching"},
+ )
+ min_sample_size: Optional[int] = field(
+ default=None,
+ metadata={"help": "min sample size to crop to for batching"},
+ )
+ single_target: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "if set, AddTargetDatasets outputs same keys "
+ "as AddTargetDataset"
+ },
+ )
+ random_crop: Optional[bool] = field(
+ default=True,
+ metadata={"help": "always crop from the beginning if false"},
+ )
+ pad_audio: Optional[bool] = field(
+ default=False,
+ metadata={"help": "pad audio to the longest one in the batch if true"},
+ )
+ store_labels: Optional[bool] = field(
+ default=True,
+ metadata={"help": "store spm labels in memory, should be true when fine-tune with bpe"},
+ )
+ add_decoder_target: bool = field(
+ default=False,
+ metadata={"help": "contral the model architecture, if set True, load reduced unit as target"},
+ )
+ split_modality_batch: bool = field(
+ default=False,
+ metadata={"help": "whether create all samples of different modalities in a batch"},
+ )
+ speech_tgt_lang: str = field(
+ default="",
+ metadata={"help": "prepend to prev_output_tokens to replace , only used for decoder"},
+ )
+ speech_sampling_alpha: float = field(
+ default=0.2,
+ metadata={
+ "help": "Hyper-parameter alpha = 1/T for temperature-based speech resampling."
+ "(alpha = 1 for no resampling)"
+ },
+ )
+ text_sampling_alpha: float = field(
+ default=0.2,
+ metadata={
+ "help": "Hyper-parameter alpha = 1/T for temperature-based text resampling."
+ "(alpha = 1 for no resampling)"
+ },
+ )
+ hubert_tokenizer: Optional[TOKENIZER_CHOICES] = field(
+ default="none",
+ metadata={"help": "which tokenizer for processing text"},
+ )
+ sp_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "sentencepiece model path if using bpe tokenizer"},
+ )
+ text_cfg: TextPretrainingConfig = TextPretrainingConfig()
+ # For inference
+ ctc_weight: float = field(
+ default=0.0,
+ metadata={"help": "ctc weight during inference"},
+ )
+ lm_dict: Optional[str] = field(
+ default="dict.txt",
+ metadata={"help": "dict used for decoding with language model, should be in cfg.data/"},
+ )
+
+@register_task("joint_sc2t_pretraining", dataclass=JointPretrainingConfig)
+class Jsc2tPretrainingTask(FairseqTask):
+
+ cfg: JointPretrainingConfig
+
+ def __init__(
+ self,
+ cfg: JointPretrainingConfig,
+ load_local_states: True,
+ ) -> None:
+ super().__init__(cfg)
+
+ logger.info(f"current directory is {os.getcwd()}")
+ logger.info(f"JSTPretrainingTask Config {cfg}")
+
+ self.cfg = cfg
+ self.fine_tuning = cfg.fine_tuning
+ self.blank_symbol = ""
+
+ if load_local_states:
+ self.state.add_factory("hubert_tokenizer", self.build_tokenizer)
+ if self.cfg.text_cfg.text_data is not None and os.path.exists(self.cfg.text_cfg.text_data):
+ self.state.add_factory("text_dictionary", self.load_text_dictionary)
+ self.state.add_factory("text_src_dictionary", self.load_text_src_dictionary)
+ if cfg.fine_tuning:
+ self.state.add_factory("target_dictionary", self.load_dictionaries)
+ else:
+ self.state.add_factory("dictionaries", self.load_dictionaries)
+
+ if cfg.text_cfg.data_config is not None:
+ self.text_data_cfg = S2TJointDataConfig(Path(f"{cfg.text_cfg.text_data}/{cfg.text_cfg.data_config}"))
+ self.cfg.text_cfg.bpe = self.text_data_cfg.bpe_tokenizer["bpe"]
+ else:
+ self.text_data_cfg = None
+
+ @property
+ def source_dictionary(self) -> Optional[Dictionary]:
+ return None
+
+ @property
+ def target_dictionary(self) -> Optional[Dictionary]:
+ return self.state.target_dictionary
+
+ @property
+ def dictionaries(self) -> List[Dictionary]:
+ return self.state.dictionaries
+
+ @property
+ def text_dictionary(self) -> Optional[Dictionary]:
+ return self.state.text_dictionary
+
+ @property
+ def text_src_dictionary(self) -> Optional[Dictionary]:
+ return self.state.text_src_dictionary
+
+ @property
+ def hubert_tokenizer(self):
+ return self.state.hubert_tokenizer
+
+ def load_dictionaries(self):
+ label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
+ dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels]
+ if not self.cfg.fine_tuning:
+ for dictionary in dictionaries:
+ dictionary.add_symbol("")
+ return dictionaries[0] if self.cfg.fine_tuning else dictionaries
+
+ def load_text_dictionary(self):
+ tgt_dict_path = f"{self.cfg.text_cfg.text_data}/{self.text_data_cfg.vocab_filename if self.text_data_cfg is not None else 'dict.txt'}"
+ if not os.path.isfile(tgt_dict_path):
+ raise FileNotFoundError(f"Dict not found: {tgt_dict_path}")
+ text_dictionary = Dictionary.load(tgt_dict_path)
+ self.mask_idx = text_dictionary.add_symbol("")
+ return text_dictionary
+
+ def load_text_src_dictionary(self):
+ src_dict_path = f"{self.cfg.text_cfg.text_data}/{self.text_data_cfg.src_vocab_filename if self.text_data_cfg is not None else 'dict.txt'}"
+ if not os.path.isfile(src_dict_path):
+ raise FileNotFoundError(f"Dict not found: {src_dict_path}")
+ src_text_dictionary = Dictionary.load(src_dict_path)
+ self.mask_idx = src_text_dictionary.add_symbol("")
+ return src_text_dictionary
+
+ @classmethod
+ def setup_task(
+ cls, cfg: JointPretrainingConfig, **kwargs
+ ) -> "Jsc2tPretrainingTask":
+ load_local_states = kwargs.get("load_local_states", True)
+ return cls(cfg, load_local_states)
+
+ def get_label_dir(self) -> str:
+ if self.cfg.label_dir is None:
+ return self.cfg.data
+ return self.cfg.label_dir
+
+ def load_paired_dataset(self, text_split, truncate_source=False):
+ text_split, lp = text_split.rsplit('.', 1) # e.g. "libritext.ltr-ltr"
+ if len(lp.split("-")) == 2:
+ src, tgt = lp.split("-")
+ if src == tgt:
+ logger.warn(f"| trying to load monolingual dataset {text_split}.{lp}, please check your task is right.")
+ paired_dataset = self.load_char_bart_dataset(f"{text_split}.{lp}.{tgt}")
+ return paired_dataset
+ paired_dataset = load_langpair_dataset(
+ self.cfg.text_cfg.text_data,
+ text_split,
+ src,
+ self.text_src_dictionary,
+ tgt,
+ self.text_dictionary,
+ combine=True,
+ dataset_impl=None,
+ upsample_primary=1,
+ left_pad_source=False,
+ left_pad_target=False,
+ max_source_positions=self.cfg.text_cfg.tokens_per_sample,
+ max_target_positions=self.cfg.text_cfg.tokens_per_sample,
+ truncate_source=truncate_source,
+ prepend_bos=False,
+ load_alignments=False,
+ append_source_id=True if self.cfg.text_cfg.prepend_tgt_lang_tag else False,
+ lang_format="" if self.cfg.text_cfg.prepend_tgt_lang_tag else "[{}]",
+ input_feeding=self.cfg.add_decoder_target,
+ )
+ if self.cfg.text_cfg.mask_text_ratio > 0:
+ # add mask
+ self.mask_idx = self.text_src_dictionary.index("")
+ mask_whole_words = None
+ if self.cfg.text_cfg.mask_whole_words:
+ mask_whole_words = get_whole_word_mask(self.cfg.text_cfg, self.text_src_dictionary)
+ elif self.cfg.text_cfg.mask_repeative_tokens:
+ mask_whole_words = get_repeative_start
+
+ src_dataset, src_unmasked_dataset = MaskTokensDataset.apply_mask(
+ paired_dataset.src,
+ self.text_src_dictionary,
+ pad_idx=self.text_src_dictionary.pad(),
+ mask_idx=self.mask_idx,
+ seed=self.cfg.text_cfg.seed,
+ mask_prob=self.cfg.text_cfg.mask_text_ratio,
+ leave_unmasked_prob=self.cfg.text_cfg.leave_unmasked_prob,
+ random_token_prob=self.cfg.text_cfg.mask_random,
+ freq_weighted_replacement=self.cfg.text_cfg.freq_weighted_replacement,
+ mask_whole_words=mask_whole_words,
+ mask_multiple_length=self.cfg.text_cfg.mask_multiple_length,
+ mask_stdev=self.cfg.text_cfg.mask_stdev,
+ )
+ tgt_dataset = paired_dataset.tgt if paired_dataset.tgt is not None else src_unmasked_dataset
+ paired_dataset = LanguageTripleDataset(
+ src_dataset,
+ src_dataset.sizes,
+ self.text_src_dictionary,
+ src_unmasked_dataset,
+ src_unmasked_dataset.sizes,
+ self.text_src_dictionary,
+ tgt_dataset,
+ tgt_dataset.sizes,
+ self.text_dictionary,
+ left_pad_source=False,
+ left_pad_target=False,
+ align_dataset=None,
+ eos=None,
+ num_buckets=0,
+ shuffle=True,
+ pad_to_multiple=1,
+ )
+ else:
+ src, ref, tgt = lp.split("-")
+ paired_dataset = load_langtriple_dataset(
+ self.cfg.text_cfg.text_data,
+ text_split,
+ src,
+ self.text_src_dictionary,
+ ref,
+ self.dictionaries[-1],
+ tgt,
+ self.text_dictionary,
+ combine=True,
+ dataset_impl=None,
+ upsample_primary=1,
+ left_pad_source=False,
+ left_pad_target=False,
+ max_source_positions=self.cfg.text_cfg.tokens_per_sample,
+ max_target_positions=self.cfg.text_cfg.tokens_per_sample,
+ truncate_source=truncate_source,
+ prepend_bos=False,
+ load_alignments=False,
+ append_source_id=True if self.cfg.text_cfg.prepend_tgt_lang_tag else False,
+ lang_format="" if self.cfg.text_cfg.prepend_tgt_lang_tag else "[{}]",
+ )
+ return paired_dataset
+
+ def load_dataset(self, split: str, epoch=1, **kwargs) -> None:
+ """
+ Create Wav dataset for audio, and Index dataset for phonemized text,
+ then concatenate them to by fairseq.data.multi_corpus_dataset.MultiCorpusDataset.
+ """
+ speech_splits = split.split('+')[0].split(',')
+ ### 1st, create a speech dataset using STSpeechDataset (modified from HubertDataset)
+ dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
+ pad_list = [dict.pad() for dict in dicts]
+ eos_list = [dict.eos() for dict in dicts]
+ procs = [LabelEncoder(dict) for dict in dicts]
+ if self.cfg.speech_tgt_lang != "":
+ tgt_lang_idx = _lang_token_index(dicts[0], self.cfg.speech_tgt_lang)
+ logger.info(f"Will prepend <{tgt_lang_idx}> at the beginning of prev_output_tokens to replace ")
+ else:
+ tgt_lang_idx = None
+
+
+ # hubert v1: pad_audio=True, random_crop=False;
+ speech_datasets = []
+ for speech_split in speech_splits:
+ paths = [
+ f"{self.get_label_dir()}/{speech_split}.{l}" for l in self.cfg.labels
+ ]
+ speech_datasets.append(
+ HubertDataset(
+ f"{self.cfg.data}/{speech_split}.tsv",
+ sample_rate=self.cfg.sample_rate,
+ label_paths=paths,
+ label_rates=self.cfg.label_rate,
+ pad_list=pad_list,
+ eos_list=eos_list,
+ label_processors=procs,
+ max_keep_sample_size=self.cfg.max_keep_size,
+ min_keep_sample_size=self.cfg.min_sample_size,
+ max_sample_size=self.cfg.max_sample_size,
+ pad_audio=self.cfg.pad_audio,
+ normalize=self.cfg.normalize,
+ store_labels=self.cfg.store_labels,
+ random_crop=self.cfg.random_crop,
+ single_target=self.cfg.single_target,
+ tgt_dict=dicts[0],
+ add_decoder_target=self.cfg.add_decoder_target,
+ fine_tuning=self.cfg.fine_tuning,
+ tgt_lang_idx=tgt_lang_idx,
+ tokenizer=self.hubert_tokenizer,
+ )
+ )
+ if len(speech_datasets) > 1:
+ speech_dataset = ConcatDataset(speech_datasets)
+ else:
+ speech_dataset = speech_datasets[0]
+
+ has_text = len(split.split('+')) > 1
+ if not has_text:
+ assert speech_dataset is not None
+ self.datasets[split] = speech_dataset
+ return
+
+ ### 2nd, create paired/mono text datasets using Langpairdataset
+ if split.split('+')[1] != '':
+ paired_splits = [paired_split for paired_split in split.split('+')[1].split(',') if paired_split != '']
+ paired_datasets = [self.load_paired_dataset(paired_split) for paired_split in paired_splits]
+ else:
+ paired_splits, paired_datasets = [], []
+
+ if len(split.split('+')) > 2 and split.split('+')[2] != '':
+ mono_splits = [mono_split for mono_split in split.split('+')[2].split(',') if mono_split != '']
+ mono_datasets = [self.load_paired_dataset(mono_split, truncate_source=self.cfg.text_cfg.truncate_mono_source) for mono_split in mono_splits]
+ else:
+ mono_splits, mono_datasets = [], []
+
+ assert len(mono_datasets + paired_datasets) > 0, f"split {split} has no text! you should check out for that"
+
+ ### 3rd, if provided, create a supervised dataset with labeled data
+ if len(split.split('+')) > 3 and split.split('+')[3] != '':
+ assert len(paired_splits) > 0, f"supervised dataset can not be loaded without text paired dataset!"
+ tgt = paired_splits[0].rsplit('.', 1)[1].split("-")[1]
+ sup_split = split.split('+')[3]
+
+ sup_dataset = HubertDataset(
+ f"{self.cfg.data}/{sup_split}.tsv",
+ sample_rate=self.cfg.sample_rate,
+ label_paths=[f"{self.get_label_dir()}/{sup_split}.{tgt}"],
+ label_rates=[-1],
+ pad_list=[self.text_dictionary.pad()],
+ eos_list=[self.text_dictionary.eos()],
+ label_processors=[LabelEncoder(self.text_dictionary)],
+ max_keep_sample_size=self.cfg.max_keep_size,
+ min_keep_sample_size=None,
+ max_sample_size=None,
+ pad_audio=True,
+ normalize=self.cfg.normalize,
+ store_labels=self.cfg.store_labels,
+ random_crop=False,
+ single_target=True,
+ tgt_dict=self.text_dictionary,
+ add_decoder_target=self.cfg.add_decoder_target,
+ fine_tuning=True,
+ tgt_lang_idx=None,
+ tokenizer=None,
+ )
+ else:
+ sup_dataset = None
+
+ ### 4th, compose a MultiCorpusDataset
+ dataset_dict, max_positions_dict, distributions, max_tokens_ratios = self.resample_multi_modality_dataset(
+ speech_dataset, sup_dataset, mono_datasets, paired_datasets, mono_splits, paired_splits, epoch=epoch,
+ )
+ self.datasets[split] = MultiCorpusDataset(
+ dataset_dict,
+ max_positions=max_positions_dict,
+ distribution=distributions,
+ max_tokens_ratio=max_tokens_ratios,
+ seed=self.cfg.text_cfg.seed,
+ sort_indices=True,
+ )
+
+
+ def max_positions(self) -> Tuple[int, int]:
+ return (sys.maxsize, sys.maxsize)
+
+ def filter_indices_by_size(
+ self, indices: np.array, *args, **kwargs
+ ) -> np.array:
+ return indices
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ skip_remainder_batch=False,
+ grouped_shuffling=False,
+ update_epoch_batch_itr=False,
+ ):
+ """
+ Get an iterator that yields batches of data from the given dataset.
+ Args:
+ dataset (~fairseq.data.FairseqDataset): dataset to batch
+ max_tokens (int, optional): max number of tokens in each batch
+ (default: None).
+ max_sentences (int, optional): max number of sentences in each
+ batch (default: None).
+ max_positions (optional): max sentence length supported by the
+ model (default: None).
+ ignore_invalid_inputs (bool, optional): don't raise Exception for
+ sentences that are too long (default: False).
+ required_batch_size_multiple (int, optional): require batch size to
+ be a multiple of N (default: 1).
+ seed (int, optional): seed for random number generator for
+ reproducibility (default: 1).
+ num_shards (int, optional): shard the data iterator into N
+ shards (default: 1).
+ shard_id (int, optional): which shard of the data iterator to
+ return (default: 0).
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. 0 means the data will be loaded in the main process
+ (default: 0).
+ epoch (int, optional): the epoch to start the iterator from
+ (default: 1).
+ data_buffer_size (int, optional): number of batches to
+ preload (default: 0).
+ disable_iterator_cache (bool, optional): don't cache the
+ EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
+ (default: False).
+ skip_remainder_batch (bool, optional): if set, discard the last
+ batch in each training epoch, as the last batch is often smaller than
+ local_batch_size * distributed_word_size (default: ``True``).
+ grouped_shuffling (bool, optional): group batches with each groups
+ containing num_shards batches and shuffle groups. Reduces difference
+ between sequence lengths among workers for batches sorted by length.
+ update_epoch_batch_itr (bool optional): if true then donot use the cached
+ batch iterator for the epoch
+
+ Returns:
+ ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
+ given dataset split
+ """
+ if self.fine_tuning or not isinstance(dataset, MultiCorpusDataset):
+ return super().get_batch_iterator(
+ dataset,
+ max_tokens=max_tokens,
+ max_sentences=max_sentences,
+ max_positions=max_positions,
+ ignore_invalid_inputs=ignore_invalid_inputs,
+ required_batch_size_multiple=required_batch_size_multiple,
+ seed=seed,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ num_workers=num_workers,
+ epoch=epoch,
+ data_buffer_size=data_buffer_size,
+ disable_iterator_cache=disable_iterator_cache,
+ skip_remainder_batch=skip_remainder_batch,
+ grouped_shuffling=grouped_shuffling,
+ update_epoch_batch_itr=update_epoch_batch_itr,
+ )
+
+ can_reuse_epoch_itr = (
+ not disable_iterator_cache
+ and not update_epoch_batch_itr
+ and self.can_reuse_epoch_itr(dataset)
+ )
+ if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
+ logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch))
+ return self.dataset_to_epoch_iter[dataset]
+
+ assert isinstance(dataset, FairseqDataset)
+
+ # initialize the dataset with the correct starting epoch
+ dataset.set_epoch(epoch)
+
+ # get indices ordered by example size
+ with data_utils.numpy_seed(seed):
+ indices = dataset.ordered_indices()
+
+ # filter examples that are too large
+ if max_positions is not None:
+ indices = self.filter_indices_by_size(
+ indices, dataset, max_positions, ignore_invalid_inputs
+ )
+
+ # create mini-batches with given size constraints
+ batch_sampler = dataset.get_batch_sampler(
+ indices,
+ num_shards,
+ seed,
+ max_tokens=max_tokens,
+ max_sentences=max_sentences,
+ required_batch_size_multiple=required_batch_size_multiple,
+ split_modality_batch=self.cfg.split_modality_batch,
+ )
+
+ # return a reusable, sharded iterator
+ epoch_iter = iterators.EpochBatchIterator(
+ dataset=dataset,
+ collate_fn=dataset.collater,
+ batch_sampler=batch_sampler,
+ seed=seed,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ num_workers=num_workers,
+ epoch=epoch,
+ buffer_size=data_buffer_size,
+ skip_remainder_batch=skip_remainder_batch,
+ disable_shuffling=True,
+ grouped_shuffling=grouped_shuffling,
+ )
+
+ if can_reuse_epoch_itr:
+ self.dataset_to_epoch_iter[dataset] = epoch_iter
+
+ return epoch_iter
+
+ def build_generator(
+ self,
+ models,
+ args,
+ seq_gen_cls=None,
+ extra_gen_cls_kwargs=None,
+ ):
+ """Build ED-CTC generator for finet-tuned ASR model"""
+ from speechut.squence_generator import SequenceGenerator
+ extra_gen_cls_kwargs = {
+ "ctc_weight": self.cfg.ctc_weight,
+ "lm_dict": Dictionary.load(os.path.join(self.cfg.data, self.cfg.lm_dict)),
+ **extra_gen_cls_kwargs
+ }
+ return super().build_generator(
+ models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
+ )
+
+ @classmethod
+ def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0):
+ """Size ratios for temperature-based sampling
+ (https://arxiv.org/abs/1907.05019)"""
+ _sizes = np.array(sizes)
+ prob = _sizes / _sizes.sum()
+ smoothed_prob = prob ** alpha
+ smoothed_prob = smoothed_prob / smoothed_prob.sum()
+ size_ratio = (smoothed_prob * _sizes.sum()) / _sizes
+
+ o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)})
+ logger.info(f"original sampling probability: {o_str}")
+ p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)})
+ logger.info(f"balanced sampling probability: {p_str}")
+ sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)})
+ logger.info(f"balanced sampling size ratio: {sr_str}")
+ return size_ratio.tolist()
+
+ def resample_multi_modality_dataset(self, speech_dataset, sup_dataset, mono_datasets, paired_datasets, mono_splits, paired_splits, epoch=1, train=True):
+ assert len(mono_datasets+paired_datasets) > 0, f"No text data loaded!"
+
+ if len(mono_datasets) > 1 and self.cfg.text_sampling_alpha != 1.0:
+ size_ratios = self._get_size_ratios(
+ mono_splits, [len(s) for s in mono_datasets], alpha=self.cfg.text_sampling_alpha
+ )
+ mono_datasets = [
+ ResamplingDataset(
+ d, size_ratio=r, seed=0, epoch=epoch, replace=(r >= 1.0)
+ ) for d, r in zip(mono_datasets, size_ratios)
+ ]
+
+ if len(paired_datasets) > 1 and self.cfg.text_sampling_alpha != 1.0:
+ size_ratios = self._get_size_ratios(
+ paired_splits, [len(s) for s in paired_datasets], alpha=self.cfg.text_sampling_alpha
+ )
+ paired_datasets = [
+ ResamplingDataset(
+ d, size_ratio=r, seed=0, epoch=epoch, replace=(r >= 1.0)
+ ) for d, r in zip(paired_datasets, size_ratios)
+ ]
+
+ dataset_list = [speech_dataset, sup_dataset]
+ for datasets in [mono_datasets, paired_datasets]:
+ if len(datasets) > 1:
+ dataset_list.append(ConcatDataset(datasets))
+ elif len(datasets) == 1:
+ dataset_list.append(datasets[0])
+ else:
+ dataset_list.append(None)
+
+ ### match speech/text datasets according to modality
+ dataset_dict = OrderedDict((name, d) for name, d in zip(["speech", "speech_sup", "text_mono", "text_paired"], dataset_list) if d is not None)
+ max_positions_dict = {
+ "speech": None,
+ "speech_sup": None,
+ "text_mono": (self.cfg.text_cfg.tokens_per_sample, self.cfg.text_cfg.tokens_per_sample),
+ "text_paired": (self.cfg.text_cfg.tokens_per_sample, self.cfg.text_cfg.tokens_per_sample),
+ }
+ max_positions_dict = OrderedDict((name, max_positions_dict[name]) for name in dataset_dict.keys())
+ max_tokens_ratios_dict = {
+ "speech": 1.0,
+ "speech_sup": 1.0,
+ "text_mono": 1.0 / 320 / self.cfg.text_cfg.text_maxtokens_ratio,
+ "text_paired": 1.0 / 320 / self.cfg.text_cfg.text_maxtokens_ratio,
+ }
+ max_tokens_ratios = [max_tokens_ratios_dict[name] for name in dataset_dict.keys()]
+ dataset_lens = np.array([len(dataset) for dataset in dataset_dict.values()])
+ dataset_avg_sample_lens = np.array([
+ sum([dataset.num_tokens(i) for i in np.random.randint(low=0, high=len(dataset), size=10000)]) / 10000.0
+ for dataset in dataset_dict.values()
+ ])
+
+ if not "speech" in dataset_dict:
+ distributions = [l / sum(dataset_lens) for l in dataset_lens]
+ else:
+ ## we just keep the batches of speech and non-speech the same, expand_coef is to ensure speech batches is less than others
+ first_ratio = dataset_lens[0] / sum(dataset_lens)
+ expand_coef = 1.2 if sup_dataset is None else 1.1 * sum(dataset_lens[0:2]) / dataset_lens[0]
+ distributions = [expand_coef * max_tokens_ratios[i] * dataset_avg_sample_lens[0] / l for (i, l) in enumerate(dataset_avg_sample_lens)]
+ distributions[0] = 1.0
+ if sup_dataset is not None:
+ distributions[1] = dataset_lens[1] / dataset_lens[0]
+ distributions = [first_ratio * d for d in distributions]
+
+ logging.info(f"Number samples of datasets is {dataset_lens}")
+ logging.info(f"Avg sample length of datasets is {dataset_avg_sample_lens}")
+ logging.info(f"Sampling distributions is {distributions}")
+ logging.info(f"Maxtokens ratio is {max_tokens_ratios}")
+ return dataset_dict, max_positions_dict, distributions, max_tokens_ratios
+
+ def build_tokenizer(self, cfg=None):
+ logger.info(f"tokenizer: {self.cfg.hubert_tokenizer}")
+ if self.cfg.hubert_tokenizer != "none":
+ return encoders.build_bpe(Namespace(**{"bpe": self.cfg.hubert_tokenizer, "sentencepiece_model": self.cfg.sp_path}))
+ else:
+ return None
+
+ def load_char_bart_dataset(self, split):
+ mono_dataset = data_utils.load_indexed_dataset(
+ f"{self.cfg.text_cfg.text_data}/{split}",
+ self.text_dictionary,
+ )
+ mono_dataset = StripTokenDataset(mono_dataset, self.text_dictionary.eos())
+ mono_dataset = maybe_shorten_dataset(
+ mono_dataset,
+ split,
+ self.cfg.text_cfg.shorten_data_split_list,
+ self.cfg.text_cfg.shorten_method,
+ self.cfg.text_cfg.tokens_per_sample - 2,
+ self.cfg.text_cfg.seed,
+ )
+ logger.info("loaded {} samples from: {}".format(len(mono_dataset), mono_dataset))
+ ### prepend bos and eos to dataset
+ mono_dataset = PrependTokenDataset(mono_dataset, self.text_dictionary.bos())
+ mono_dataset = AppendTokenDataset(mono_dataset, self.text_dictionary.eos())
+ mask_whole_words = (
+ get_whole_word_mask(None, self.text_dictionary)
+ if self.cfg.text_cfg.mask_whole_words
+ else None
+ )
+ lang=self.cfg.speech_tgt_lang
+ mono_dataset = DenoisingDataset(
+ mono_dataset,
+ mono_dataset.sizes,
+ self.text_dictionary,
+ self.mask_idx,
+ mask_whole_words,
+ shuffle=self.cfg.text_cfg.shuffle_instance,
+ seed=self.cfg.text_cfg.seed,
+ args=self.cfg.text_cfg,
+ tgt_lang_idx=_lang_token_index(self.text_dictionary, lang) if self.cfg.text_cfg.prepend_tgt_lang_tag else None,
+ )
+
+ return mono_dataset
diff --git a/SpeechT5/SpeechLM/README.md b/SpeechT5/SpeechLM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..11923ca3022332fd6f9e02b634ec871dde7b164b
--- /dev/null
+++ b/SpeechT5/SpeechLM/README.md
@@ -0,0 +1,268 @@
+# SpeechLM
+
+
+
+ [**SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data**](https://arxiv.org/abs/2209.15329)
+
+- June 2023: We have corrected the errors in the pre-training data for SpeechLM-P Base models, and new results are updated.
+
+- April 2023: We discovered some errors about the data in the pre-training experiments, which will affect all the results about SpeechLM-P Base models. We are re-conducting the related experiments and will update the paper with the new results.
+
+- (Done) Oct 2022: release the code and models
+- Oct 2022: release preprint in [arXiv](https://arxiv.org/abs/2209.15329)
+
+## Pre-Trained and Fine-tuned Models
+
+| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
+| :------: | :----------------------------------------------: | :-----------------: | :-----: |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/update/speechlm_checkpoint_298_400000.pt?sv=2020-04-08&st=2023-06-19T10%3A35%3A37Z&se=2033-06-20T10%3A35%3A00Z&sr=b&sp=r&sig=xPzDV3Zm7l7Mp4dgMxAYMOcoZfVJjlbBglqD7uw2XW0%3D) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/update/checkpoint_best_asr_ft.pt?sv=2020-04-08&st=2023-06-19T10%3A36%3A39Z&se=2033-06-20T10%3A36%3A00Z&sr=b&sp=r&sig=xbS2hGAlTr7K6JJdBN0nKrPtITZE62eT%2FoEK3MBsnZs%3D) |
+| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1eblW8U8f9t-NTuCNRrNHwr-8BeLAUAmQ/view?usp=sharing) |
+| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1vXyO5DolbiWiTYZ6pkkKQsu2wJetaPlv/view?usp=sharing) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/update/checkpoint_best_ende_ft.pt?sv=2020-04-08&st=2023-06-19T10%3A37%3A23Z&se=2033-06-20T10%3A37%3A00Z&sr=b&sp=r&sig=bNET3bF240rQg%2B%2F87WC%2FJ1cMojI0WEIoqwEfM7PyQUE%3D) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/update/checkpoint_best_enca_ft.pt?sv=2020-04-08&st=2023-06-19T10%3A37%3A46Z&se=2033-06-20T10%3A37%3A00Z&sr=b&sp=r&sig=9H1XMRiAU8tz%2B9Ri4sUGP0kZFiiQ5cSVqAqShZAhIzY%3D) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/update/checkpoint_best_enar_ft.pt?sv=2020-04-08&st=2023-06-19T10%3A38%3A05Z&se=2033-06-20T10%3A38%3A00Z&sr=b&sp=r&sig=mvlF1vmbW9mr66dP3wW9M%2BiU7ASluD4xqCbxblYPCOw%3D) |
+| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/update/checkpoint_best_entr_ft.pt?sv=2020-04-08&st=2023-06-19T10%3A38%3A29Z&se=2033-06-20T10%3A38%3A00Z&sr=b&sp=r&sig=Wda6nh9AVlcJAI6PamiEuHeeCwi4Yudva060qGORbSc%3D) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1QjLIgTJKIylVIp5hUkfSjGPtz8Xo7Lky/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1YZQDVv096o8Opt0RBnkRiZXYPRDqKZnP/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1qYygNWSc11TQbBI1OzC4ChlR-dNh8t9S/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/162U88mwso2aVfzzPkEM2nP_vwTpcb57T/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1lbTSRXewEeb2t45URunD6EiJcbniyjWW/view?usp=sharing) |
+| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1Er4I_jHS175pQQph223yKtiiLQ378VvH/view?usp=sharing) |
+
+
+## Extract features using pre-trained models
+For easier use of our pre-trained models, we merge all inference-related code to [`SpeechLM.py`](SpeechLM.py) and make cleaned checkpoints [~~`SpeechLM-P Base`~~](https://valle.blob.core.windows.net/share/speechlm/speechlmp_base_checkpoint_clean.pt?sv=2020-04-08&st=2023-04-04T05%3A42%3A17Z&se=2033-04-05T05%3A42%3A00Z&sr=b&sp=r&sig=DN7VwaEWhrhRPiyuT84mJpohrMeJsEPq4o6qRr8BNsk%3D) [`SpeechLM-H Base`](https://valle.blob.core.windows.net/share/speechlm/speechlmh_base_checkpoint_clean.pt?sv=2020-04-08&st=2023-04-04T05%3A43%3A07Z&se=2033-04-05T05%3A43%3A00Z&sr=b&sp=r&sig=T9oaIvrb3z3Wo5GTZp8eP2x7B7yuQ%2B80Ff1KhuWrrKg%3D) [`SpeechLM-P Large`](https://valle.blob.core.windows.net/share/speechlm/speechlmp_large_checkpoint_clean.pt?sv=2020-04-08&st=2023-04-04T05%3A43%3A33Z&se=2033-04-05T05%3A43%3A00Z&sr=b&sp=r&sig=qfWBNdiIGuDgkgUiHXaWnPiVbUHm3VSp%2FHTlWrCghRk%3D) by removing non-required modules. Now you can directly use the following script to extract your speech features:
+```python
+import torch
+import torch.nn.functional as F
+from SpeechLM import SpeechLMConfig, SpeechLM
+
+checkpoint = torch.load('path/to/the/cleaned/checkpoint.pt')
+cfg = SpeechLMConfig(checkpoint['cfg']['model'])
+model = SpeechLM(cfg)
+model.load_state_dict(checkpoint['model'])
+model.eval()
+
+wav_input_16khz = torch.randn(1,10000)
+normalize = checkpoint['cfg']['task']['normalize'] # False for base model, True for large model
+if normalize:
+ wav_input_16khz = F.layer_norm(wav_input_16khz[0], wav_input_16khz[0].shape).unsqueeze(0)
+
+# extract the representation of last layer
+rep = model.extract_features(wav_input_16khz)[0]
+
+# extract the representation of each layer
+output_layer = model.cfg.encoder_layers + model.cfg.text_transformer.encoder.layers
+rep, layer_results = model.extract_features(wav_input_16khz, output_layer=output_layer, ret_layer_results=True)[0]
+layer_reps = [x.transpose(0, 1) for x in layer_results]
+```
+
+
+## Setup
+To fine-tune or pre-train more models, please follow the instructions below.
+
+```bash
+git submodule update --init SpeechLM/fairseq
+cd SpeechLM/
+pip install --editable fairseq/
+pip install sacrebleu==1.5.1
+```
+
+## ASR on LibriSpeech
+### Data preparation
+Please follow the steps of wav2vec 2.0 manifest [here](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#prepare-training-data-manifest) to prepare `train.tsv` and `train.ltr`. You should make sure the vocabulary [`dict.ltr.txt`](dataset/LibriSpeech/asr/dict.ltr.txt) is the same as that used for the pre-trained model.
+
+Put yout prepared data into `$data_dir`, we provided eamples in [`dataset/LibriSpeech/asr`](dataset/LibriSpeech/asr/).
+
+### Fine-tune a CTC model
+- Fine-tune the base model
+ ```bash
+ # Usage: speechlm/scripts/tune_speechlm_asr/finetune_base_ctc.sh [mount=$PWD] [world_size=8] [update_freq=1]
+ model_path=path/to/your/pre-trained/model
+ data_dir=dataset/LibriSpeech/asr
+ bash speechlm/scripts/tune_speechlm_asr/finetune_base_ctc.sh $model_path $data_dir 'tag400k'
+ ```
+- Fine-tune the large model
+ ```bash
+ # Usage: speechlm/scripts/tune_speechlm_asr/finetune_large_ctc.sh [mount=$PWD] [world_size=8] [update_freq=4]
+ model_path=path/to/your/pre-trained/model
+ data_dir=dataset/LibriSpeech/asr
+ bash speechlm/scripts/tune_speechlm_asr/finetune_large_ctc.sh $model_path $data_dir 'tag400k'
+ ```
+### Decode
+- Directly decode a CTC model.
+ ```bash
+ # Usage: speechlm/scripts/tune_speechlm_asr/inference_ctc.sh [gen-set=dev_clean,dev_other,test_clean,test_other]
+ model_path=path/to/your/fine-tuned/model
+ data_dir=dataset/LibriSpeech/asr
+ bash speechlm/scripts/tune_speechlm_asr/inference_ctc.sh $model_path $data_dir
+ # for large models
+ # bash speechlm/scripts/tune_speechlm_asr/inference_ctc_large.sh $model_path $data_dir
+ ```
+- Decode with 4-gram language model using [flashlight](https://github.com/flashlight/flashlight/tree/main/bindings/python) and [kenlm](https://github.com/kpu/kenlm).
+ > Please put [4-gram.arpa](https://www.openslr.org/resources/11/4-gram.arpa.gz) and the word-to-letter lexicon [librispeech_lexicon.lst](https://drive.google.com/file/d/1q7IbNGqtwXnctjvuvpviQ4ZmepFHQmTO/view?usp=sharing) into `$data_dir`.
+ ```bash
+ # Usage: speechlm/scripts/tune_speechlm_asr/inference_ctc_kenlm.sh [gen-set=dev_clean,dev_other,test_clean,test_other]
+ model_path=path/to/your/fine-tuned/model
+ data_dir=dataset/LibriSpeech/asr
+ bash speechlm/scripts/tune_speechlm_asr/inference_ctc_kenlm.sh $model_path $data_dir
+ ```
+- Decode large models with fairseq-lm using [flashlight](https://github.com/flashlight/flashlight/tree/main/bindings/python).
+ > Please put [lm_librispeech_word_transformer.pt](https://dl.fbaipublicfiles.com/wav2letter/sota/2019/lm/lm_librispeech_word_transformer.pt) and its vocabulary [`dict.txt`](https://dl.fbaipublicfiles.com/wav2letter/sota/2019/lm/lm_librispeech_word_transformer.dict) into `$data_dir/fairseq_word_lm`, and the word-to-letter lexicon [librispeech_lexicon.lst](https://drive.google.com/file/d/1q7IbNGqtwXnctjvuvpviQ4ZmepFHQmTO/view?usp=sharing) into `$data_dir`. Capitalize the `dict.txt` to amke it compatible with the word-to-letter lexicon.
+ ```bash
+ # Usage: speechlm/scripts/tune_speechlm_asr/inference_ctc_large_fsqlm.sh [gen-set=dev_clean,dev_other,test_clean,test_other]
+ model_path=path/to/your/fine-tuned/model
+ data_dir=dataset/LibriSpeech/asr
+ bash speechlm/scripts/tune_speechlm_asr/inference_ctc_large_fsqlm.sh $model_path $data_dir dev_other
+ ```
+
+## ST on CoVoST-2
+### Data Preparation
+1. Download [Common Voice audio clips](https://commonvoice.mozilla.org/en/datasets) (version 4) for English into `$cv_root/en`.
+2. Get data manifest. The following script will convert mp3 files to waveform, create tsv file containing speech/translation paires, create data config files.
+ ```bash
+ lang=de # ca,ar,tr
+ cv_root=dataset/CommonVoice/v4
+ bash speechlm/data_process/prepare_covost2_enxx.sh $lang $cv_root
+ ```
+ We provided examples in [`dataset/CommonVoice/v4/en/en-de`](dataset/CommonVoice/v4/en/en-de).
+
+### Fine-tune a encoder-decoder model
+- Fine-tune the Base model (fine-tuned models will be stored in `$mount/exp/finetune_covost`).
+
+ ```bash
+ model_path=path/to/your/pre-trained/model
+ lang=de # ca,ar,tr
+ data_dir=dataset/CommonVoice/v4/en/en-${lang}
+ # Usage (Base model): speechlm/scripts/tune_speechlm_st/ft_base_covost_enxx.sh [mount=$PWD] [world_size=8] [update_freq=2]
+ bash speechlm/scripts/tune_speechlm_st/ft_base_covost_enxx.sh $model_path $data_dir $lang 'tag400k'
+ ```
+- Fine-tune the Large model (fine-tuned models will be stored in `$mount/exp/finetune_covost`).
+ ```bash
+ # Usage (Large model): speechlm/scripts/tune_speechlm_st/ft_large_covost_enxx.sh [mount=$PWD] [world_size=8] [update_freq=4]
+ bash speechlm/scripts/tune_speechlm_st/ft_large_covost_enxx.sh $model_path $data_dir $lang 'tag400k'
+ ```
+
+### Decode
+- Decode the base model
+ ```bash
+ # Usage: speechlm/scripts/tune_speechlm_st/inference_base.sh [gen-set=dev] [beam_size=5]
+ model_path=path/to/your/fine-tuned/model
+ lang=de # ca,ar,tr
+ data_dir=dataset/CommonVoice/v4/en/en-${lang}
+ bash speechlm/scripts/tune_speechlm_st/inference_base.sh $model_path $data_dir $lang dev
+ ```
+- Decode the large model
+ ```bash
+ # Usage: speechlm/scripts/tune_speechlm_st/inference_large.sh [gen-set=dev] [beam_size=5]
+ bash speechlm/scripts/tune_speechlm_st/inference_large.sh $model_path $data_dir $lang dev
+ ```
+
+## Universal Representation Evaluation on SUPERB
+
+Please refer to [**SUPERB**](https://superbbenchmark.org/) for the downstreaming tasks.
+
+## Pre-train
+Please follow the instructions of [Tokenizer](README.md#Tokenizers) to prepare the pre-training data. We provided examples in [`dataset`](dataset).
+- SpeechLM-P Base model
+
+ Models will be stored in `$mount/pretrain`.
+ ```bash
+ data_dir=dataset/LibriSpeech/phone_unit # should contain train_960.{tsv,phn}
+ text_data_dir=dataset/LibriLM/phone_unit/bin-idx # should contain train_text.phn-ltr.{phn,ltr}.{bin,idx}
+ # Usage: speechlm/scripts/pretrain_speechlm/base_speechlmp.sh [mount=$PWD] [world_size=32] [update_freq=1]
+ bash speechlm/scripts/pretrain_speechlm/base_speechlmp.sh $data_dir $text_data_dir
+ ```
+- SpeechLM-H Base model
+ ```bash
+ data_dir=dataset/LibriSpeech/hidden_unit # should contain train_960.{tsv,phn}
+ text_data_dir=dataset/LibriLM/km-ltr/bin-idx # should contain train_text.km-ltr.{km,ltr}.{bin,idx}
+ # Usage: speechlm/scripts/pretrain_speechlm/base_speechlmh.sh [mount=$PWD] [world_size=32] [update_freq=1]
+ bash speechlm/scripts/pretrain_speechlm/base_speechlmp.sh $data_dir $text_data_dir
+ ```
+- SpeechLM-P Large model
+ ```bash
+ data_dir=dataset/LibriSpeech/phone_unit # should contain train_960.{tsv,phn}
+ text_data_dir=dataset/LibriLM/phone_unit/bin-idx # should contain train_text.phn-ltr.{phn,ltr}.{bin,idx}
+ # Usage: speechlm/scripts/pretrain_speechlm/base_speechlmp.sh [mount=$PWD] [world_size=32] [update_freq=1]
+ bash speechlm/scripts/pretrain_speechlm/large_speechlmp.sh $data_dir $text_data_dir
+ ```
+
+
+## Tokenizers
+### Phoneme-unit Tokenizer for Speech
+This tokenizer is used to produce the frame-laigned phonemes for unlabeled speech, which is actually a hybrid HMM ASR model.
+
+In the Base setting, we use 100h LibriSpeech labeled data to train the HMM model under Kaldi recipe, then decode the unpaired speech and get the aligned phonemes from the lattice.
+Here we provided the processed phonemes of 960h speech here: [`train_960.tsv`](https://drive.google.com/file/d/1rxlikMglL2kEsF4NfqekZRoA02klY7CE/view?usp=sharing), [`train_960.phn`](), [`dev_clean.tsv`](https://drive.google.com/file/d/1NuVwe687jLBFkDLRy1EV2A2uXyV_kBo2/view?usp=sharing), [`dev_clean.phn`](https://drive.google.com/file/d/1cq_gbS-UgCALOoaE5QmhWrhkTdXuc_Uc/view?usp=sharing). Note that the label-rate is 100 (10ms).
+
+> The phoneme inventory is 300+ word-position-dependent phones including silence phones.
+
+### Phoneme-unit Tokenizer for Text
+This tokenizer is used to phonemize the unpaired text data to (phonemes, letters) paired data, following a `words -> phonemes -> upsampled phones` pipeline.
+
+The following script will download LibriSpeech LM corpus and produce the required data: `train_text.phn-ltr.phn.{idx,bin}` and `train_text.phn-ltr.ltr.{idx,bin}`.
+> Before runing it, make sure you have our provided [`dict.phn.txt`](dataset/LibriLM/phone_unit/bin-idx/dict.phn.txt) and [`dict.ltr.txt`](dataset/LibriLM/phone_unit/bin-idx/dict.ltr.txt) in the output dir `dataset/LibriLM/phone_unit/bin-idx/`.
+
+> The phoneme inventory is 300+ word-position-dependent phones including silence phones.
+
+```bash
+# data will be in dataset/LibriLM/phone_unit/
+bash speechlm/data_process/prepare_phn2ltr_librilm.sh
+```
+### Hidden-unit Tokenizer for Speech
+Please follow the steps of data preparation for HuBERT [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#data-preparation) to prepare 1) wav recordings [`train.tsv`](dataset/LibriSpeech/hidden_unit/train_sample100.tsv) and 2) corresponding hidden-units [`train.km`](dataset/LibriSpeech/hidden_unit/train_sample100.km), and 3) unit vocabulary [`dict.km.txt`](dataset/LibriSpeech/hidden_unit/dict.km.txt).
+
+### Hidden-unit Tokenizer for Text
+This tokenizer is used to produce the speech-style hidden units from unpaired text.
+We train a [FastSpeech](https://arxiv.org/abs/2006.04558)-like model (instead generating continuous spectrum in the original paper, here we generate discrete units) on a small amount of ASR data ([100 hrs LibriSpeech](http://www.openslr.org/12)) as the tokenizer.
+
+Train:
+1. Convert asr transcripts to phoneme sequence with duration information.
+2. Extract hidden-units from speech, using the [Hidden-unit Tokenizer for Speech](#hidden-unit-tokenizer-for-speech).
+3. Train the [model](speechlm/models/fasttext2unit.py) on the paired data:
+ ```bash
+ data_dir=dataset/LibriSpeech/fast_phone2unit
+ bash speechlm/scripts/tokenizer_fastT2U/train_s_5e-4.sh $data_dir
+ ```
+> The phoneme inventory is 41 mono phones including silence phones.
+
+Inference:
+
+4. Convert text data to phoneme sequence by [`lexicon`](https://drive.google.com/file/d/1dh9NEx_cCF9_Aa0UcKyl9j00GXs6LmLQ/view?usp=sharing).
+5. [Generate](speechlm/scripts/tokenizer_fastT2U/generate.sh) hidden units for a large text corpus:
+ ```bash
+ gen_set=dataset/LibriSpeech/fast_phone2unit/genset_examples
+ bash speechlm/scripts/tokenizer_fastT2U/generate.sh $model_path $gen_set
+ ```
+We provided train/generate data examples in [`dataset/LibriSpeech/fast_phone2unit`](dataset/LibriSpeech/fast_phone2unit), and the model checkpoint [here](https://drive.google.com/file/d/1e-aYf8hPXuly8DEvNg5SISOlcUxsgED0/view?usp=sharing).
+
+## License
+
+This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
+Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
+
+[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
+
+## Reference
+
+If you find our work is useful in your research, please cite the following paper:
+
+```bibtex
+@article{zhang2022speechlm,
+ title = {SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data},
+ author = {Zhang, Ziqiang and Chen, Sanyuan and Zhou, Long and Wu, Yu and Ren, Shuo and Liu, Shujie and Yao, Zhuoyuan and Gong, Xun and Dai, Lirong and Li, Jinyu and Wei, Furu},
+ eprint={2209.15329},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ year={2022}
+}
+```
+
+### Contact Information
+
+For help or issues using SpeechLM models, please submit a GitHub issue.
+
+For other communications related to SpeechLM, please contact Long Zhou (`lozhou@microsoft.com`).
+
diff --git a/SpeechT5/SpeechLM/SpeechLM.py b/SpeechT5/SpeechLM/SpeechLM.py
new file mode 100644
index 0000000000000000000000000000000000000000..b242dde083e272f96e80791f13803c44b438991d
--- /dev/null
+++ b/SpeechT5/SpeechLM/SpeechLM.py
@@ -0,0 +1,667 @@
+# ----------------------------------------------------------------------------
+# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
+# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
+# Code based on fairseq: https://github.com/facebookresearch/fairseq
+#
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# ----------------------------------------------------------------------------
+
+import copy
+import logging
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from modules import (
+ compute_mask_indices,
+ LayerNorm,
+ ConvFeatureExtractionModel,
+ GradMultiply,
+ TransformerEncoder,
+ TransformerEncoderBase,
+
+)
+
+# from fairseq.models.transformer import TransformerConfig
+
+logger = logging.getLogger(__name__)
+
+class DictConfig:
+ def __init__(self, cfg=None):
+ if cfg is not None:
+ self.update(cfg)
+
+ def update(self, cfg: dict):
+ self.__dict__.update(cfg)
+
+
+class TransformerConfig:
+ def __init__(self, cfg=None):
+ if cfg is not None:
+ self.update(cfg)
+
+ def update(self, cfg: dict):
+ if 'encoder' in cfg:
+ self.encoder = DictConfig(cfg['encoder'])
+ del cfg['encoder']
+ if 'quant_noise' in cfg:
+ self.quant_noise = DictConfig(cfg['quant_noise'])
+ del cfg['quant_noise']
+ if 'decoder' in cfg:
+ del cfg['decoder']
+ self.__dict__.update(cfg)
+
+
+class SpeechLMConfig:
+ def __init__(self, cfg=None):
+ self.label_rate: int = 50
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
+ self.activation_fn: str = "gelu" # activation function to use
+ self.layer_type: str = "transformer" # layer type in encoder
+
+ # dropouts
+ self.dropout: float = 0.1 # dropout probability for the transformer
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
+
+ self.final_dim: int = 256 # project final representations and targets to this many dimensions
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
+ self.conv_bias: bool = False # include bias in conv encoder
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
+
+ # masking
+ self.mask_length: int = 10 # mask length
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
+ self.mask_selection: str = "static" # how to choose mask length
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
+
+
+ # channel masking
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
+
+ # positional embeddings
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
+
+ # loss computation
+ self.skip_masked: bool = False # skip computing losses over masked frames
+ self.skip_nomask: bool = False # skip computing losses over unmasked frames
+ self.checkpoint_activations: bool = False # recompute activations and save memory for extra compute
+
+ # FP16 optimization
+ self.required_seq_len_multiple: int = 2 # pad the input to encoder such that the sequence length is divisible by multiple
+
+ # Custom
+ self.use_rel_pos_enc: bool = False # whether to use relative positional encoding
+ self.scaling_for_att: float = 1.0 # scaling for attention weights to prevent overflow issue (for large model)
+
+ # unit encoder-decoder
+ self.add_unit_encoder: bool = False # add unit encoder
+
+ # embedding mixing
+ self.mix_with_unit: bool = True # mix with the unit embeddings
+ self.use_pred_unit: bool = False # use the embeddings of predicted units
+ self.l2_embedding: bool = False # compute l2 loss between unit embedding and unit hidden state
+
+ if cfg is not None:
+ self.update(cfg)
+
+ def update(self, cfg: dict):
+ model_cfg = copy.deepcopy(cfg)
+ self.text_transformer = TransformerConfig(model_cfg['text_transformer'])
+ del model_cfg['text_transformer']
+ self.__dict__.update(model_cfg)
+
+class SpeechLM(nn.Module):
+ def __init__(
+ self,
+ cfg: SpeechLMConfig,
+ ) -> None:
+ super().__init__()
+ self.cfg = cfg
+
+ feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
+ self.embed = feature_enc_layers[-1][0]
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+ sample_rate = 16000
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / sample_rate
+
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+ self.logit_temp = cfg.logit_temp
+ self.skip_masked = cfg.skip_masked
+ self.skip_nomask = cfg.skip_nomask
+
+ self.final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
+ self.final_proj_list = nn.ModuleList([
+ nn.Linear(cfg.encoder_embed_dim, self.final_dim) for _ in range(2)
+ ])
+
+ self.mask_emb = nn.Parameter(
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+ )
+
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ ### build unit encoder:
+ self.mask_u2t = cfg.mask_u2t
+ self.compute_mum = cfg.compute_mum
+ self.add_text_ctc = cfg.add_text_ctc
+ self.text_ctc_conv_kernel = cfg.text_ctc_conv_kernel
+ self.padding_idx = 1
+
+ self.add_unit_encoder = cfg.add_unit_encoder
+ self.mix_with_unit = cfg.mix_with_unit
+ self.use_pred_unit = cfg.use_pred_unit
+ self.l2_embedding = cfg.l2_embedding
+ if self.add_unit_encoder:
+ self.unit_embed_tokens = None
+ ### build unit encoder
+ self.unit_encoder = TransformerEncoderBase(
+ cfg.text_transformer,
+ dictionary=None,
+ embed_tokens=self.unit_embed_tokens,
+ use_rel_pos_enc=cfg.use_rel_pos_enc,
+ scaling_for_att=cfg.scaling_for_att,
+ )
+
+ ### build unit2text decoder, not available for now
+ self.add_decoder = cfg.add_decoder
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions."""
+
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ def apply_mask(self, x, padding_mask, target_list):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x[mask_indices] = self.mask_emb
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+ def forward_features(self, source: torch.Tensor) -> torch.Tensor:
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(source)
+ return features
+
+ def forward_targets(
+ self,
+ features: torch.Tensor,
+ target_list: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Trim features to ensure labels exist and then get aligned labels
+ feat_tsz = features.size(2)
+ targ_tsz = min([t.size(1) for t in target_list])
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
+ features = features[..., :feat_tsz]
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
+ target_inds += np.random.choice(int(self.feat2tar_ratio))
+ target_list = [t[:, target_inds.long()] for t in target_list]
+ return features, target_list
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def get_normalized_probs(
+ self,
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
+ log_probs: bool,
+ sample: Optional[Dict[str, Tensor]] = None,
+ ):
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
+ lprobs.batch_first = True
+ return lprobs
+
+ def downsample_ctc_padding_mask(self, padding_mask):
+ """
+ padding_mask: (B, T)
+ """
+ stride = self.text_ctc_conv_kernel // 2
+ return padding_mask[:, ::stride]
+
+ def compute_pred(self, proj_x, label_embs):
+ if self.target_glu:
+ label_embs = self.target_glu(label_embs)
+ x = F.normalize(proj_x.float(), dim=-1) # (S, D)
+ label_embs = F.normalize(label_embs.float(), dim=-1) # (C, D)
+ logits = torch.matmul(x, label_embs.T).type_as(proj_x) # (S, C)
+ logits /= self.logit_temp
+ return logits
+
+ def compute_hubert_logits(self, x, target, proj, label_embs, padding_mask, mask_indices):
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ proj_x_m = proj(x[masked_indices])
+ logit_m_list = [(self.compute_pred(proj_x_m, label_embs), target[masked_indices])]
+ else:
+ logit_m_list = [None]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ proj_x_u = proj(x[nomask_indices])
+ logit_u_list = [(self.compute_pred(proj_x_u, label_embs), target[nomask_indices])]
+ else:
+ logit_u_list = [None]
+
+ return logit_m_list, logit_u_list
+
+ def convert_embeddings(self,
+ x,
+ padding_mask,
+ target=None,
+ mask_indices=None,
+ mix_with_unit=False,
+ use_pred_unit=False,
+ l2_embedding=False,
+ remask=False
+ ):
+ """
+ 1. Mix with units if needed (default: True)
+ 2. Prepare for unit_encoder inputs
+ Inputs:
+ x, (B, T, D)
+ Return:
+ src_tokens, (B, T)
+ soft_embeddings, (B, T, D)
+ l2_loss, a loss
+ """
+ soft_embeddings = self.final_proj_list[0](x) if x.size(-1) == self.final_dim else x
+ if padding_mask is None:
+ padding_mask = soft_embeddings.new_zeros(soft_embeddings.size(0), soft_embeddings.size(1), dtype=torch.long)
+ if use_pred_unit:
+ src_tokens = self.compute_pred(self.final_proj_list[0](x), self.label_embs_list[0]).argmax(dim=-1)
+ src_tokens[padding_mask] = self.padding_idx
+ elif target is not None:
+ src_tokens = target
+ else:
+ src_tokens = padding_mask.long()
+
+ if l2_embedding | mix_with_unit:
+ unit_embeddings = self.unit_embed_tokens(src_tokens) # (B, T, D)
+
+ l2_loss = 0
+ if l2_embedding:
+ if mask_indices is not None:
+ l2_loss = (soft_embeddings - unit_embeddings)[mask_indices].float().pow(2).mean(dim=-1)
+ scale = unit_embeddings[mask_indices].float().pow(2).sum(dim=-1)
+ else:
+ l2_loss = (soft_embeddings - unit_embeddings).float().pow(2).mean(dim=-1)
+ scale = unit_embeddings.float().pow(2).sum(dim=-1)
+ l2_loss = (l2_loss / scale).mean()
+
+ if mix_with_unit:
+ B, T, D = x.shape
+ selected_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob / 2,
+ self.mask_length // 2,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ selected_indices = torch.from_numpy(selected_indices).to(x.device)
+ if mask_indices is not None:
+ if remask:
+ remask_indices = torch.logical_and(selected_indices, mask_indices)
+ soft_embeddings[remask_indices] = self.mask_emb
+ swap_indices = torch.logical_and(selected_indices, ~mask_indices)
+ else:
+ swap_indices = selected_indices
+ soft_embeddings[swap_indices] = unit_embeddings[swap_indices]
+
+ soft_embeddings = soft_embeddings * (1 - padding_mask.unsqueeze(-1).type_as(x))
+ return src_tokens, soft_embeddings, l2_loss
+
+ def forward(
+ self,
+ source: torch.Tensor = None,
+ src_tokens: torch.Tensor = None,
+ src_lengths: torch.Tensor = None,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Dict[str, torch.Tensor]:
+ assert source is not None or src_tokens is not None
+ if source is not None:
+ return self.forward_speech(
+ source=source,
+ target_list=target_list,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=features_only,
+ output_layer=output_layer,
+ )
+ else:
+ return self.forward_text(
+ src_tokens=src_tokens,
+ src_lengths=src_lengths,
+ mask=self.mask_u2t,
+ output_layer=output_layer,
+ )
+
+ def forward_speech(
+ self,
+ source: torch.Tensor = None,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Dict[str, torch.Tensor]:
+ """output layer is 1-based"""
+ features = self.forward_features(source)
+ if target_list is not None:
+ features, target_list = self.forward_targets(features, target_list)
+
+ features_pen = features.float().pow(2).mean()
+
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+ unmasked_features = features.clone()
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ features = self.dropout_input(features)
+ unmasked_features = self.dropout_features(unmasked_features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
+ else:
+ x = features
+ mask_indices = None
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+ x, layer_results = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=None if output_layer is None else output_layer - 1,
+ )
+
+ if features_only:
+ return {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
+
+ logit_m_list, logit_u_list = self.compute_hubert_logits(
+ x,
+ target_list[0],
+ self.final_proj_list[0],
+ self.label_embs_list[0],
+ padding_mask,
+ mask_indices,
+ )
+
+ result = {
+ "logit_m_list": logit_m_list,
+ "logit_u_list": logit_u_list,
+ "padding_mask": padding_mask,
+ "features_pen": features_pen,
+ }
+
+ if self.add_unit_encoder:
+ src_tokens, x_emb, l2_loss = self.convert_embeddings(
+ x,
+ padding_mask, target_list[0],
+ mask_indices=mask_indices,
+ mix_with_unit=self.mix_with_unit,
+ use_pred_unit=self.use_pred_unit,
+ l2_embedding=self.l2_embedding,
+ )
+ encoder_out = self.unit_encoder(src_tokens, token_embeddings=x_emb)
+
+ result['encoder_out'] = encoder_out['encoder_out'] # [(T, B, D)]
+ result['encoder_padding_mask'] = encoder_out['encoder_padding_mask'] # [(B, T)]
+ if self.l2_embedding:
+ result['embedding_l2_loss'] = l2_loss
+
+ code_logit_m_list, code_logit_u_list = self.compute_hubert_logits(
+ encoder_out['encoder_out'][0].transpose(0, 1),
+ target_list[-1],
+ self.final_proj_list[-1],
+ self.label_embs_list[-1],
+ padding_mask,
+ mask_indices,
+ )
+ result['logit_m_list'] += code_logit_m_list
+ result['logit_u_list'] += code_logit_u_list
+ return result
+
+ def forward_text(
+ self,
+ src_tokens: torch.Tensor = None,
+ src_lengths: torch.Tensor = None,
+ target_list: Optional[List[torch.Tensor]] = None,
+ mask: bool = True,
+ output_layer: Optional[int] = None,
+ ) -> Dict[str, torch.Tensor]:
+ assert self.add_unit_encoder, f"Can not forward unit-text branch without unit_encoder!"
+
+ padding_mask = src_tokens == self.padding_idx
+ unit_embeddings = self.unit_embed_tokens(src_tokens)
+ if mask:
+ unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, [src_tokens])
+ else:
+ ### If already applied mask on src_tokens, then the target_list should contains many padding_idx
+ mask_indices = target_list[-1] != self.padding_idx
+ unit_embeddings[mask_indices] = self.mask_emb
+
+ encoder_out = self.unit_encoder(
+ src_tokens,
+ token_embeddings=unit_embeddings,
+ return_all_hiddens=output_layer is not None,
+ )
+
+ result = {}
+ result["encoder_out"] = encoder_out["encoder_out"]
+ result["encoder_states"] = encoder_out["encoder_states"]
+ result["padding_mask"] = padding_mask
+
+ if self.compute_mum:
+ code_logit_m_list, code_logit_u_list = self.compute_hubert_logits(
+ encoder_out["encoder_out"].transpose(0, 1),
+ target_list[-1],
+ self.final_proj_list[-1],
+ self.label_embs_list[-1],
+ padding_mask,
+ mask_indices,
+ )
+ result["logit_m_list"] = code_logit_m_list
+ result["logit_u_list"] = code_logit_u_list
+
+ if self.add_text_ctc:
+ result["encoder_out_ctc"] = [self.unit_encoder_ctc_head(x) for x in encoder_out['encoder_out']]
+ result["encoder_padding_mask"] = [
+ self.downsample_ctc_padding_mask(padding_mask) for padding_mask in encoder_out['encoder_padding_mask']
+ ]
+ return result
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ ret_layer_results: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Extract features for only speech input"""
+ with torch.no_grad():
+ res = self.forward(
+ source,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ output_layer=output_layer,
+ )
+ # {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
+
+ x = res["x"] # B x T x D
+ padding_mask = res["padding_mask"]
+ if self.add_unit_encoder and (output_layer is None or output_layer > self.cfg.encoder_layers):
+ src_tokens, x, _ = self.convert_embeddings(
+ x,
+ padding_mask,
+ mix_with_unit=False,
+ use_pred_unit=False,
+ )
+ return_all_hiddens=output_layer is not None and output_layer > self.cfg.encoder_layers
+ encoder_out = self.unit_encoder(
+ src_tokens,
+ token_embeddings=x,
+ return_all_hiddens=return_all_hiddens,
+ )
+ res["x"] = encoder_out['encoder_out'][0].transpose(0, 1) # (B, T, D)
+ if return_all_hiddens:
+ res["layer_results"] += encoder_out['encoder_states'][1:1+output_layer-len(res["layer_results"])]
+
+ feature = res["features"] if ret_conv else res["x"]
+ if ret_layer_results:
+ feature = (feature, res["layer_results"])
+
+ return feature, padding_mask
+
+ def get_logits(self, net_output, is_masked=True):
+ if is_masked:
+ logits_list = net_output["logit_m_list"]
+ else:
+ logits_list = net_output["logit_u_list"]
+ logits_list = [x[0].float() for x in logits_list if x is not None]
+ return logits_list
+
+ def get_targets(self, net_output, is_masked=True):
+ if is_masked:
+ logits_list = net_output["logit_m_list"]
+ else:
+ logits_list = net_output["logit_u_list"]
+ targets_list = [x[1].long() for x in logits_list if x is not None]
+ return targets_list
+
+ def get_extra_losses(self, net_output):
+ extra_losses = []
+ names = []
+
+ if "features_pen" in net_output:
+ extra_losses.append(net_output["features_pen"])
+ names.append("features_pen")
+
+ if "embedding_l2_loss" in net_output:
+ extra_losses.append(net_output["embedding_l2_loss"])
+ names.append("embedding_l2_loss")
+
+ return extra_losses, names
+
+ def remove_pretraining_modules(self, step2=False):
+ self.target_glu = None
+
diff --git a/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/config_base_ende.yaml b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/config_base_ende.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..50733b2740c6f02f3adfc1d536a3a4005ffa7d6a
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/config_base_ende.yaml
@@ -0,0 +1,14 @@
+bpe_tokenizer:
+ bpe: sentencepiece
+ sentencepiece_model: spm_char_st_en_de.model
+
+shuffle: false
+use_audio_input: true
+use_sample_rate: 16000
+standardize_audio: false
+vocab_filename: spm_char_st_en_de.txt
+
+# required by speech_to_text task but never used
+input_channels: 1
+input_feat_per_channel: 1
+
diff --git a/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/config_large_ende.yaml b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/config_large_ende.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d3424a3c55f0e48e8197d98cd3e724baa08c834f
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/config_large_ende.yaml
@@ -0,0 +1,14 @@
+bpe_tokenizer:
+ bpe: sentencepiece
+ sentencepiece_model: spm_char_st_en_de.model
+
+shuffle: false
+use_audio_input: true
+use_sample_rate: 16000
+standardize_audio: true
+vocab_filename: spm_char_st_en_de.txt
+
+# required by speech_to_text task but never used
+input_channels: 1
+input_feat_per_channel: 1
+
diff --git a/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/dev-sample100_st_en_de_local.tsv b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/dev-sample100_st_en_de_local.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..c4251fa8a24f33e2ebd44ad90899c0778e24aaf8
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/dev-sample100_st_en_de_local.tsv
@@ -0,0 +1,100 @@
+id audio n_frames tgt_text
+common_voice_en_18540003 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18540003.wav 90624 Wenn Wasser knapp ist, verschwenden Sie es nicht.
+common_voice_en_18540005 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18540005.wav 57984 Du fährst mit ihr bis zu ihrer Tür.
+common_voice_en_18540006 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18540006.wav 63744 Celia schreckte zurück und zitterte.
+common_voice_en_65557 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_65557.wav 40704 Haben Sie einen Ring?
+common_voice_en_65559 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_65559.wav 44160 Ich habe ihn nicht einmal gefragt.
+common_voice_en_19594267 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19594267.wav 110208 Der größte See nach Fläche in der Mongolei, der Uvs-See, ist in der Great Lakes Depression.
+common_voice_en_19594268 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19594268.wav 91392 Die darauffolgende Wiedervereinigung mit Rom hat bis heute ununterbrochen angedauert.
+common_voice_en_19594269 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19594269.wav 64128 Die Saiten könnten aus Messing oder Stahl sein.
+common_voice_en_18282099 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18282099.wav 67584 Andrew rollte sich in der Box zusammen.
+common_voice_en_2518264 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_2518264.wav 61824 Säure ätzt Locher in Wollstoff.
+common_voice_en_18909686 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18909686.wav 147072 Dies wurde später von Herny Seebohm beschrieben und Riesen-Fischuhu genannt.
+common_voice_en_18909688 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18909688.wav 114048 Er ist auch dazu in der Lage, über kurze Distanzen zu schweben.
+common_voice_en_18909689 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18909689.wav 85248 So konnte Letta seine große Koalition fortsetzen.
+common_voice_en_18460666 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18460666.wav 56064 Es nicht gekostet wegschieben?
+common_voice_en_18460690 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18460690.wav 68736 Ich bin verzweifelt, und damit hatte es sich.
+common_voice_en_18460692 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18460692.wav 54912 Ich folge dir nicht, Jeeves.
+common_voice_en_485640 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_485640.wav 70272 Ordentliche Pläne scheitern ohne Glück.
+common_voice_en_89833 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_89833.wav 128256 Das ist ein super Armband, das du trägst.
+common_voice_en_19001715 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19001715.wav 148224 Der Buddhismus in Afghanistan wurde von den Saffariden, Ghaznawiden und Ghuriden erfolgreich beseitigt.
+common_voice_en_19001716 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19001716.wav 80256 Das System sieht einen frei schwebenden Lauf vor.
+common_voice_en_19001719 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19001719.wav 109056 Diese bekannten Murderabilia-Händler finden Sie auf den folgenden Websites.
+common_voice_en_9774 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_9774.wav 119040 Sie liest unheimlich gern, weiß jedoch nicht so genau, wie das Lesen zu einer Steigerung der Kreativität beitragen kann.
+common_voice_en_26370 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_26370.wav 62208 Danke, dass Sie uns an Ihrer Geschichte haben teilhaben lassen. Alles Gute für die Hochzeitsreise.
+common_voice_en_26372 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_26372.wav 59904 Sie kennen die Uhrzeit doch. Warum fragen Sie mich danach?
+common_voice_en_17260994 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_17260994.wav 155520 Der Fuchs sprang über den Rand der Farm. Dort fand er einen Safari-Reisenden vor, der eine Vivaldi Opera zum Besten gab.
+common_voice_en_18881599 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18881599.wav 108672 "Express" sollte das Gebiet untersuchen und fand dort nichts.
+common_voice_en_18881604 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18881604.wav 92544 Dadurch werden die Probleme gemildert, die durch einen Mangel an Hämoglobin verursacht werden.
+common_voice_en_18881605 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18881605.wav 109056 Diese Behauptungen werden von der Mainstream-Archäologie kategorisch zurückgewiesen.
+common_voice_en_180278 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_180278.wav 48768 Sie sollte eigentlich herunterkommen und Sie abholen.
+common_voice_en_180279 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_180279.wav 54912 Ich werde dort nicht als Geist leben.
+common_voice_en_696251 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_696251.wav 98304 Der Junge hat bemerkt, dass der Engländer nervös war und seine Bücher vergessen hat.
+common_voice_en_19049974 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19049974.wav 73344 Durch eine Augenverletzung fand seine Karriere ein vorzeitiges Ende.
+common_voice_en_19049975 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19049975.wav 126336 Supermatrixes ähnlicher Größe können genauso wie normale Matrixes hinzugefügt und vervielfacht werden.
+common_voice_en_19049976 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19049976.wav 94464 Es liegt annäherungsweise südlich von Denali, der höchsten Erhebung in Nordamerika.
+common_voice_en_19765134 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19765134.wav 110208 Kleinstädte in Vietnam unterstehen der regionalen Regierung.
+common_voice_en_19765136 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19765136.wav 61440 Fünf Jahre später nahm er ihn nach Dresden mit.
+common_voice_en_19765138 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19765138.wav 130176 Der Croma ist standardmäßig mit Anti-Blockier-System (ABS) und Elektronischer Bremskraftverteilung (EBD) ausgestattet.
+common_voice_en_19688061 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19688061.wav 104448 Carter hat zwei Kinder, die Tochter Taleya und den Sohn Tamere.
+common_voice_en_19688062 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19688062.wav 148992 Wenn der Gehalt an gelöstem Sauerstoff zu hypoxischen Bedingungen übergeht, ersticken Fische und andere Meerestiere.
+common_voice_en_19688064 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19688064.wav 48768 Adams hatte ein Leben mit vielen Tiefen.
+common_voice_en_19690060 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19690060.wav 136320 Er hat die Dudley Middle Comprehensive School besucht.
+common_voice_en_19690063 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19690063.wav 116352 Der ursprüngliche Name der Schule lautet "School of Commerce and Domestic Science" (Handels- und Hauswirtschaftsschule).
+common_voice_en_19690064 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19690064.wav 124032 Bei dem Unfall, bei dem er am Steuer saß, befand sich auch Anna, seine Tochter, im Auto. Sie hat den Unfall überlebt.
+common_voice_en_18260377 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18260377.wav 98304 Jeder möchte gemocht werden. Das liegt in der Natur des Menschen.
+common_voice_en_18260378 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18260378.wav 85248 Jeder sollte Zugang zu medizinischer Grundversorgung haben.
+common_voice_en_18260379 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18260379.wav 77184 Während wir älter werden, sind wir in unserem Leben gefangen.
+common_voice_en_100764 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_100764.wav 73344 Sie sollten das in einem Wahrnehmungsexperiment untersuchen.
+common_voice_en_100765 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_100765.wav 70656 Sie haben mich vom ersten Moment an abgelehnt.
+common_voice_en_626029 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_626029.wav 104448 Hanf ist ein Gras, das in Teilen der Tropen vorgefunden wird.
+common_voice_en_19703984 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19703984.wav 142848 Sowohl Federation als auch Neo-Zeon Forces sehen dabei zu als die Axis beim Wiedereintritt von der Bahn abkommen.
+common_voice_en_19703985 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19703985.wav 165120 Das Mutterhaus in Loretto befindet sich in Nerinx, Marion County, Kentucky.
+common_voice_en_19703987 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19703987.wav 114048 Der Umfang von Matildas militärischer Ausbildung wird diskutiert.
+common_voice_en_19676540 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19676540.wav 105984 Zu Lernzwecken wurden Stifte nach und nach durch Schreibtafeln ersetzt.
+common_voice_en_19676541 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19676541.wav 131712 Die extrem hügelige Landschaft zeichnet sich durch eine Art Erhabenheit aus und bietet einen atemberaubenden Ausblick.
+common_voice_en_19676542 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19676542.wav 93696 Die beiden Tierbilder wurden zu einem Bild kombiniert.
+common_voice_en_19678470 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19678470.wav 145920 Sie und Gatti-Casazza haben sich im darauffolgenden Jahr getrennt und sich dann scheiden lassen.
+common_voice_en_19678471 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19678471.wav 74112 Es zeigt allerdings niemand Interesse.
+common_voice_en_19678476 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19678476.wav 98688 Er hat keine sinnvollen Aussagen gemacht. Es war nur Kauderwelsch.
+common_voice_en_17730605 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_17730605.wav 57984 Wer im Glashaus sitzt, sollte nicht mit Steinen werfen.
+common_voice_en_19768089 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19768089.wav 66432 Der Rahmen kippt den Motor leicht nach hinten.
+common_voice_en_19768197 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19768197.wav 58752 Bevor er hauptberuflich Politiker wurde, war er Landwirt.
+common_voice_en_19768200 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19768200.wav 73344 Er hat auch als Karikaturist und Comiczeichner gearbeitet.
+common_voice_en_19699188 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19699188.wav 106368 Das Schiff war zwei von vier Lebensjahren aufgelegt.
+common_voice_en_19699189 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19699189.wav 133632 Boucher hat sich von Künstlern wie Peter Pauls Rubens und Antoine Watteau inspirieren lassen.
+common_voice_en_19699190 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19699190.wav 108288 Zwei Tracks wurden als Auszüge auf einer Single herausgebracht.
+common_voice_en_512711 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_512711.wav 84096 Gruppe von Menschen, von sanftem Licht einer Öllaterne angestrahlt.
+common_voice_en_512712 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_512712.wav 103296 Frau mit hellem Haar und Mann mit einem Lächeln, die nebeneinander sitzen.
+common_voice_en_512713 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_512713.wav 98304 Ein Mann fährt die Straße entlang und passiert Blumenkübel. Er hält dabei ein zweites Fahrrad.
+common_voice_en_19678686 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19678686.wav 114816 Computertische werden normalerweise in der Massenproduktion gefertigt und müssen teilweise in Selbstmontage montiert werden.
+common_voice_en_19678689 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19678689.wav 97536 Aufgrund der geringen Auflage gilt es jetzt als Sammlerstück.
+common_voice_en_19678692 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19678692.wav 139776 Die Songs von Thrussel haben regelmäßig Themen zum Gegenstand, in denen er sich gegen Konsum und Überwachung durch den Staat ausspricht.
+common_voice_en_648128 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_648128.wav 77184 Ein Mann und ein Kind auf einem Campingplatz, die ein Frühstück zubereiten.
+common_voice_en_648129 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_648129.wav 72576 Militärangehörige bereiten sich auf ihren Dienst vor.
+common_voice_en_648130 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_648130.wav 109824 Ein Baseballspieler, der ein blaues T-Shirt trägt, läuft auf eine Base zu.
+common_voice_en_34182 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_34182.wav 82560 Ihr Büro hat mich angerufen, um ihn zurückzuhalten.
+common_voice_en_34184 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_34184.wav 67968 Dieser Satz macht überhaupt keinen Sinn.
+common_voice_en_92676 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_92676.wav 62976 Eine Gruppe von Leuten läuft durch eine Karnevalsgruppe.
+common_voice_en_92677 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_92677.wav 62976 Ein älteres Paar, das singt und Gitarre spielt.
+common_voice_en_92678 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_92678.wav 86016 Zwei Männer in roten Hosen vollführen akrobatische Kunststücke mit einer Leiter.
+common_voice_en_570502 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_570502.wav 82944 Künstliche neuronale Netzwerke können etwas ganz ähnliches ausführen.
+common_voice_en_141246 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_141246.wav 63744 Schalte die Laterne aus, die uns Licht spendet.
+common_voice_en_141247 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_141247.wav 62592 Brian reist heute ab.
+common_voice_en_19047441 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19047441.wav 84096 Die Bewohner haben im Namen des Dauphin eine Silbermine betrieben.
+common_voice_en_19047442 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19047442.wav 74496 Die Statue wurde durch den Millenium Lottery Fund teilfinanziert.
+common_voice_en_19047443 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19047443.wav 117504 Das Henderson House in Elmhurst, Illinois, USA; hat einen ähnlichen Grundriss.
+common_voice_en_567705 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_567705.wav 54144 Hängen Sie an beide Zweige Lametta.
+common_voice_en_17283658 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_17283658.wav 59520 Unter den Linden.
+common_voice_en_17283659 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_17283659.wav 119040 Das höchste Gebäude der Welt ist 829,8 m hoch.
+common_voice_en_18707930 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18707930.wav 107136 Die Stadt liegt in Harris County, in Südost Texas.
+common_voice_en_18707931 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18707931.wav 155904 Das steht im Gegensatz zum Potential des Pacemakers oder dem Strom, der die rhythmische Modulierung der Impulsfrequenz antreibt.
+common_voice_en_18707933 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18707933.wav 90624 Die Stadt wird durch eine Stadtverwaltung regiert.
+common_voice_en_18524588 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18524588.wav 88704 Genehmigen Sie den Ausdruck meiner vorzüglichen Hochachtung.
+common_voice_en_18524590 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18524590.wav 67584 Anhand der Laufzeit kann man ablesen, dass dieser Computer nie neu gestartet wurde.
+common_voice_en_18524592 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_18524592.wav 89856 Celia stand dort, war offenbar nicht betroffen und konnte den Vorkommnissen nicht folgen.
+common_voice_en_19254317 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19254317.wav 134784 Der Unterricht wird extra abends abgehalten, damit die Studenten von High Schools daran teilnehmen können.
+common_voice_en_19254318 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19254318.wav 119424 Dieses Fett ist Halacha und wird auch Chelev oder Talg genannt.
+common_voice_en_19254320 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_19254320.wav 97536 Die Patienten und das Krankenhauspersonal haben sie für den Preis nominiert.
+common_voice_en_542826 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_542826.wav 117120 Jeder Teilbereich des Bildschirms gehört zu einer bestimmten Reihe und Spalte.
+common_voice_en_542828 /LocalData/dataset/CommonVoice/v4/en/wav/common_voice_en_542828.wav 108672 Die internationale Raumstation ist ein großartiges Projekt.
diff --git a/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.model b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.model
new file mode 100644
index 0000000000000000000000000000000000000000..b9418a61f6cb120e16d0b64c67886203b5e95da2
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a0b188591cd0d1e9d713fe1f3a9cfbe23a72b6bf73346ba11a2a70ab1a3a025
+size 239480
diff --git a/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.txt b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1a1a2f6420331fb0efef9fe87631b10fa493dba7
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.txt
@@ -0,0 +1,164 @@
+▁ 1
+e 1
+n 1
+i 1
+r 1
+t 1
+s 1
+a 1
+d 1
+h 1
+u 1
+l 1
+o 1
+c 1
+g 1
+m 1
+. 1
+b 1
+f 1
+w 1
+k 1
+z 1
+S 1
+v 1
+p 1
+, 1
+D 1
+ü 1
+E 1
+ä 1
+A 1
+B 1
+M 1
+G 1
+" 1
+F 1
+K 1
+P 1
+W 1
+T 1
+y 1
+H 1
+ö 1
+I 1
+R 1
+L 1
+- 1
+C 1
+V 1
+N 1
+ß 1
+Z 1
+J 1
+U 1
+j 1
+O 1
+x 1
+? 1
+! 1
+' 1
+q 1
+Y 1
+Ü 1
+: 1
+Q 1
+Ä 1
+Ö 1
+; 1
+( 1
+) 1
+X 1
+0 1
+1 1
+[ 1
+] 1
+é 1
+2 1
+& 1
+3 1
+5 1
+4 1
+7 1
+9 1
+8 1
+6 1
+/ 1
+á 1
+ō 1
+ó 1
+ñ 1
+ú 1
+í 1
+ā 1
+è 1
+* 1
+ć 1
+à 1
+ê 1
+ë 1
+¡ 1
+ç 1
+ð 1
+ã 1
+č 1
+ū 1
+% 1
+É 1
+â 1
+ø 1
+š 1
+å 1
+ô 1
+ł 1
+œ 1
+ş 1
+Š 1
+_ 1
+Î 1
+Ó 1
+æ 1
+ï 1
+ă 1
+ě 1
+ī 1
+ı 1
+ʻ 1
+ʿ 1
+π 1
+и 1
+к 1
+= 1
+Ã 1
+Ø 1
+î 1
+û 1
+þ 1
+ċ 1
+Č 1
+ę 1
+ğ 1
+ń 1
+Ō 1
+ő 1
+ř 1
+ž 1
+ǎ 1
+α 1
+В 1
+е 1
+з 1
+й 1
+л 1
+н 1
+ь 1
+я 1
+ṃ 1
+ạ 1
+ụ 1
+→ 1
+≡ 1
+京 1
+大 1
+都 1
+阪 1
diff --git a/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.vocab b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.vocab
new file mode 100644
index 0000000000000000000000000000000000000000..dcaf02c4610abddbef943bb81b8df7807ca6d7ca
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/CommonVoice/v4/en/en-de/spm_char_st_en_de.vocab
@@ -0,0 +1,168 @@
+ 0
+ 0
+ 0
+ 0
+▁ -1.94346
+e -2.0247
+n -2.52771
+i -2.69095
+r -2.81179
+t -2.99429
+s -3.07457
+a -3.08727
+d -3.37853
+h -3.41543
+u -3.52845
+l -3.53925
+o -3.76429
+c -3.83672
+g -3.89086
+m -4.03425
+. -4.27171
+b -4.34078
+f -4.45167
+w -4.51255
+k -4.68054
+z -4.81542
+S -4.96966
+v -5.01738
+p -5.09819
+, -5.11371
+D -5.22687
+ü -5.34517
+E -5.43072
+ä -5.43483
+A -5.61389
+B -5.67037
+M -5.68285
+G -5.93387
+" -5.94796
+F -5.95252
+K -5.99114
+P -6.03568
+W -6.0592
+T -6.08128
+y -6.08834
+H -6.14664
+ö -6.17763
+I -6.18576
+R -6.22513
+L -6.30172
+- -6.34074
+C -6.41901
+V -6.44441
+N -6.48507
+ß -6.60475
+Z -6.78851
+J -6.81489
+U -7.04154
+j -7.07161
+O -7.13538
+x -7.50985
+? -7.66957
+! -8.34983
+' -8.62779
+q -8.7511
+Y -8.80869
+Ü -9.0344
+: -9.03696
+Q -9.11993
+Ä -9.61997
+Ö -9.9612
+; -10.0729
+( -10.0826
+) -10.0839
+X -10.6277
+0 -11.1096
+1 -11.1164
+[ -11.296
+] -11.296
+é -11.3293
+2 -11.4413
+& -12.1488
+3 -12.188
+5 -12.3864
+4 -12.4237
+7 -12.4891
+9 -12.6035
+8 -12.6343
+6 -12.666
+/ -12.9645
+á -13.1043
+ō -13.392
+ó -13.5351
+ñ -13.6151
+ú -13.9028
+í -14.1541
+ā -14.1541
+è -14.2282
+* -14.3953
+ć -14.7137
+à -14.8472
+ê -14.8472
+ë -14.8472
+¡ -15.0014
+ç -15.0014
+ð -15.0014
+ã -15.1837
+č -15.1837
+ū -15.1837
+% -15.4069
+É -15.4069
+â -15.4069
+ø -15.4069
+š -15.4069
+å -15.6945
+ô -15.6945
+ł -15.6945
+œ -15.6945
+ş -15.6945
+Š -15.6945
+_ -16.1
+Î -16.1
+Ó -16.1
+æ -16.1
+ï -16.1
+ă -16.1
+ě -16.1
+ī -16.1
+ı -16.1
+ʻ -16.1
+ʿ -16.1
+π -16.1
+и -16.1
+к -16.1
+= -16.7932
+Ã -16.7932
+Ø -16.7932
+î -16.7932
+û -16.7932
+þ -16.7932
+ċ -16.7932
+Č -16.7932
+ę -16.7932
+ğ -16.7932
+ń -16.7932
+Ō -16.7932
+ő -16.7932
+ř -16.7932
+ž -16.7932
+ǎ -16.7932
+α -16.7932
+В -16.7932
+е -16.7932
+з -16.7932
+й -16.7932
+л -16.7932
+н -16.7932
+ь -16.7932
+я -16.7932
+ṃ -16.7932
+ạ -16.7932
+ụ -16.7932
+→ -16.7932
+≡ -16.7932
+京 -16.7932
+大 -16.7932
+都 -16.7932
+阪 -16.7932
diff --git a/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/config.yaml b/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..97f25d9780d99813e322fbbf24c5b916525ede94
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/config.yaml
@@ -0,0 +1,3 @@
+vocab_filename: dict.ltr.txt
+src_vocab_filename: dict.km.txt
+
diff --git a/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/dict.km.txt b/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/dict.km.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bbfe59e554d6234f3631d8d09d9281c2160f4675
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/dict.km.txt
@@ -0,0 +1,500 @@
+0 0
+1 1
+2 2
+3 3
+4 4
+5 5
+6 6
+7 7
+8 8
+9 9
+10 10
+11 11
+12 12
+13 13
+14 14
+15 15
+16 16
+17 17
+18 18
+19 19
+20 20
+21 21
+22 22
+23 23
+24 24
+25 25
+26 26
+27 27
+28 28
+29 29
+30 30
+31 31
+32 32
+33 33
+34 34
+35 35
+36 36
+37 37
+38 38
+39 39
+40 40
+41 41
+42 42
+43 43
+44 44
+45 45
+46 46
+47 47
+48 48
+49 49
+50 50
+51 51
+52 52
+53 53
+54 54
+55 55
+56 56
+57 57
+58 58
+59 59
+60 60
+61 61
+62 62
+63 63
+64 64
+65 65
+66 66
+67 67
+68 68
+69 69
+70 70
+71 71
+72 72
+73 73
+74 74
+75 75
+76 76
+77 77
+78 78
+79 79
+80 80
+81 81
+82 82
+83 83
+84 84
+85 85
+86 86
+87 87
+88 88
+89 89
+90 90
+91 91
+92 92
+93 93
+94 94
+95 95
+96 96
+97 97
+98 98
+99 99
+100 100
+101 101
+102 102
+103 103
+104 104
+105 105
+106 106
+107 107
+108 108
+109 109
+110 110
+111 111
+112 112
+113 113
+114 114
+115 115
+116 116
+117 117
+118 118
+119 119
+120 120
+121 121
+122 122
+123 123
+124 124
+125 125
+126 126
+127 127
+128 128
+129 129
+130 130
+131 131
+132 132
+133 133
+134 134
+135 135
+136 136
+137 137
+138 138
+139 139
+140 140
+141 141
+142 142
+143 143
+144 144
+145 145
+146 146
+147 147
+148 148
+149 149
+150 150
+151 151
+152 152
+153 153
+154 154
+155 155
+156 156
+157 157
+158 158
+159 159
+160 160
+161 161
+162 162
+163 163
+164 164
+165 165
+166 166
+167 167
+168 168
+169 169
+170 170
+171 171
+172 172
+173 173
+174 174
+175 175
+176 176
+177 177
+178 178
+179 179
+180 180
+181 181
+182 182
+183 183
+184 184
+185 185
+186 186
+187 187
+188 188
+189 189
+190 190
+191 191
+192 192
+193 193
+194 194
+195 195
+196 196
+197 197
+198 198
+199 199
+200 200
+201 201
+202 202
+203 203
+204 204
+205 205
+206 206
+207 207
+208 208
+209 209
+210 210
+211 211
+212 212
+213 213
+214 214
+215 215
+216 216
+217 217
+218 218
+219 219
+220 220
+221 221
+222 222
+223 223
+224 224
+225 225
+226 226
+227 227
+228 228
+229 229
+230 230
+231 231
+232 232
+233 233
+234 234
+235 235
+236 236
+237 237
+238 238
+239 239
+240 240
+241 241
+242 242
+243 243
+244 244
+245 245
+246 246
+247 247
+248 248
+249 249
+250 250
+251 251
+252 252
+253 253
+254 254
+255 255
+256 256
+257 257
+258 258
+259 259
+260 260
+261 261
+262 262
+263 263
+264 264
+265 265
+266 266
+267 267
+268 268
+269 269
+270 270
+271 271
+272 272
+273 273
+274 274
+275 275
+276 276
+277 277
+278 278
+279 279
+280 280
+281 281
+282 282
+283 283
+284 284
+285 285
+286 286
+287 287
+288 288
+289 289
+290 290
+291 291
+292 292
+293 293
+294 294
+295 295
+296 296
+297 297
+298 298
+299 299
+300 300
+301 301
+302 302
+303 303
+304 304
+305 305
+306 306
+307 307
+308 308
+309 309
+310 310
+311 311
+312 312
+313 313
+314 314
+315 315
+316 316
+317 317
+318 318
+319 319
+320 320
+321 321
+322 322
+323 323
+324 324
+325 325
+326 326
+327 327
+328 328
+329 329
+330 330
+331 331
+332 332
+333 333
+334 334
+335 335
+336 336
+337 337
+338 338
+339 339
+340 340
+341 341
+342 342
+343 343
+344 344
+345 345
+346 346
+347 347
+348 348
+349 349
+350 350
+351 351
+352 352
+353 353
+354 354
+355 355
+356 356
+357 357
+358 358
+359 359
+360 360
+361 361
+362 362
+363 363
+364 364
+365 365
+366 366
+367 367
+368 368
+369 369
+370 370
+371 371
+372 372
+373 373
+374 374
+375 375
+376 376
+377 377
+378 378
+379 379
+380 380
+381 381
+382 382
+383 383
+384 384
+385 385
+386 386
+387 387
+388 388
+389 389
+390 390
+391 391
+392 392
+393 393
+394 394
+395 395
+396 396
+397 397
+398 398
+399 399
+400 400
+401 401
+402 402
+403 403
+404 404
+405 405
+406 406
+407 407
+408 408
+409 409
+410 410
+411 411
+412 412
+413 413
+414 414
+415 415
+416 416
+417 417
+418 418
+419 419
+420 420
+421 421
+422 422
+423 423
+424 424
+425 425
+426 426
+427 427
+428 428
+429 429
+430 430
+431 431
+432 432
+433 433
+434 434
+435 435
+436 436
+437 437
+438 438
+439 439
+440 440
+441 441
+442 442
+443 443
+444 444
+445 445
+446 446
+447 447
+448 448
+449 449
+450 450
+451 451
+452 452
+453 453
+454 454
+455 455
+456 456
+457 457
+458 458
+459 459
+460 460
+461 461
+462 462
+463 463
+464 464
+465 465
+466 466
+467 467
+468 468
+469 469
+470 470
+471 471
+472 472
+473 473
+474 474
+475 475
+476 476
+477 477
+478 478
+479 479
+480 480
+481 481
+482 482
+483 483
+484 484
+485 485
+486 486
+487 487
+488 488
+489 489
+490 490
+491 491
+492 492
+493 493
+494 494
+495 495
+496 496
+497 497
+498 498
+499 499
diff --git a/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/dict.ltr.txt b/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/dict.ltr.txt
new file mode 100644
index 0000000000000000000000000000000000000000..26a7e6ba309998c3868db7ecab5d7afa52a68e52
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriLM/hidden_unit/bin-idx/dict.ltr.txt
@@ -0,0 +1,29 @@
+| 803288730
+E 439294199
+T 319071758
+A 277306732
+O 263784364
+N 239361162
+I 237353011
+H 223346762
+S 220175453
+R 203352500
+D 152198685
+L 141597450
+U 98913389
+M 87138757
+C 84680142
+W 81375101
+F 80240665
+G 70642902
+Y 68388038
+P 58436929
+B 52538531
+V 33250231
+K 26906609
+' 9162896
+X 5075632
+J 4746771
+Q 3401794
+Z 2186971
+ 1
diff --git a/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/config.yaml b/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d6fd3d8c13f92f3ef5796e4c93adb4fe3161a38b
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/config.yaml
@@ -0,0 +1,3 @@
+vocab_filename: dict.ltr.txt
+src_vocab_filename: dict.phn.txt
+
diff --git a/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/dict.ltr.txt b/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/dict.ltr.txt
new file mode 100644
index 0000000000000000000000000000000000000000..26a7e6ba309998c3868db7ecab5d7afa52a68e52
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/dict.ltr.txt
@@ -0,0 +1,29 @@
+| 803288730
+E 439294199
+T 319071758
+A 277306732
+O 263784364
+N 239361162
+I 237353011
+H 223346762
+S 220175453
+R 203352500
+D 152198685
+L 141597450
+U 98913389
+M 87138757
+C 84680142
+W 81375101
+F 80240665
+G 70642902
+Y 68388038
+P 58436929
+B 52538531
+V 33250231
+K 26906609
+' 9162896
+X 5075632
+J 4746771
+Q 3401794
+Z 2186971
+ 1
diff --git a/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/dict.phn.txt b/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/dict.phn.txt
new file mode 100644
index 0000000000000000000000000000000000000000..812e4b06e13b30fda420034927f6f877e2d54f56
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriLM/phone_unit/bin-idx/dict.phn.txt
@@ -0,0 +1,364 @@
+ 0
+SIL 1
+SIL_B 2
+SIL_E 3
+SIL_I 4
+SIL_S 5
+SPN 6
+SPN_B 7
+SPN_E 8
+SPN_I 9
+SPN_S 10
+AA_B 11
+AA_E 12
+AA_I 13
+AA_S 14
+AA0_B 15
+AA0_E 16
+AA0_I 17
+AA0_S 18
+AA1_B 19
+AA1_E 20
+AA1_I 21
+AA1_S 22
+AA2_B 23
+AA2_E 24
+AA2_I 25
+AA2_S 26
+AE_B 27
+AE_E 28
+AE_I 29
+AE_S 30
+AE0_B 31
+AE0_E 32
+AE0_I 33
+AE0_S 34
+AE1_B 35
+AE1_E 36
+AE1_I 37
+AE1_S 38
+AE2_B 39
+AE2_E 40
+AE2_I 41
+AE2_S 42
+AH_B 43
+AH_E 44
+AH_I 45
+AH_S 46
+AH0_B 47
+AH0_E 48
+AH0_I 49
+AH0_S 50
+AH1_B 51
+AH1_E 52
+AH1_I 53
+AH1_S 54
+AH2_B 55
+AH2_E 56
+AH2_I 57
+AH2_S 58
+AO_B 59
+AO_E 60
+AO_I 61
+AO_S 62
+AO0_B 63
+AO0_E 64
+AO0_I 65
+AO0_S 66
+AO1_B 67
+AO1_E 68
+AO1_I 69
+AO1_S 70
+AO2_B 71
+AO2_E 72
+AO2_I 73
+AO2_S 74
+AW_B 75
+AW_E 76
+AW_I 77
+AW_S 78
+AW0_B 79
+AW0_E 80
+AW0_I 81
+AW0_S 82
+AW1_B 83
+AW1_E 84
+AW1_I 85
+AW1_S 86
+AW2_B 87
+AW2_E 88
+AW2_I 89
+AW2_S 90
+AY_B 91
+AY_E 92
+AY_I 93
+AY_S 94
+AY0_B 95
+AY0_E 96
+AY0_I 97
+AY0_S 98
+AY1_B 99
+AY1_E 100
+AY1_I 101
+AY1_S 102
+AY2_B 103
+AY2_E 104
+AY2_I 105
+AY2_S 106
+B_B 107
+B_E 108
+B_I 109
+B_S 110
+CH_B 111
+CH_E 112
+CH_I 113
+CH_S 114
+D_B 115
+D_E 116
+D_I 117
+D_S 118
+DH_B 119
+DH_E 120
+DH_I 121
+DH_S 122
+EH_B 123
+EH_E 124
+EH_I 125
+EH_S 126
+EH0_B 127
+EH0_E 128
+EH0_I 129
+EH0_S 130
+EH1_B 131
+EH1_E 132
+EH1_I 133
+EH1_S 134
+EH2_B 135
+EH2_E 136
+EH2_I 137
+EH2_S 138
+ER_B 139
+ER_E 140
+ER_I 141
+ER_S 142
+ER0_B 143
+ER0_E 144
+ER0_I 145
+ER0_S 146
+ER1_B 147
+ER1_E 148
+ER1_I 149
+ER1_S 150
+ER2_B 151
+ER2_E 152
+ER2_I 153
+ER2_S 154
+EY_B 155
+EY_E 156
+EY_I 157
+EY_S 158
+EY0_B 159
+EY0_E 160
+EY0_I 161
+EY0_S 162
+EY1_B 163
+EY1_E 164
+EY1_I 165
+EY1_S 166
+EY2_B 167
+EY2_E 168
+EY2_I 169
+EY2_S 170
+F_B 171
+F_E 172
+F_I 173
+F_S 174
+G_B 175
+G_E 176
+G_I 177
+G_S 178
+HH_B 179
+HH_E 180
+HH_I 181
+HH_S 182
+IH_B 183
+IH_E 184
+IH_I 185
+IH_S 186
+IH0_B 187
+IH0_E 188
+IH0_I 189
+IH0_S 190
+IH1_B 191
+IH1_E 192
+IH1_I 193
+IH1_S 194
+IH2_B 195
+IH2_E 196
+IH2_I 197
+IH2_S 198
+IY_B 199
+IY_E 200
+IY_I 201
+IY_S 202
+IY0_B 203
+IY0_E 204
+IY0_I 205
+IY0_S 206
+IY1_B 207
+IY1_E 208
+IY1_I 209
+IY1_S 210
+IY2_B 211
+IY2_E 212
+IY2_I 213
+IY2_S 214
+JH_B 215
+JH_E 216
+JH_I 217
+JH_S 218
+K_B 219
+K_E 220
+K_I 221
+K_S 222
+L_B 223
+L_E 224
+L_I 225
+L_S 226
+M_B 227
+M_E 228
+M_I 229
+M_S 230
+N_B 231
+N_E 232
+N_I 233
+N_S 234
+NG_B 235
+NG_E 236
+NG_I 237
+NG_S 238
+OW_B 239
+OW_E 240
+OW_I 241
+OW_S 242
+OW0_B 243
+OW0_E 244
+OW0_I 245
+OW0_S 246
+OW1_B 247
+OW1_E 248
+OW1_I 249
+OW1_S 250
+OW2_B 251
+OW2_E 252
+OW2_I 253
+OW2_S 254
+OY_B 255
+OY_E 256
+OY_I 257
+OY_S 258
+OY0_B 259
+OY0_E 260
+OY0_I 261
+OY0_S 262
+OY1_B 263
+OY1_E 264
+OY1_I 265
+OY1_S 266
+OY2_B 267
+OY2_E 268
+OY2_I 269
+OY2_S 270
+P_B 271
+P_E 272
+P_I 273
+P_S 274
+R_B 275
+R_E 276
+R_I 277
+R_S 278
+S_B 279
+S_E 280
+S_I 281
+S_S 282
+SH_B 283
+SH_E 284
+SH_I 285
+SH_S 286
+T_B 287
+T_E 288
+T_I 289
+T_S 290
+TH_B 291
+TH_E 292
+TH_I 293
+TH_S 294
+UH_B 295
+UH_E 296
+UH_I 297
+UH_S 298
+UH0_B 299
+UH0_E 300
+UH0_I 301
+UH0_S 302
+UH1_B 303
+UH1_E 304
+UH1_I 305
+UH1_S 306
+UH2_B 307
+UH2_E 308
+UH2_I 309
+UH2_S 310
+UW_B 311
+UW_E 312
+UW_I 313
+UW_S 314
+UW0_B 315
+UW0_E 316
+UW0_I 317
+UW0_S 318
+UW1_B 319
+UW1_E 320
+UW1_I 321
+UW1_S 322
+UW2_B 323
+UW2_E 324
+UW2_I 325
+UW2_S 326
+V_B 327
+V_E 328
+V_I 329
+V_S 330
+W_B 331
+W_E 332
+W_I 333
+W_S 334
+Y_B 335
+Y_E 336
+Y_I 337
+Y_S 338
+Z_B 339
+Z_E 340
+Z_I 341
+Z_S 342
+ZH_B 343
+ZH_E 344
+ZH_I 345
+ZH_S 346
+#0 347
+#1 348
+#2 349
+#3 350
+#4 351
+#5 352
+#6 353
+#7 354
+#8 355
+#9 356
+#10 357
+#11 358
+#12 359
+#13 360
+#14 361
+#15 362
+#16 363
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/dict.ltr.txt b/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/dict.ltr.txt
new file mode 100644
index 0000000000000000000000000000000000000000..26a7e6ba309998c3868db7ecab5d7afa52a68e52
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/dict.ltr.txt
@@ -0,0 +1,29 @@
+| 803288730
+E 439294199
+T 319071758
+A 277306732
+O 263784364
+N 239361162
+I 237353011
+H 223346762
+S 220175453
+R 203352500
+D 152198685
+L 141597450
+U 98913389
+M 87138757
+C 84680142
+W 81375101
+F 80240665
+G 70642902
+Y 68388038
+P 58436929
+B 52538531
+V 33250231
+K 26906609
+' 9162896
+X 5075632
+J 4746771
+Q 3401794
+Z 2186971
+ 1
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/train_sample100.ltr b/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/train_sample100.ltr
new file mode 100644
index 0000000000000000000000000000000000000000..ab9ab39e823eba89897e7763155c77d6f2be38a4
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/train_sample100.ltr
@@ -0,0 +1,100 @@
+C H A P T E R | O N E | M I S S U S | R A C H E L | L Y N D E | I S | S U R P R I S E D | M I S S U S | R A C H E L | L Y N D E | L I V E D | J U S T | W H E R E | T H E | A V O N L E A | M A I N | R O A D | D I P P E D | D O W N | I N T O | A | L I T T L E | H O L L O W | F R I N G E D | W I T H | A L D E R S | A N D | L A D I E S | E A R D R O P S | A N D | T R A V E R S E D | B Y | A | B R O O K |
+T H A T | H A D | I T S | S O U R C E | A W A Y | B A C K | I N | T H E | W O O D S | O F | T H E | O L D | C U T H B E R T | P L A C E | I T | W A S | R E P U T E D | T O | B E | A N | I N T R I C A T E | H E A D L O N G | B R O O K | I N | I T S | E A R L I E R | C O U R S E | T H R O U G H | T H O S E | W O O D S | W I T H | D A R K | S E C R E T S | O F | P O O L | A N D | C A S C A D E | B U T | B Y | T H E | T I M E | I T | R E A C H E D | L Y N D E ' S | H O L L O W | I T | W A S | A | Q U I E T | W E L L | C O N D U C T E D | L I T T L E | S T R E A M |
+F O R | N O T | E V E N | A | B R O O K | C O U L D | R U N | P A S T | M I S S U S | R A C H E L | L Y N D E ' S | D O O R | W I T H O U T | D U E | R E G A R D | F O R | D E C E N C Y | A N D | D E C O R U M | I T | P R O B A B L Y | W A S | C O N S C I O U S | T H A T | M I S S U S | R A C H E L | W A S | S I T T I N G | A T | H E R | W I N D O W | K E E P I N G | A | S H A R P | E Y E | O N | E V E R Y T H I N G | T H A T | P A S S E D | F R O M | B R O O K S | A N D | C H I L D R E N | U P |
+A N D | T H A T | I F | S H E | N O T I C E D | A N Y T H I N G | O D D | O R | O U T | O F | P L A C E | S H E | W O U L D | N E V E R | R E S T | U N T I L | S H E | H A D | F E R R E T E D | O U T | T H E | W H Y S | A N D | W H E R E F O R E S | T H E R E O F | T H E R E | A R E | P L E N T Y | O F | P E O P L E | I N | A V O N L E A | A N D | O U T | O F | I T | W H O | C A N | A T T E N D | C L O S E L Y | T O | T H E I R | N E I G H B O R ' S | B U S I N E S S | B Y | D I N T | O F | N E G L E C T I N G | T H E I R | O W N |
+B U T | M I S S U S | R A C H E L | L Y N D E | W A S | O N E | O F | T H O S E | C A P A B L E | C R E A T U R E S | W H O | C A N | M A N A G E | T H E I R | O W N | C O N C E R N S | A N D | T H O S E | O F | O T H E R | F O L K S | I N T O | T H E | B A R G A I N | S H E | W A S | A | N O T A B L E | H O U S E W I F E | H E R | W O R K | W A S | A L W A Y S | D O N E | A N D | W E L L | D O N E | S H E | R A N | T H E | S E W I N G | C I R C L E |
+H E L P E D | R U N | T H E | S U N D A Y | S C H O O L | A N D | W A S | T H E | S T R O N G E S T | P R O P | O F | T H E | C H U R C H | A I D | S O C I E T Y | A N D | F O R E I G N | M I S S I O N S | A U X I L I A R Y | Y E T | W I T H | A L L | T H I S | M I S S U S | R A C H E L | F O U N D | A B U N D A N T | T I M E | T O | S I T | F O R | H O U R S | A T | H E R | K I T C H E N | W I N D O W | K N I T T I N G | C O T T O N | W A R P | Q U I L T S | S H E | H A D | K N I T T E D | S I X T E E N | O F | T H E M |
+A S | A V O N L E A | H O U S E K E E P E R S | W E R E | W O N T | T O | T E L L | I N | A W E D | V O I C E S | A N D | K E E P I N G | A | S H A R P | E Y E | O N | T H E | M A I N | R O A D | T H A T | C R O S S E D | T H E | H O L L O W | A N D | W O U N D | U P | T H E | S T E E P | R E D | H I L L | B E Y O N D |
+A N Y B O D Y | W H O | W E N T | O U T | O F | I T | O R | I N T O | I T | H A D | T O | P A S S | O V E R | T H A T | H I L L | R O A D | A N D | S O | R U N | T H E | U N S E E N | G A U N T L E T | O F | M I S S U S | R A C H E L ' S | A L L | S E E I N G | E Y E | S H E | W A S | S I T T I N G | T H E R E | O N E | A F T E R N O O N | I N | E A R L Y | J U N E | T H E | S U N | W A S | C O M I N G | I N | A T | T H E | W I N D O W | W A R M | A N D | B R I G H T |
+T H E | O R C H A R D | O N | T H E | S L O P E | B E L O W | T H E | H O U S E | W A S | I N | A | B R I D A L | F L U S H | O F | P I N K Y | W H I T E | B L O O M | H U M M E D | O V E R | B Y | A | M Y R I A D | O F | B E E S | T H O M A S | L Y N D E | A | M E E K | L I T T L E | M A N | W H O M | A V O N L E A | P E O P L E | C A L L E D | R A C H E L | L Y N D E ' S | H U S B A N D | W A S | S O W I N G | H I S | L A T E | T U R N I P | S E E D | O N | T H E | H I L L | F I E L D | B E Y O N D | T H E | B A R N |
+M I S S U S | R A C H E L | K N E W | T H A T | H E | O U G H T | B E C A U S E | S H E | H A D | H E A R D | H I M | T E L L | P E T E R | M O R R I S O N | T H E | E V E N I N G | B E F O R E | I N | W I L L I A M | J | B L A I R ' S | S T O R E | O V E R | A T | C A R M O D Y | T H A T | H E | M E A N T | T O | S O W | H I S | T U R N I P | S E E D | T H E | N E X T | A F T E R N O O N |
+P E T E R | H A D | A S K E D | H I M | O F | C O U R S E | F O R | M A T T H E W | C U T H B E R T | H A D | N E V E R | B E E N | K N O W N | T O | V O L U N T E E R | I N F O R M A T I O N | A B O U T | A N Y T H I N G | I N | H I S | W H O L E | L I F E | A N D | Y E T | H E R E | W A S | M A T T H E W | C U T H B E R T | A T | H A L F | P A S T | T H R E E | O N | T H E | A F T E R N O O N | O F | A | B U S Y | D A Y | P L A C I D L Y | D R I V I N G | O V E R | T H E | H O L L O W | A N D | U P | T H E | H I L L |
+A N D | H I S | B E S T | S U I T | O F | C L O T H E S | W H I C H | W A S | P L A I N | P R O O F | T H A T | H E | W A S | G O I N G | O U T | O F | A V O N L E A | A N D | H E | H A D | T H E | B U G G Y | A N D | T H E | S O R R E L | M A R E | W H I C H | B E T O K E N E D | T H A T | H E | W A S | G O I N G | A | C O N S I D E R A B L E | D I S T A N C E | N O W | W H E R E | W A S | M A T T H E W | C U T H B E R T | G O I N G | A N D | W H Y | W A S | H E | G O I N G | T H E R E |
+H A D | I T | B E E N | A N Y | O T H E R | M A N | I N | A V O N L E A | M I S S U S | R A C H E L | D E F T L Y | P U T T I N G | T H I S | A N D | T H A T | T O G E T H E R | M I G H T | H A V E | G I V E N | A | P R E T T Y | G O O D | G U E S S | A S | T O | B O T H | Q U E S T I O N S | B U T | M A T T H E W | S O | R A R E L Y | W E N T | F R O M | H O M E | T H A T | I T | M U S T | B E | S O M E T H I N G | P R E S S I N G | A N D | U N U S U A L | W H I C H | W A S | T A K I N G | H I M |
+H E | W A S | T H E | S H Y E S T | M A N | A L I V E | A N D | H A T E D | T O | H A V E | T O | G O | A M O N G | S T R A N G E R S | O R | T O | A N Y | P L A C E | W H E R E | H E | M I G H T | H A V E | T O | T A L K | M A T T H E W | D R E S S E D | U P | W I T H | A | W H I T E | C O L L A R | A N D | D R I V I N G | I N | A | B U G G Y | W A S | S O M E T H I N G | T H A T | D I D N ' T | H A P P E N | O F T E N | M I S S U S | R A C H E L | P O N D E R | A S | S H E | M I G H T | C O U L D | M A K E | N O T H I N G | O F | I T |
+A N D | H E R | A F T E R N O O N ' S | E N J O Y M E N T | W A S | S P O I L E D | I ' L L | J U S T | S T E P | O V E R | T O | G R E E N | G A B L E S | A F T E R | T E A | A N D | F I N D | O U T | F R O M | M A R I L L A | W H E R E | H E ' S | G O N E | A N D | W H Y | T H E | W O R T H Y | W O M A N | F I N A L L Y | C O N C L U D E D | H E | D O E S N ' T | G E N E R A L L Y | G O | T O | T O W N | T H I S | T I M E | O F | Y E A R | A N D | H E | N E V E R | V I S I T S |
+I F | H E ' D | R U N | O U T | O F | T U R N I P | S E E D | H E | W O U L D N ' T | D R E S S | U P | A N D | T A K E | T H E | B U G G Y | T O | G O | F O R | M O R E |
+Y E T | S O M E T H I N G | M U S T | H A V E | H A P P E N E D | S I N C E | L A S T | N I G H T | T O | S T A R T | H I M | O F F | I ' M | C L E A N | P U Z Z L E D | T H A T ' S | W H A T | A N D | I | W O N ' T | K N O W | A | M I N U T E ' S | P E A C E | O F | M I N D | O R | C O N S C I E N C E | U N T I L | I | K N O W | W H A T | H A S | T A K E N | M A T T H E W | C U T H B E R T | O U T | O F | A V O N L E A | T O D A Y | A C C O R D I N G L Y | A F T E R | T E A | M I S S U S | R A C H E L | S E T | O U T | S H E | H A D | N O T | F A R | T O | G O |
+T H E | B I G | R A M B L I N G | O R C H A R D | E M B O W E R E D | H O U S E | W H E R E | T H E | C U T H B E R T S | L I V E D | W A S | A | S C A N T | Q U A R T E R | O F | A | M I L E | U P | T H E | R O A D | F R O M | L Y N D E ' S | H O L L O W | T O | B E | S U R E | T H E | L O N G | L A N E | M A D E | I T | A | G O O D | D E A L | F U R T H E R | M A T T H E W | C U T H B E R T ' S | F A T H E R | A S | S H Y | A N D | S I L E N T | A S | H I S | S O N | A F T E R | H I M |
+H A D | G O T | A S | F A R | A W A Y | A S | H E | P O S S I B L Y | C O U L D | F R O M | H I S | F E L L O W | M E N | W I T H O U T | A C T U A L L Y | R E T R E A T I N G | I N T O | T H E | W O O D S | W H E N | H E | F O U N D E D | H I S | H O M E S T E A D | G R E E N | G A B L E S | W A S | B U I L T | A T | T H E | F U R T H E S T | E D G E | O F | H I S | C L E A R E D | L A N D | A N D | T H E R E | I T | W A S | T O | T H I S | D A Y |
+B A R E L Y | V I S I B L E | F R O M | T H E | M A I N | R O A D | A L O N G | W H I C H | A L L | T H E | O T H E R | A V O N L E A | H O U S E S | W E R E | S O | S O C I A B L Y | S I T U A T E D | M I S S U S | R A C H E L | L Y N D E | D I D | N O T | C A L L | L I V I N G | I N | S U C H | A | P L A C E | L I V I N G | A T | A L L | I T ' S | J U S T | S T A Y I N G | T H A T ' S | W H A T | S H E | S A I D | A S | S H E | S T E P P E D | A L O N G | T H E | D E E P | R U T T E D | G R A S S Y | L A N E |
+B O R D E R E D | W I T H | W I L D | R O S E | B U S H E S | I T ' S | N O | W O N D E R | M A T T H E W | A N D | M A R I L L A | A R E | B O T H | A | L I T T L E | O D D | L I V I N G | A W A Y | B A C K | H E R E | B Y | T H E M S E L V E S | T R E E S | A R E N ' T | M U C H | C O M P A N Y | T H O U G H | D E A R | K N O W S | I F | T H E Y | W E R E | T H E R E ' D | B E | E N O U G H | O F | T H E M | I ' D | R U T H E R | L O O K | A T | P E O P L E | T O | B E | S U R E |
+T H E Y | S E E M | C O N T E N T E D | E N O U G H | B U T | T H E N | I | S U P P O S E | T H E Y ' R E | U S E D | T O | I T | A | B O D Y | C A N | G E T | U S E D | T O | A N Y T H I N G | E V E N | T O | B E I N G | H A N G E D | A S | T H E | I R I S H M A N | S A I D | W I T H | T H I S | M I S S U S | R A C H E L | S T E P P E D | O U T | O F | T H E | L A N E | I N T O | T H E | B A C K Y A R D | O F | G R E E N | G A B L E S | V E R Y | G R E E N | A N D | N E A T | A N D | P R E C I S E | W A S | T H A T | Y A R D |
+S E T | A B O U T | O N | O N E | S I D E | W I T H | G R E A T | P A T R I A R C H A L | W I L L O W S | A N D | T H E | O T H E R | W I T H | P R I M | L O M B A R D I E S | N O T | A | S T R A Y | S T I C K | N O R | S T O N E | W A S | T O | B E | S E E N | F O R | M I S S U S | R A C H E L | W O U L D | H A V E | S E E N | I T | I F | T H E R E | H A D | B E E N | P R I V A T E L Y | S H E | W A S | O F | T H E | O P I N I O N | T H A T | M A R I L L A | C U T H B E R T | S W E P T | T H A T | Y A R D | O V E R | A S | O F T E N | A S | S H E | S W E P T | H E R | H O U S E |
+O N E | C O U L D | H A V E | E A T E N | A | M E A L | O F F | T H E | G R O U N D | W I T H O U T | O V E R B R I M M I N G | T H E | P R O V E R B I A L | P E C K | O F | D I R T | M I S S U S | R A C H E L | R A P P E D | S M A R T L Y | A T | T H E | K I T C H E N | D O O R | A N D | S T E P P E D | I N | W H E N | B I D D E N | T O | D O | S O | T H E | K I T C H E N | A T | G R E E N | G A B L E S | W A S | A | C H E E R F U L | A P A R T M E N T |
+O R | W O U L D | H A V E | B E E N | C H E E R F U L | I F | I T | H A D | N O T | B E E N | S O | P A I N F U L L Y | C L E A N | A S | T O | G I V E | I T | S O M E T H I N G | O F | T H E | A P P E A R A N C E | O F | A N | U N U S E D | P A R L O R | I T S | W I N D O W S | L O O K E D | E A S T | A N D | W E S T | T H R O U G H | T H E | W E S T | O N E | L O O K I N G | O U T | O N | T H E | B A C K | Y A R D | C A M E | A | F L O O D | O F | M E L L O W | J U N E | S U N L I G H T | B U T | T H E | E A S T | O N E |
+W H E N C E | Y O U | G O T | A | G L I M P S E | O F | T H E | B L O O M | W H I T E | C H E R R Y | T R E E S | I N | T H E | L E F T | O R C H A R D | A N D | N O D D I N G | S L E N D E R | B I R C H E S | D O W N | I N | T H E | H O L L O W | B Y | T H E | B R O O K | W A S | G R E E N E D | O V E R | B Y | A | T A N G L E | O F | V I N E S | H E R E | S A T | M A R I L L A | C U T H B E R T | W H E N | S H E | S A T | A T | A L L | A L W A Y S | S L I G H T L Y | D I S T R U S T F U L | O F | S U N S H I N E |
+A N D | H E R E | S H E | S A T | N O W | K N I T T I N G | A N D | T H E | T A B L E | B E H I N D | H E R | W A S | L A I D | F O R | S U P P E R | M I S S U S | R A C H E L | B E F O R E | S H E | H A D | F A I R L Y | C L O S E D | T H E | D O O R |
+T H E R E | W E R E | T H R E E | P L A T E S | L A I D | S O | T H A T | M A R I L L A | M U S T | B E | E X P E C T I N G | S O M E | O N E | H O M E | W I T H | M A T T H E W | T O | T E A | B U T | T H E | D I S H E S | W E R E | E V E R Y D A Y | D I S H E S | A N D | T H E R E | W A S | O N L Y | C R A B | A P P L E | P R E S E R V E S | A N D | O N E | K I N D | O F | C A K E | S O | T H A T | T H E | E X P E C T E D | C O M P A N Y | C O U L D | N O T | B E | A N Y | P A R T I C U L A R | C O M P A N Y |
+Y E T | W H A T | O F | M A T T H E W ' S | W H I T E | C O L L A R | A N D | T H E | S O R R E L | M A R E | M I S S U S | R A C H E L | W A S | G E T T I N G | F A I R L Y | D I Z Z Y | W I T H | T H I S | U N U S U A L | M Y S T E R Y | A B O U T | Q U I E T | U N M Y S T E R I O U S | G R E E N | G A B L E S | G O O D | E V E N I N G | R A C H E L | M A R I L L A | S A I D | B R I S K L Y | T H I S | I S | A | R E A L | F I N E | E V E N I N G | I S N ' T | I T | W O N ' T | Y O U | S I T | D O W N |
+H O W | A R E | A L L | Y O U R | F O L K S | S O M E T H I N G | T H A T | F O R | L A C K | O F | A N Y | O T H E R | N A M E | M I G H T | B E | C A L L E D | F R I E N D S H I P | E X I S T E D | A N D | A L W A Y S | H A D | E X I S T E D | B E T W E E N | M A R I L L A | C U T H B E R T | A N D | M I S S U S | R A C H E L | I N | S P I T E | O F | O R | P E R H A P S | B E C A U S E | O F | T H E I R | D I S S I M I L A R I T Y | M A R I L L A | W A S | A | T A L L |
+T H I N | W O M A N | W I T H | A N G L E S | A N D | W I T H O U T | C U R V E S | H E R | D A R K | H A I R | S H O W E D | S O M E | G R A Y | S T R E A K S | A N D | W A S | A L W A Y S | T W I S T E D | U P | I N | A | H A R D | L I T T L E | K N O T | B E H I N D | W I T H | T W O | W I R E | H A I R P I N S | S T U C K | A G G R E S S I V E L Y | T H R O U G H | I T | S H E | L O O K E D | L I K E | A | W O M A N | O F | N A R R O W | E X P E R I E N C E | A N D | R I G I D | C O N S C I E N C E | W H I C H | S H E | W A S |
+B U T | T H E R E | W A S | A | S A V I N G | S O M E T H I N G | A B O U T | H E R | M O U T H | W H I C H | I F | I T | H A D | B E E N | E V E R | S O | S L I G H T L Y | D E V E L O P E D | M I G H T | H A V E | B E E N | C O N S I D E R E D | I N D I C A T I V E | O F | A | S E N S E | O F | H U M O R | W E ' R E | A L L | P R E T T Y | W E L L | S A I D | M I S S U S | R A C H E L | I | W A S | K I N D | O F | A F R A I D | Y O U | W E R E N ' T | T H O U G H | W H E N | I | S A W | M A T T H E W | S T A R T I N G | O F F | T O D A Y | I | T H O U G H T | M A Y B E | H E | W A S | G O I N G | T O | T H E | D O C T O R ' S |
+M A R I L L A ' S | L I P S | T W I T C H E D | U N D E R S T A N D I N G L Y | S H E | H A D | E X P E C T E D | M I S S U S | R A C H E L | U P | S H E | H A D | K N O W N | T H A T | T H E | S I G H T | O F | M A T T H E W | J A U N T I N G | O F F | S O | U N A C C O U N T A B L Y | W O U L D | B E | T O O | M U C H | F O R | H E R | N E I G H B O R ' S | C U R I O S I T Y | O H | N O | I ' M | Q U I T E | W E L L | A L T H O U G H | I | H A D | A | B A D | H E A D A C H E | Y E S T E R D A Y | S H E | S A I D |
+M A T T H E W | W E N T | T O | B R I G H T | R I V E R | W E ' R E | G E T T I N G | A | L I T T L E | B O Y | F R O M | A N | O R P H A N | A S Y L U M | I N | N O V A | S C O T I A | A N D | H E ' S | C O M I N G | O N | T H E | T R A I N | T O N I G H T | I F | M A R I L L A | H A D | S A I D | T H A T | M A T T H E W | H A D | G O N E | T O | B R I G H T | R I V E R | T O | M E E T | A | K A N G A R O O | F R O M | A U S T R A L I A | M I S S U S | R A C H E L | C O U L D | N O T | H A V E | B E E N | M O R E | A S T O N I S H E D |
+S H E | W A S | A C T U A L L Y | S T R I C K E N | D U M B | F O R | F I V E | S E C O N D S | I T | W A S | U N S U P P O S A B L E | T H A T | M A R I L L A | W A S | M A K I N G | F U N | O F | H E R | B U T | M I S S U S | R A C H E L | W A S | A L M O S T | F O R C E D | T O | S U P P O S E | I T | A R E | Y O U | I N | E A R N E S T | M A R I L L A | S H E | D E M A N D E D | W H E N | V O I C E | R E T U R N E D | T O | H E R | Y E S | O F | C O U R S E |
+S A I D | M A R I L L A | A S | I F | G E T T I N G | B O Y S | F R O M | O R P H A N | A S Y L U M S | I N | N O V A | S C O T I A | W E R E | P A R T | O F | T H E | U S U A L | S P R I N G | W O R K | O N | A N Y | W E L L | R E G U L A T E D | A V O N L E A | F A R M | I N S T E A D | O F | B E I N G | A N | U N H E A R D | O F | I N N O V A T I O N | M I S S U S | R A C H E L | F E L T | T H A T | S H E | H A D | R E C E I V E D | A | S E V E R E | M E N T A L | J O L T | S H E | T H O U G H T | I N | E X C L A M A T I O N | P O I N T S |
+M A R I L L A | A N D | M A T T H E W | C U T H B E R T | O F | A L L | P E O P L E | A D O P T I N G | A | B O Y | F R O M | A N | O R P H A N | A S Y L U M | W E L L | T H E | W O R L D | W A S | C E R T A I N L Y | T U R N I N G | U P S I D E | D O W N | S H E | W O U L D | B E | S U R P R I S E D | A T | N O T H I N G | A F T E R | T H I S | N O T H I N G |
+W H A T | O N | E A R T H | P U T | S U C H | A | N O T I O N | I N T O | Y O U R | H E A D | S H E | D E M A N D E D | D I S A P P R O V I N G L Y | T H I S | H A D | B E E N | D O N E | W I T H O U T | H E R | A D V I C E | B E I N G | A S K E D | A N D | M U S T | P E R F O R C E | B E | D I S A P P R O V E D | W E L L | W E ' V E | B E E N | T H I N K I N G | A B O U T | I T | F O R | S O M E | T I M E | A L L | W I N T E R | I N | F A C T | R E T U R N E D | M A R I L L A |
+M I S S U S | A L E X A N D E R | S P E N C E R | W A S | U P | H E R E | O N E | D A Y | B E F O R E | C H R I S T M A S | A N D | S H E | S A I D | S H E | W A S | G O I N G | T O | G E T | A | L I T T L E | G I R L | F R O M | T H E | A S Y L U M | O V E R | I N | H O P E T O N | I N | T H E | S P R I N G |
+S O | M A T T H E W | A N D | I | H A V E | T A L K E D | I T | O V E R | O F F | A N D | O N | E V E R | S I N C E | W E | T H O U G H T | W E ' D | G E T | A | B O Y | M A T T H E W | I S | G E T T I N G | U P | I N | Y E A R S | Y O U | K N O W | H E ' S | S I X T Y | A N D | H E | I S N ' T | S O | S P R Y | A S | H E | O N C E | W A S | H I S | H E A R T | T R O U B L E S | H I M | A | G O O D | D E A L | A N D | Y O U | K N O W | H O W | D E S P E R A T E | H A R D | I T ' S | G O T | T O | B E | T O | G E T | H I R E D | H E L P |
+T H E R E ' S | N E V E R | A N Y B O D Y | T O | B E | H A D | B U T | T H O S E | S T U P I D | H A L F | G R O W N | L I T T L E | F R E N C H | B O Y S | A N D | A S | S O O N | A S | Y O U | D O | G E T | O N E | B R O K E | I N T O | Y O U R | W A Y S | A N D | T A U G H T | S O M E T H I N G | H E ' S | U P | A N D | O F F | T O | T H E | L O B S T E R | C A N N E R I E S | O R | T H E | S T A T E S | A T | F I R S T | M A T T H E W | S U G G E S T E D | G E T T I N G | A | H O M E | B O Y | B U T | I | S A I D | N O | F L A T | T O | T H A T |
+T H E Y | M A Y | B E | A L L | R I G H T | I ' M | N O T | S A Y I N G | T H E Y ' R E | N O T | B U T | N O | L O N D O N | S T R E E T | A R A B S | F O R | M E | I | S A I D | G I V E | M E | A | N A T I V E | B O R N | A T | L E A S T | T H E R E ' L L | B E | A | R I S K | N O | M A T T E R | W H O | W E | G E T | B U T | I ' L L | F E E L | E A S I E R | I N | M Y | M I N D | A N D | S L E E P | S O U N D E R | A T | N I G H T S | I F | W E | G E T | A | B O R N | C A N A D I A N |
+S O | I N | T H E | E N D | W E | D E C I D E D | T O | A S K | M I S S U S | S P E N C E R | T O | P I C K | U S | O U T | O N E | W H E N | S H E | W E N T | O V E R | T O | G E T | H E R | L I T T L E | G I R L | W E | H E A R D | L A S T | W E E K | S H E | W A S | G O I N G | S O | W E | S E N T | H E R | W O R D | B Y | R I C H A R D | S P E N C E R ' S | F O L K S | A T | C A R M O D Y | T O | B R I N G | U S | A | S M A R T | L I K E L Y | B O Y | O F | A B O U T | T E N | O R | E L E V E N | W E | D E C I D E D | T H A T | W O U L D | B E | T H E | B E S T | A G E |
+O L D | E N O U G H | T O | B E | O F | S O M E | U S E | I N | D O I N G | C H O R E S | R I G H T | O F F | A N D | Y O U N G | E N O U G H | T O | B E | T R A I N E D | U P | P R O P E R | W E | M E A N | T O | G I V E | H I M | A | G O O D | H O M E | A N D | S C H O O L I N G | W E | H A D | A | T E L E G R A M | F R O M | M I S S U S | A L E X A N D E R | S P E N C E R | T O D A Y | T H E | M A I L | M A N | B R O U G H T | I T | F R O M | T H E | S T A T I O N | S A Y I N G | T H E Y | W E R E | C O M I N G | O N | T H E | F I V E | T H I R T Y | T R A I N | T O N I G H T |
+S O | M A T T H E W | W E N T | T O | B R I G H T | R I V E R | T O | M E E T | H I M | M I S S U S | S P E N C E R | W I L L | D R O P | H I M | O F F | T H E R E | O F | C O U R S E | S H E | G O E S | O N | T O | W H I T E | S A N D S | S T A T I O N | H E R S E L F | M I S S U S | R A C H E L | P R I D E D | H E R S E L F | O N | A L W A Y S | S P E A K I N G | H E R | M I N D | S H E | P R O C E E D E D | T O | S P E A K | I T | N O W | H A V I N G | A D J U S T E D | H E R | M E N T A L | A T T I T U D E | T O | T H I S | A M A Z I N G | P I E C E | O F | N E W S |
+W E L L | M A R I L L A | I ' L L | J U S T | T E L L | Y O U | P L A I N | T H A T | I | T H I N K | Y O U ' R E | D O I N G | A | M I G H T Y | F O O L I S H | T H I N G | A | R I S K Y | T H I N G | T H A T ' S | W H A T | Y O U | D O N ' T | K N O W | W H A T | Y O U ' R E | G E T T I N G | Y O U ' R E | B R I N G I N G | A | S T R A N G E | C H I L D | I N T O | Y O U R | H O U S E | A N D | H O M E | A N D | Y O U | D O N ' T | K N O W | A | S I N G L E | T H I N G | A B O U T | H I M | N O R | W H A T | H I S | D I S P O S I T I O N | I S | L I K E | N O R | W H A T | S O R T | O F | P A R E N T S | H E | H A D |
+N O R | H O W | H E ' S | L I K E L Y | T O | T U R N | O U T | W H Y | I T | W A S | O N L Y | L A S T | W E E K | I | R E A D | I N | T H E | P A P E R | H O W | A | M A N | A N D | H I S | W I F E | U P | W E S T | O F | T H E | I S L A N D | T O O K | A | B O Y | O U T | O F | A N | O R P H A N | A S Y L U M | A N D | H E | S E T | F I R E | T O | T H E | H O U S E | A T | N I G H T | S E T | I T | O N | P U R P O S E | M A R I L L A | A N D | N E A R L Y | B U R N T | T H E M | T O | A | C R I S P | I N | T H E I R | B E D S |
+A N D | I | K N O W | A N O T H E R | C A S E | W H E R E | A N | A D O P T E D | B O Y | U S E D | T O | S U C K | T H E | E G G S | T H E Y | C O U L D N ' T | B R E A K | H I M | O F | I T | I F | Y O U | H A D | A S K E D | M Y | A D V I C E | I N | T H E | M A T T E R | W H I C H | Y O U | D I D N ' T | D O | M A R I L L A | I ' D | H A V E | S A I D | F O R | M E R C Y ' S | S A K E | N O T | T O | T H I N K | O F | S U C H | A | T H I N G | T H A T ' S | W H A T |
+T H I S | J O B ' S | C O M F O R T I N G | S E E M E D | N E I T H E R | T O | O F F E N D | N O R | T O | A L A R M | M A R I L L A | S H E | K N I T T E D | S T E A D I L Y | O N | I | D O N ' T | D E N Y | T H E R E ' S | S O M E T H I N G | I N | W H A T | Y O U | S A Y | R A C H E L | I ' V E | H A D | S O M E | Q U A L M S | M Y S E L F | B U T | M A T T H E W | W A S | T E R R I B L E | S E T | O N | I T | I | C O U L D | S E E | T H A T | S O | I | G A V E | I N |
+I T ' S | S O | S E L D O M | M A T T H E W | S E T S | H I S | M I N D | O N | A N Y T H I N G | T H A T | W H E N | H E | D O E S | I | A L W A Y S | F E E L | I T ' S | M Y | D U T Y | T O | G I V E | I N | A N D | A S | F O R | T H E | R I S K | T H E R E ' S | R I S K S | I N | P R E T T Y | N E A R | E V E R Y T H I N G | A | B O D Y | D O E S | I N | T H I S | W O R L D | T H E R E ' S | R I S K S | I N | P E O P L E ' S | H A V I N G | C H I L D R E N | O F | T H E I R | O W N | I F | I T | C O M E S | T O | T H A T | T H E Y | D O N ' T | A L W A Y S | T U R N | O U T | W E L L |
+A N D | T H E N | N O V A | S C O T I A | I S | R I G H T | C L O S E | T O | T H E | I S L A N D | I T | I S N ' T | A S | I F | W E | W E R E | G E T T I N G | H I M | F R O M | E N G L A N D | O R | T H E | S T A T E S | H E | C A N ' T | B E | M U C H | D I F F E R E N T | F R O M | O U R S E L V E S | W E L L | I | H O P E | I T | W I L L | T U R N | O U T | A L L | R I G H T | S A I D | M I S S U S | R A C H E L | I N | A | T O N E | T H A T | P L A I N L Y | I N D I C A T E D | H E R | P A I N F U L | D O U B T S |
+O N L Y | D O N ' T | S A Y | I | D I D N ' T | W A R N | Y O U | I F | H E | B U R N S | G R E E N | G A B L E S | D O W N | O R | P U T S | S T R Y C H N I N E | I N | T H E | W E L L | I | H E A R D | O F | A | C A S E | O V E R | I N | N E W | B R U N S W I C K | W H E R E | A N | O R P H A N | A S Y L U M | C H I L D | D I D | T H A T | A N D | T H E | W H O L E | F A M I L Y | D I E D | I N | F E A R F U L | A G O N I E S | O N L Y | I T | W A S | A | G I R L | I N | T H A T | I N S T A N C E | W E L L | W E ' R E | N O T | G E T T I N G | A | G I R L | S A I D | M A R I L L A |
+A S | I F | P O I S O N I N G | W E L L S | W E R E | A | P U R E L Y | F E M I N I N E | A C C O M P L I S H M E N T | A N D | N O T | T O | B E | D R E A D E D | I N | T H E | C A S E | O F | A | B O Y | I ' D | N E V E R | D R E A M | O F | T A K I N G | A | G I R L | T O | B R I N G | U P | I | W O N D E R | A T | M I S S U S | A L E X A N D E R | S P E N C E R | F O R | D O I N G | I T | B U T | T H E R E | S H E | W O U L D N ' T | S H R I N K | F R O M | A D O P T I N G | A | W H O L E | O R P H A N | A S Y L U M | I F | S H E | T O O K | I T | I N T O | H E R | H E A D |
+M I S S U S | R A C H E L | W O U L D | H A V E | L I K E D | T O | S T A Y | U N T I L | M A T T H E W | C A M E | H O M E | W I T H | H I S | I M P O R T E D | O R P H A N | B U T | R E F L E C T I N G | T H A T | I T | W O U L D | B E | A | G O O D | T W O | H O U R S | A T | L E A S T | B E F O R E | H I S | A R R I V A L | S H E | C O N C L U D E D | T O | G O | U P | T H E | R O A D | T O | R O B E R T | B E L L ' S | A N D | T E L L | T H E | N E W S | I T | W O U L D | C E R T A I N L Y | M A K E | A | S E N S A T I O N | S E C O N D | T O | N O N E |
+A N D | M I S S U S | R A C H E L | D E A R L Y | L O V E D | T O | M A K E | A | S E N S A T I O N | S O | S H E | T O O K | H E R S E L F | A W A Y | S O M E W H A T | T O | M A R I L L A ' S | R E L I E F | F O R | T H E | L A T T E R | F E L T | H E R | D O U B T S | A N D | F E A R S | R E V I V I N G | U N D E R | T H E | I N F L U E N C E | O F | M I S S U S | R A C H E L ' S | P E S S I M I S M | W E L L | O F | A L L | T H I N G S | T H A T | E V E R | W E R E | O R | W I L L | B E | E J A C U L A T E D | M I S S U S | R A C H E L | W H E N | S H E | W A S | S A F E L Y | O U T | I N | T H E | L A N E |
+I T | D O E S | R E A L L Y | S E E M | A S | I F | I | M U S T | B E | D R E A M I N G | W E L L | I ' M | S O R R Y | F O R | T H A T | P O O R | Y O U N G | O N E | A N D | N O | M I S T A K E | M A T T H E W | A N D | M A R I L L A | D O N ' T | K N O W | A N Y T H I N G | A B O U T | C H I L D R E N | A N D | T H E Y ' L L | E X P E C T | H I M | T O | B E | W I S E R | A N D | S T E A D I E R | T H A T | H I S | O W N | G R A N D F A T H E R |
+I T | S E E M S | U N C A N N Y | T O | T H I N K | O F | A | C H I L D | A T | G R E E N | G A B L E S | S O M E H O W | T H E R E ' S | N E V E R | B E E N | O N E | T H E R E | F O R | M A T T H E W | A N D | M A R I L L A | W E R E | G R O W N | U P | W H E N | T H E | N E W | H O U S E | W A S | B U I L T | I F | T H E Y | E V E R | W E R E | C H I L D R E N | W H I C H | I S | H A R D | T O | B E L I E V E | W H E N | O N E | L O O K S | A T | T H E M | I | W O U L D N ' T | B E | I N | T H A T | O R P H A N ' S | S H O E S | F O R | A N Y T H I N G |
+M Y | B U T | I | P I T Y | H I M | T H A T ' S | W H A T | S O | S A I D | M I S S U S | R A C H E L | T O | T H E | W I L D | R O S E | B U S H E S | O U T | O F | T H E | F U L N E S S | O F | H E R | H E A R T |
+C H A P T E R | T W O | M A T T H E W | C U T H B E R T | I S | S U R P R I S E D | M A T T H E W | C U T H B E R T | A N D | T H E | S O R R E L | M A R E | J O G G E D | C O M F O R T A B L Y | O V E R | T H E | E I G H T | M I L E S | T O | B R I G H T | R I V E R | I T | W A S | A | P R E T T Y | R O A D | R U N N I N G | A L O N G | B E T W E E N | S N U G | F A R M S T E A D S | W I T H | N O W | A N D | A G A I N | A | B I T | O F | B A L S A M Y | F I R | W O O D | T O | D R I V E | T H R O U G H |
+O R | A | H O L L O W | W H E R E | W I L D | P L U M S | H U N G | O U T | T H E I R | F I L M Y | B L O O M | T H E | A I R | W A S | S W E E T | W I T H | T H E | B R E A T H | O F | M A N Y | A P P L E | O R C H A R D S | A N D | T H E | M E A D O W S | S L O P E D | A W A Y | I N | T H E | D I S T A N C E | T O | H O R I Z O N | M I S T S | O F | P E A R L | A N D | P U R P L E | W H I L E | T H E | L I T T L E | B I R D S | S A N G | A S | I F | I T | W E R E | T H E | O N E | D A Y | O F | S U M M E R | I N | A L L | T H E | Y E A R |
+M A T T H E W | E N J O Y E D | T H E | D R I V E | A F T E R | H I S | O W N | F A S H I O N | E X C E P T | D U R I N G | T H E | M O M E N T S | W H E N | H E | M E T | W O M E N | A N D | H A D | T O | N O D | T O | T H E M | F O R | I N | P R I N C E | E D W A R D | I S L A N D | Y O U | A R E | S U P P O S E D | T O | N O D | T O | A L L | A N D | S U N D R Y | Y O U | M E E T | O N | T H E | R O A D | W H E T H E R | Y O U | K N O W | T H E M | O R | N O T | M A T T H E W | D R E A D E D | A L L | W O M E N | E X C E P T | M A R I L L A | A N D | M I S S U S | R A C H E L |
+H E | H A D | A N | U N C O M F O R T A B L E | F E E L I N G | T H A T | T H E | M Y S T E R I O U S | C R E A T U R E S | W E R E | S E C R E T L Y | L A U G H I N G | A T | H I M | H E | M A Y | H A V E | B E E N | Q U I T E | R I G H T | I N | T H I N K I N G | S O | F O R | H E | W A S | A N | O D D | L O O K I N G | P E R S O N A G E | W I T H | A N | U N G A I N L Y | F I G U R E | A N D | L O N G | I R O N | G R A Y | H A I R | T H A T | T O U C H E D | H I S | S T O O P I N G | S H O U L D E R S |
+A N D | A | F U L L | S O F T | B R O W N | B E A R D | W H I C H | H E | H A D | W O R N | E V E R | S I N C E | H E | W A S | T W E N T Y | I N | F A C T | H E | H A D | L O O K E D | A T | T W E N T Y | V E R Y | M U C H | A S | H E | L O O K E D | A T | S I X T Y | L A C K I N G | A | L I T T L E | O F | T H E | G R A Y N E S S | W H E N | H E | R E A C H E D | B R I G H T | R I V E R | T H E R E | W A S | N O | S I G N | O F | A N Y | T R A I N |
+H E | T H O U G H T | H E | W A S | T O O | E A R L Y | S O | H E | T I E D | H I S | H O R S E | I N | T H E | Y A R D | O F | T H E | S M A L L | B R I G H T | R I V E R | H O T E L | A N D | W E N T | O V E R | T O | T H E | S T A T I O N | H O U S E | T H E | L O N G | P L A T F O R M | W A S | A L M O S T | D E S E R T E D | T H E | O N L Y | L I V I N G | C R E A T U R E | I N | S I G H T | B E I N G | A | G I R L | W H O | W A S | S I T T I N G | O N | A | P I L E | O F | S H I N G L E S | A T | T H E | E X T R E M E | E N D |
+M A T T H E W | B A R E L Y | N O T I N G | T H A T | I T | W A S | A | G I R L | S I D L E D | P A S T | H E R | A S | Q U I C K L Y | A S | P O S S I B L E | W I T H O U T | L O O K I N G | A T | H E R | H A D | H E | L O O K E D | H E | C O U L D | H A R D L Y | H A V E | F A I L E D | T O | N O T I C E | T H E | T E N S E | R I G I D I T Y | A N D | E X P E C T A T I O N | O F | H E R | A T T I T U D E | A N D | E X P R E S S I O N | S H E | W A S | S I T T I N G | T H E R E | W A I T I N G | F O R | S O M E T H I N G | O R | S O M E B O D Y |
+A N D | S I N C E | S I T T I N G | A N D | W A I T I N G | W A S | T H E | O N L Y | T H I N G | T O | D O | J U S T | T H E N | S H E | S A T | A N D | W A I T E D | W I T H | A L L | H E R | M I G H T | A N D | M A I N | M A T T H E W | E N C O U N T E R E D | T H E | S T A T I O N M A S T E R | L O C K I N G | U P | T H E | T I C K E T | O F F I C E | P R E P A R A T O R Y | T O | G O I N G | H O M E | F O R | S U P P E R | A N D | A S K E D | H I M | I F | T H E | F I V E | T H I R T Y | T R A I N | W O U L D | S O O N | B E | A L O N G |
+T H E | F I V E | T H I R T Y | T R A I N | H A S | B E E N | I N | A N D | G O N E | H A L F | A N | H O U R | A G O | A N S W E R E D | T H A T | B R I S K | O F F I C I A L | B U T | T H E R E | W A S | A | P A S S E N G E R | D R O P P E D | O F F | F O R | Y O U | A | L I T T L E | G I R L | S H E ' S | S I T T I N G | O U T | T H E R E | O N | T H E | S H I N G L E S | I | A S K E D | H E R | T O | G O | I N T O | T H E | L A D I E S | W A I T I N G | R O O M | B U T | S H E | I N F O R M E D | M E | G R A V E L Y | T H A T | S H E | P R E F E R R E D | T O | S T A Y | O U T S I D E |
+S H E ' S | A | C A S E | I | S H O U L D | S A Y | I ' M | N O T | E X P E C T I N G | A | G I R L | S A I D | M A T T H E W | B L A N K L Y | I T ' S | A | B O Y | I ' V E | C O M E | F O R | H E | S H O U L D | B E | H E R E | M I S S U S | A L E X A N D E R | S P E N C E R | W A S | T O | B R I N G | H I M | O V E R | F R O M | N O V A | S C O T I A | F O R | M E | T H E | S T A T I O N M A S T E R | W H I S T L E D |
+G U E S S | T H E R E ' S | S O M E | M I S T A K E | H E | S A I D | M I S S U S | S P E N C E R | C A M E | O F F | T H E | T R A I N | W I T H | T H A T | G I R L | A N D | G A V E | H E R | I N T O | M Y | C H A R G E | S A I D | Y O U | A N D | Y O U R | S I S T E R | W E R E | A D O P T I N G | H E R | F R O M | A N | O R P H A N | A S Y L U M | A N D | T H A T | Y O U | W O U L D | B E | A L O N G | F O R | H E R | P R E S E N T L Y | T H A T ' S | A L L | I | K N O W | A B O U T | I T | A N D | I | H A V E N ' T | G O T | A N Y | M O R E | O R P H A N S | C O N C E A L E D | H E R E A B O U T S |
+I | D O N ' T | U N D E R S T A N D | S A I D | M A T T H E W | H E L P L E S S L Y | W I S H I N G | T H A T | M A R I L L A | W A S | A T | H A N D | T O | C O P E | W I T H | T H E | S I T U A T I O N | W E L L | Y O U ' D | B E T T E R | Q U E S T I O N | T H E | G I R L | S A I D | T H E | S T A T I O N | M A S T E R | C A R E L E S S L Y | I | D A R E | S A Y | S H E ' L L | B E | A B L E | T O | E X P L A I N | S H E ' S | G O T | A | T O N G U E | O F | H E R | O W N | T H A T ' S | C E R T A I N |
+M A Y B E | T H E Y | W E R E | O U T | O F | B O Y S | O F | T H E | B R A N D | Y O U | W A N T E D | H E | W A L K E D | J A U N T I L Y | A W A Y | B E I N G | H U N G R Y | A N D | T H E | U N F O R T U N A T E | M A T T H E W | W A S | L E F T | T O | D O | T H A T | W H I C H | W A S | H A R D E R | F O R | H I M | T H A N | B E A R D I N G | A | L I O N | I N | I T S | D E N | W A L K | U P | T O | A | G I R L | A | S T R A N G E | G I R L | A N | O R P H A N | G I R L |
+A N D | D E M A N D | O F | H E R | W H Y | S H E | W A S N ' T | A | B O Y | M A T T H E W | G R O A N E D | I N | S P I R I T | A S | H E | T U R N E D | A B O U T | A N D | S H U F F L E D | G E N T L Y | D O W N | T H E | P L A T F O R M | T O W A R D S | H E R | S H E | H A D | B E E N | W A T C H I N G | H I M | E V E R | S I N C E | H E | H A D | P A S S E D | H E R | A N D | S H E | H A D | H E R | E Y E S | O N | H I M | N O W | M A T T H E W | W A S | N O T | L O O K I N G | A T | H E R |
+A | C H I L D | O F | A B O U T | E L E V E N | G A R B E D | I N | A | V E R Y | S H O R T | V E R Y | T I G H T | V E R Y | U G L Y | D R E S S | O F | Y E L L O W I S H | G R A Y | W I N C E Y | S H E | W O R E | A | F A D E D | B R O W N | S A I L O R | H A T | A N D | B E N E A T H | T H E | H A T | E X T E N D I N G | D O W N | H E R | B A C K | W E R E | T W O | B R A I D S | O F | V E R Y | T H I C K | D E C I D E D L Y | R E D | H A I R |
+H E R | F A C E | W A S | S M A L L | W H I T E | A N D | T H I N | A L S O | M U C H | F R E C K L E D | H E R | M O U T H | W A S | L A R G E | A N D | S O | W E R E | H E R | E Y E S | W H I C H | L O O K E D | G R E E N | I N | S O M E | L I G H T S | A N D | M O O D S | A N D | G R A Y | I N | O T H E R S | S O | F A R | T H E | O R D I N A R Y | O B S E R V E R | A N | E X T R A O R D I N A R Y | O B S E R V E R |
+M I G H T | H A V E | S E E N | T H A T | T H E | C H I N | W A S | V E R Y | P O I N T E D | A N D | P R O N O U N C E D | T H A T | T H E | B I G | E Y E S | W E R E | F U L L | O F | S P I R I T | A N D | V I V A C I T Y | T H A T | T H E | M O U T H | W A S | S W E E T | L I P P E D | A N D | E X P R E S S I V E | T H A T | T H E | F O R E H E A D | W A S | B R O A D | A N D | F U L L | I N | S H O R T | O U R | D I S C E R N I N G | E X T R A O R D I N A R Y | O B S E R V E R | M I G H T | H A V E | C O N C L U D E D |
+W A S | S O | L U D I C R O U S L Y | A F R A I D | M A T T H E W | H O W E V E R | W A S | S P A R E D | T H E | O R D E A L | O F | S P E A K I N G | F I R S T | F O R | A S | S O O N | A S | S H E | C O N C L U D E D | T H A T | H E | W A S | C O M I N G | T O | H E R | S H E | S T O O D | U P | G R A S P I N G | W I T H | O N E | T H I N | B R O W N | H A N D | T H E | H A N D L E | O F | A | S H A B B Y | O L D | F A S H I O N E D | C A R P E T | B A G | T H E | O T H E R | S H E | H E L D | O U T | T O | H I M |
+I | S U P P O S E | Y O U | A R E | M I S T E R | M A T T H E W | C U T H B E R T | O F | G R E E N | G A B L E S | S H E | S A I D | I N | A | P E C U L I A R L Y | C L E A R | S W E E T | V O I C E | I ' M | V E R Y | G L A D | T O | S E E | Y O U | I | W A S | B E G I N N I N G | T O | B E | A F R A I D | Y O U | W E R E N ' T | C O M I N G | F O R | M E |
+I | H A D | M A D E | U P | M Y | M I N D | T H A T | I F | Y O U | D I D N ' T | C O M E | F O R | M E | T O | N I G H T |
+I | W O U L D N ' T | B E | A | B I T | A F R A I D | A N D | I T | W O U L D | B E | L O V E L Y | T O | S L E E P | I N | A | W I L D | C H E R R Y | T R E E | A L L | W H I T E | W I T H | B L O O M | I N | T H E | M O O N S H I N E | D O N ' T | Y O U | T H I N K | Y O U | C O U L D | I M A G I N E | Y O U | W E R E | D W E L L I N G | I N | M A R B L E | H A L L S | C O U L D N ' T | Y O U |
+M A T T H E W | H A D | T A K E N | T H E | S C R A W N Y | L I T T L E | H A N D | A W K W A R D L Y | I N | H I S | T H E N | A N D | T H E R E | H E | D E C I D E D | W H A T | T O | D O | H E | C O U L D | N O T | T E L L | T H I S | C H I L D | W I T H | T H E | G L O W I N G | E Y E S | T H A T | T H E R E | H A D | B E E N | A | M I S T A K E | H E | W O U L D | T A K E | H E R | H O M E | A N D | L E T | M A R I L L A | D O | T H A T | S H E | C O U L D N ' T | B E | L E F T | A T | B R I G H T | R I V E R | A N Y H O W |
+N O | M A T T E R | W H A T | M I S T A K E | H A D | B E E N | M A D E | S O | A L L | Q U E S T I O N S | A N D | E X P L A N A T I O N S | M I G H T | A S | W E L L | B E | D E F E R R E D | U N T I L | H E | W A S | S A F E L Y | B A C K | A T | G R E E N | G A B L E S | I ' M | S O R R Y | I | W A S | L A T E | H E | S A I D | S H Y L Y | C O M E | A L O N G | T H E | H O R S E | I S | O V E R | I N | T H E | Y A R D | G I V E | M E | Y O U R | B A G | O H | I | C A N | C A R R Y | I T | T H E | C H I L D | R E S P O N D E D | C H E E R F U L L Y |
+I T | I S N ' T | H E A V Y | I ' V E | G O T | A L L | M Y | W O R L D L Y | G O O D S | I N | I T | B U T | I T | I S N ' T | H E A V Y | A N D | I F | I T | I S N ' T | C A R R I E D | I N | J U S T | A | C E R T A I N | W A Y | T H E | H A N D L E | P U L L S | O U T | S O | I ' D | B E T T E R | K E E P | I T | B E C A U S E | I | K N O W | T H E | E X A C T | K N A C K | O F | I T | I T ' S | A N | E X T R E M E L Y | O L D | C A R P E T | B A G | O H | I ' M | V E R Y | G L A D | Y O U ' V E | C O M E | E V E N | I F | I T | W O U L D | H A V E | B E E N | N I C E | T O | S L E E P | I N | A | W I L D | C H E R R Y | T R E E |
+W E ' V E | G O T | T O | D R I V E | A | L O N G | P I E C E | H A V E N ' T | W E | M I S S U S | S P E N C E R | S A I D | I T | W A S | E I G H T | M I L E S | I ' M | G L A D | B E C A U S E | I | L O V E | D R I V I N G | O H | I T | S E E M S | S O | W O N D E R F U L | T H A T | I ' M | G O I N G | T O | L I V E | W I T H | Y O U | A N D | B E L O N G | T O | Y O U | I ' V E | N E V E R | B E L O N G E D | T O | A N Y B O D Y | N O T | R E A L L Y | B U T | T H E | A S Y L U M | W A S | T H E | W O R S T | I ' V E | O N L Y | B E E N | I N | I T | F O U R | M O N T H S | B U T | T H A T | W A S | E N O U G H |
+I T ' S | W O R S E | T H A N | A N Y T H I N G | Y O U | C O U L D | I M A G I N E | M I S S U S | S P E N C E R | S A I D | I T | W A S | W I C K E D | O F | M E | T O | T A L K | L I K E | T H A T |
+T H E Y | W E R E | G O O D | Y O U | K N O W | T H E | A S Y L U M | P E O P L E | B U T | T H E R E | I S | S O | L I T T L E | S C O P E | F O R | T H E | I M A G I N A T I O N | I N | A N | A S Y L U M | O N L Y | J U S T | I N | T H E | O T H E R | O R P H A N S | I T | W A S | P R E T T Y | I N T E R E S T I N G | T O | I M A G I N E | T H I N G S | A B O U T | T H E M |
+W H O | H A D | B E E N | S T O L E N | A W A Y | F R O M | H E R | P A R E N T S | I N | H E R | I N F A N C Y | B Y | A | C R U E L | N U R S E | W H O | D I E D | B E F O R E | S H E | C O U L D | C O N F E S S | I | U S E D | T O | L I E | A W A K E | A T | N I G H T S | A N D | I M A G I N E | T H I N G S | L I K E | T H A T | B E C A U S E | I | D I D N ' T | H A V E | T I M E | I N | T H E | D A Y | I | G U E S S | T H A T ' S | W H Y | I ' M | S O | T H I N | I | A M | D R E A D F U L | T H I N | A I N ' T | I | T H E R E | I S N ' T | A | P I C K | O N | M Y | B O N E S |
+I | D O | L O V E | T O | I M A G I N E | I ' M | N I C E | A N D | P L U M P | W I T H | D I M P L E S | I N | M Y | E L B O W S | W I T H | T H I S | M A T T H E W ' S | C O M P A N I O N | S T O P P E D | T A L K I N G | P A R T L Y | B E C A U S E | S H E | W A S | O U T | O F | B R E A T H | A N D | P A R T L Y | B E C A U S E | T H E Y | H A D | R E A C H E D | T H E | B U G G Y | N O T | A N O T H E R | W O R D | D I D | S H E | S A Y | U N T I L | T H E Y | H A D | L E F T | T H E | V I L L A G E | A N D | W E R E | D R I V I N G | D O W N | A | S T E E P | L I T T L E | H I L L |
+T H E | R O A D | P A R T | O F | W H I C H | H A D | B E E N | C U T | S O | D E E P L Y | I N T O | T H E | S O F T | S O I L | T H A T | T H E | B A N K S | F R I N G E D | W I T H | B L O O M I N G | W I L D | C H E R R Y | T R E E S | A N D | S L I M | W H I T E | B I R C H E S | W E R E | S E V E R A L | F E E T | A B O V E | T H E I R | H E A D S | T H E | C H I L D | P U T | O U T | H E R | H A N D | A N D | B R O K E | O F F | A | B R A N C H | O F | W I L D | P L U M | T H A T | B R U S H E D | A G A I N S T | T H E | S I D E | O F | T H E | B U G G Y |
+I S N ' T | T H A T | B E A U T I F U L | W H A T | D I D | T H A T | T R E E | L E A N I N G | O U T | F R O M | T H E | B A N K | A L L | W H I T E | A N D | L A C Y | M A K E | Y O U | T H I N K | O F | S H E | A S K E D | W E L L | N O W | I | D U N N O | S A I D | M A T T H E W | W H Y | A | B R I D E | O F | C O U R S E | A | B R I D E | A L L | I N | W H I T E | W I T H | A | L O V E L Y | M I S T Y | V E I L |
+I ' V E | N E V E R | S E E N | O N E | B U T | I | C A N | I M A G I N E | W H A T | S H E | W O U L D | L O O K | L I K E | I | D O N ' T | E V E R | E X P E C T | T O | B E | A | B R I D E | M Y S E L F | I ' M | S O | H O M E L Y | N O B O D Y | W I L L | E V E R | W A N T | T O | M A R R Y | M E | U N L E S S | I T | M I G H T | B E | A | F O R E I G N | M I S S I O N A R Y | I | S U P P O S E | A | F O R E I G N | M I S S I O N A R Y | M I G H T N ' T | B E | V E R Y | P A R T I C U L A R |
+B U T | I | D O | H O P E | T H A T | S O M E | D A Y | I | S H A L L | H A V E | A | W H I T E | D R E S S | T H A T | I S | M Y | H I G H E S T | I D E A L | O F | E A R T H L Y | B L I S S | I | J U S T | L O V E | P R E T T Y | C L O T H E S | A N D | I ' V E | N E V E R | H A D | A | P R E T T Y | D R E S S | I N | M Y | L I F E | T H A T | I | C A N | R E M E M B E R | B U T | O F | C O U R S E | I T ' S | A L L | T H E | M O R E | T O | L O O K | F O R W A R D | T O | I S N ' T | I T | A N D | T H E N |
+I | C A N | I M A G I N E | T H A T | I ' M | D R E S S E D | G O R G E O U S L Y | T H I S | M O R N I N G | W H E N | I | L E F T | T H E | A S Y L U M | I | F E L T | S O | A S H A M E D | B E C A U S E | I | H A D | T O | W E A R | T H I S | H O R R I D | O L D | W I N C E Y | D R E S S | A L L | T H E | O R P H A N S | H A D | T O | W E A R | T H E M | Y O U | K N O W | A | M E R C H A N T | I N | H O P E T O N | L A S T | W I N T E R | D O N A T E D | T H R E E | H U N D R E D | Y A R D S | O F | W I N C E Y | T O | T H E | A S Y L U M | S O M E | P E O P L E | S A I D | I T | W A S | B E C A U S E | H E | C O U L D N ' T | S E L L | I T |
+B U T | I ' D | R A T H E R | B E L I E V E | T H A T | I T | W A S | O U T | O F | T H E | K I N D N E S S | O F | H I S | H E A R T | W O U L D N ' T | Y O U | W H E N | W E | G O T | O N | T H E | T R A I N | I | F E L T | A S | I F | E V E R Y B O D Y | M U S T | B E | L O O K I N G | A T | M E | A N D | P I T Y I N G | M E | B U T | I | J U S T | W E N T | T O | W O R K | A N D | I M A G I N E D | T H A T | I | H A D | O N | T H E | M O S T | B E A U T I F U L | P A L E | B L U E | S I L K | D R E S S | B E C A U S E | W H E N | Y O U | A R E | I M A G I N I N G | Y O U | M I G H T | A S | W E L L | I M A G I N E | S O M E T H I N G | W O R T H | W H I L E |
+A N D | A | B I G | H A T | A L L | F L O W E R S | A N D | N O D D I N G | P L U M E S | A N D | A | G O L D | W A T C H | A N D | K I D | G L O V E S | A N D | B O O T S | I | F E L T | C H E E R E D | U P | R I G H T | A W A Y | A N D | I | E N J O Y E D | M Y | T R I P | T O | T H E | I S L A N D | W I T H | A L L | M Y | M I G H T | I | W A S N ' T | A | B I T | S I C K | C O M I N G | O V E R | I N | T H E | B O A T | N E I T H E R | W A S | M I S S U S | S P E N C E R | A L T H O U G H | S H E | G E N E R A L L Y | I S |
+S H E | S A I D | S H E | H A D N ' T | T I M E | T O | G E T | S I C K | W A T C H I N G | T O | S E E | T H A T | I | D I D N ' T | F A L L | O V E R B O A R D | S H E | S A I D | S H E | N E V E R | S A W | T H E | B E A T | O F | M E | F O R | P R O W L I N G | A B O U T | B U T | I F | I T | K E P T | H E R | F R O M | B E I N G | S E A S I C K | I T ' S | A | M E R C Y | I | D I D | P R O W L | I S N ' T | I T | A N D | I | W A N T E D | T O | S E E | E V E R Y T H I N G | T H A T | W A S | T O | B E | S E E N | O N | T H A T | B O A T | B E C A U S E | I | D I D N ' T | K N O W | W H E T H E R | I ' D | E V E R | H A V E | A N O T H E R | O P P O R T U N I T Y |
+O H | T H E R E | A R E | A | L O T | M O R E | C H E R R Y | T R E E S | A L L | I N | B L O O M | T H I S | I S L A N D | I S | T H E | B L O O M I E S T | P L A C E | I | J U S T | L O V E | I T | A L R E A D Y | A N D | I ' M | S O | G L A D | I ' M | G O I N G | T O | L I V E | H E R E | I ' V E | A L W A Y S | H E A R D | T H A T | P R I N C E | E D W A R D | I S L A N D | W A S | T H E | P R E T T I E S T | P L A C E | I N | T H E | W O R L D |
+A N D | I | U S E D | T O | I M A G I N E | I | W A S | L I V I N G | H E R E | B U T | I | N E V E R | R E A L L Y | E X P E C T E D | I | W O U L D | I T ' S | D E L I G H T F U L | W H E N | Y O U R | I M A G I N A T I O N S | C O M E | T R U E | I S N ' T | I T | B U T | T H O S E | R E D | R O A D S | A R E | S O | F U N N Y | W H E N | W E | G O T | I N T O | T H E | T R A I N | A T | C H A R L O T T E T O W N | A N D | T H E | R E D | R O A D S | B E G A N | T O | F L A S H | P A S T | I | A S K E D | M I S S U S | S P E N C E R | W H A T | M A D E | T H E M | R E D |
+A N D | S H E | S A I D | S H E | D I D N ' T | K N O W | A N D | F O R | P I T Y ' S | S A K E | N O T | T O | A S K | H E R | A N Y | M O R E | Q U E S T I O N S | S H E | S A I D | I | M U S T | H A V E | A S K E D | H E R | A | T H O U S A N D | A L R E A D Y | I | S U P P O S E | I | H A D | T O O | B U T | H O W | Y O U | G O I N G | T O | F I N D | O U T | A B O U T | T H I N G S | I F | Y O U | D O N ' T | A S K | Q U E S T I O N S | A N D | W H A T | D O E S | M A K E | T H E | R O A D S | R E D | W E L L | N O W | I | D U N N O | S A I D | M A T T H E W |
+T H E R E ' D | B E | N O | S C O P E | F O R | I M A G I N A T I O N | T H E N | W O U L D | T H E R E | B U T | A M | I | T A L K I N G | T O O | M U C H | P E O P L E | A R E | A L W A Y S | T E L L I N G | M E | I | D O | W O U L D | Y O U | R A T H E R | I | D I D N ' T | T A L K | I F | Y O U | S A Y | S O | I ' L L | S T O P | I | C A N | S T O P | W H E N | I | M A K E | U P | M Y | M I N D | T O | I T | A L T H O U G H | I T ' S | D I F F I C U L T | M A T T H E W |
+W A S | E N J O Y I N G | H I M S E L F | L I K E | M O S T | Q U I E T | F O L K S | H E | L I K E D | T A L K A T I V E | P E O P L E | W H E N | T H E Y | W E R E | W I L L I N G | T O | D O | T H E | T A L K I N G | T H E M S E L V E S | A N D | D I D | N O T | E X P E C T | H I M | T O | K E E P | U P | H I S | E N D | O F | I T | B U T | H E | H A D | N E V E R | E X P E C T E D | T O | E N J O Y | T H E | S O C I E T Y | O F | A | L I T T L E | G I R L | W O M E N | W E R E | B A D | E N O U G H | I N | A L L | C O N S C I E N C E | B U T | L I T T L E | G I R L S | W E R E | W O R S E |
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/train_sample100.tsv b/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/train_sample100.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..8b50a1d2f1e06553881ec3352bee2e6360814635
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/asr/train_sample100.tsv
@@ -0,0 +1,101 @@
+/LocalData/dataset/LibriSpeech/train-clean-100
+103/1240/103-1240-0000.flac 225360
+103/1240/103-1240-0001.flac 255120
+103/1240/103-1240-0002.flac 223120
+103/1240/103-1240-0003.flac 235360
+103/1240/103-1240-0004.flac 200240
+103/1240/103-1240-0005.flac 242800
+103/1240/103-1240-0006.flac 153280
+103/1240/103-1240-0007.flac 240560
+103/1240/103-1240-0008.flac 246960
+103/1240/103-1240-0009.flac 160480
+103/1240/103-1240-0010.flac 236880
+103/1240/103-1240-0011.flac 234480
+103/1240/103-1240-0012.flac 243040
+103/1240/103-1240-0013.flac 244160
+103/1240/103-1240-0014.flac 223360
+103/1240/103-1240-0015.flac 60960
+103/1240/103-1240-0016.flac 250640
+103/1240/103-1240-0017.flac 229040
+103/1240/103-1240-0018.flac 185760
+103/1240/103-1240-0019.flac 246480
+103/1240/103-1240-0020.flac 214640
+103/1240/103-1240-0021.flac 236960
+103/1240/103-1240-0022.flac 262000
+103/1240/103-1240-0023.flac 194400
+103/1240/103-1240-0024.flac 244320
+103/1240/103-1240-0025.flac 241920
+103/1240/103-1240-0026.flac 133360
+103/1240/103-1240-0027.flac 223440
+103/1240/103-1240-0028.flac 250400
+103/1240/103-1240-0029.flac 244320
+103/1240/103-1240-0030.flac 232320
+103/1240/103-1240-0031.flac 269760
+103/1240/103-1240-0032.flac 236400
+103/1240/103-1240-0033.flac 230640
+103/1240/103-1240-0034.flac 246480
+103/1240/103-1240-0035.flac 256720
+103/1240/103-1240-0036.flac 200320
+103/1240/103-1240-0037.flac 237040
+103/1240/103-1240-0038.flac 114480
+103/1240/103-1240-0039.flac 230800
+103/1240/103-1240-0040.flac 234720
+103/1240/103-1240-0041.flac 216160
+103/1240/103-1240-0042.flac 249680
+103/1240/103-1240-0043.flac 236160
+103/1240/103-1240-0044.flac 262240
+103/1240/103-1240-0045.flac 250800
+103/1240/103-1240-0046.flac 222800
+103/1240/103-1240-0047.flac 206320
+103/1240/103-1240-0048.flac 236320
+103/1240/103-1240-0049.flac 244560
+103/1240/103-1240-0050.flac 224400
+103/1240/103-1240-0051.flac 245760
+103/1240/103-1240-0052.flac 236640
+103/1240/103-1240-0053.flac 218640
+103/1240/103-1240-0054.flac 261360
+103/1240/103-1240-0055.flac 179920
+103/1240/103-1240-0056.flac 229040
+103/1240/103-1240-0057.flac 109680
+103/1241/103-1241-0000.flac 255440
+103/1241/103-1241-0001.flac 248800
+103/1241/103-1241-0002.flac 249040
+103/1241/103-1241-0003.flac 222160
+103/1241/103-1241-0004.flac 236080
+103/1241/103-1241-0005.flac 224400
+103/1241/103-1241-0006.flac 243760
+103/1241/103-1241-0007.flac 242320
+103/1241/103-1241-0008.flac 242160
+103/1241/103-1241-0009.flac 222400
+103/1241/103-1241-0010.flac 253920
+103/1241/103-1241-0011.flac 231760
+103/1241/103-1241-0012.flac 239680
+103/1241/103-1241-0013.flac 236960
+103/1241/103-1241-0014.flac 242080
+103/1241/103-1241-0015.flac 224160
+103/1241/103-1241-0016.flac 234640
+103/1241/103-1241-0017.flac 254240
+103/1241/103-1241-0018.flac 150960
+103/1241/103-1241-0019.flac 48400
+103/1241/103-1241-0020.flac 155360
+103/1241/103-1241-0021.flac 242880
+103/1241/103-1241-0022.flac 261600
+103/1241/103-1241-0023.flac 266720
+103/1241/103-1241-0024.flac 254240
+103/1241/103-1241-0025.flac 77280
+103/1241/103-1241-0026.flac 176080
+103/1241/103-1241-0027.flac 238080
+103/1241/103-1241-0028.flac 248880
+103/1241/103-1241-0029.flac 244960
+103/1241/103-1241-0030.flac 247520
+103/1241/103-1241-0031.flac 209600
+103/1241/103-1241-0032.flac 224080
+103/1241/103-1241-0033.flac 251920
+103/1241/103-1241-0034.flac 270560
+103/1241/103-1241-0035.flac 248800
+103/1241/103-1241-0036.flac 249040
+103/1241/103-1241-0037.flac 204400
+103/1241/103-1241-0038.flac 238960
+103/1241/103-1241-0039.flac 258160
+103/1241/103-1241-0040.flac 220560
+103/1241/103-1241-0041.flac 252240
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/config.yaml b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eaec2ce8655ebfa043cf73a2ee2d85ac5bcdfb21
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/config.yaml
@@ -0,0 +1,13 @@
+audio_root: /home/v-ziqzhang/dataset/librispeech_phone2unit
+features:
+ energy_max: 5.733445167541504
+ energy_min: 1.0e-08
+ eps: 1.0e-05
+ hop_length: 256
+ pitch_max: 6.608609099713706
+ pitch_min: 1.0e-08
+ sample_rate: 16000
+sample_rate: 16000
+vocab_filename: dict.km.txt
+src_vocab_filename: dict.phn.txt
+
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/config_generate.yaml b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/config_generate.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1d9fa74529728fe81f41edd55689f43f6ae2da83
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/config_generate.yaml
@@ -0,0 +1,13 @@
+audio_root: /home/v-ziqzhang/dataset/librispeech_phone2unit
+features:
+ energy_max: 5.733445167541504
+ energy_min: 1.0e-08
+ eps: 1.0e-05
+ hop_length: 256
+ pitch_max: 6.608609099713706
+ pitch_min: 1.0e-08
+ sample_rate: 16000
+sample_rate: 16000
+vocab_filename: dict.km.txt
+src_vocab_filename: dict.PHN.txt
+
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/dict.PHN.txt b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/dict.PHN.txt
new file mode 100644
index 0000000000000000000000000000000000000000..60232ecf55c10e9ab673168262af28951ecbec2f
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/dict.PHN.txt
@@ -0,0 +1,42 @@
+| 0
+ 1
+' 2
+AA 3
+AE 4
+AH 5
+AO 6
+AW 7
+AY 8
+B 9
+CH 10
+D 11
+DH 12
+EH 13
+ER 14
+EY 15
+F 16
+G 17
+HH 18
+IH 19
+IY 20
+JH 21
+K 22
+L 23
+M 24
+N 25
+NG 26
+OW 27
+OY 28
+P 29
+R 30
+S 31
+SH 32
+T 33
+TH 34
+UH 35
+UW 36
+V 37
+W 38
+Y 39
+Z 40
+ZH 41
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/dict.km.txt b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/dict.km.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bbfe59e554d6234f3631d8d09d9281c2160f4675
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/dict.km.txt
@@ -0,0 +1,500 @@
+0 0
+1 1
+2 2
+3 3
+4 4
+5 5
+6 6
+7 7
+8 8
+9 9
+10 10
+11 11
+12 12
+13 13
+14 14
+15 15
+16 16
+17 17
+18 18
+19 19
+20 20
+21 21
+22 22
+23 23
+24 24
+25 25
+26 26
+27 27
+28 28
+29 29
+30 30
+31 31
+32 32
+33 33
+34 34
+35 35
+36 36
+37 37
+38 38
+39 39
+40 40
+41 41
+42 42
+43 43
+44 44
+45 45
+46 46
+47 47
+48 48
+49 49
+50 50
+51 51
+52 52
+53 53
+54 54
+55 55
+56 56
+57 57
+58 58
+59 59
+60 60
+61 61
+62 62
+63 63
+64 64
+65 65
+66 66
+67 67
+68 68
+69 69
+70 70
+71 71
+72 72
+73 73
+74 74
+75 75
+76 76
+77 77
+78 78
+79 79
+80 80
+81 81
+82 82
+83 83
+84 84
+85 85
+86 86
+87 87
+88 88
+89 89
+90 90
+91 91
+92 92
+93 93
+94 94
+95 95
+96 96
+97 97
+98 98
+99 99
+100 100
+101 101
+102 102
+103 103
+104 104
+105 105
+106 106
+107 107
+108 108
+109 109
+110 110
+111 111
+112 112
+113 113
+114 114
+115 115
+116 116
+117 117
+118 118
+119 119
+120 120
+121 121
+122 122
+123 123
+124 124
+125 125
+126 126
+127 127
+128 128
+129 129
+130 130
+131 131
+132 132
+133 133
+134 134
+135 135
+136 136
+137 137
+138 138
+139 139
+140 140
+141 141
+142 142
+143 143
+144 144
+145 145
+146 146
+147 147
+148 148
+149 149
+150 150
+151 151
+152 152
+153 153
+154 154
+155 155
+156 156
+157 157
+158 158
+159 159
+160 160
+161 161
+162 162
+163 163
+164 164
+165 165
+166 166
+167 167
+168 168
+169 169
+170 170
+171 171
+172 172
+173 173
+174 174
+175 175
+176 176
+177 177
+178 178
+179 179
+180 180
+181 181
+182 182
+183 183
+184 184
+185 185
+186 186
+187 187
+188 188
+189 189
+190 190
+191 191
+192 192
+193 193
+194 194
+195 195
+196 196
+197 197
+198 198
+199 199
+200 200
+201 201
+202 202
+203 203
+204 204
+205 205
+206 206
+207 207
+208 208
+209 209
+210 210
+211 211
+212 212
+213 213
+214 214
+215 215
+216 216
+217 217
+218 218
+219 219
+220 220
+221 221
+222 222
+223 223
+224 224
+225 225
+226 226
+227 227
+228 228
+229 229
+230 230
+231 231
+232 232
+233 233
+234 234
+235 235
+236 236
+237 237
+238 238
+239 239
+240 240
+241 241
+242 242
+243 243
+244 244
+245 245
+246 246
+247 247
+248 248
+249 249
+250 250
+251 251
+252 252
+253 253
+254 254
+255 255
+256 256
+257 257
+258 258
+259 259
+260 260
+261 261
+262 262
+263 263
+264 264
+265 265
+266 266
+267 267
+268 268
+269 269
+270 270
+271 271
+272 272
+273 273
+274 274
+275 275
+276 276
+277 277
+278 278
+279 279
+280 280
+281 281
+282 282
+283 283
+284 284
+285 285
+286 286
+287 287
+288 288
+289 289
+290 290
+291 291
+292 292
+293 293
+294 294
+295 295
+296 296
+297 297
+298 298
+299 299
+300 300
+301 301
+302 302
+303 303
+304 304
+305 305
+306 306
+307 307
+308 308
+309 309
+310 310
+311 311
+312 312
+313 313
+314 314
+315 315
+316 316
+317 317
+318 318
+319 319
+320 320
+321 321
+322 322
+323 323
+324 324
+325 325
+326 326
+327 327
+328 328
+329 329
+330 330
+331 331
+332 332
+333 333
+334 334
+335 335
+336 336
+337 337
+338 338
+339 339
+340 340
+341 341
+342 342
+343 343
+344 344
+345 345
+346 346
+347 347
+348 348
+349 349
+350 350
+351 351
+352 352
+353 353
+354 354
+355 355
+356 356
+357 357
+358 358
+359 359
+360 360
+361 361
+362 362
+363 363
+364 364
+365 365
+366 366
+367 367
+368 368
+369 369
+370 370
+371 371
+372 372
+373 373
+374 374
+375 375
+376 376
+377 377
+378 378
+379 379
+380 380
+381 381
+382 382
+383 383
+384 384
+385 385
+386 386
+387 387
+388 388
+389 389
+390 390
+391 391
+392 392
+393 393
+394 394
+395 395
+396 396
+397 397
+398 398
+399 399
+400 400
+401 401
+402 402
+403 403
+404 404
+405 405
+406 406
+407 407
+408 408
+409 409
+410 410
+411 411
+412 412
+413 413
+414 414
+415 415
+416 416
+417 417
+418 418
+419 419
+420 420
+421 421
+422 422
+423 423
+424 424
+425 425
+426 426
+427 427
+428 428
+429 429
+430 430
+431 431
+432 432
+433 433
+434 434
+435 435
+436 436
+437 437
+438 438
+439 439
+440 440
+441 441
+442 442
+443 443
+444 444
+445 445
+446 446
+447 447
+448 448
+449 449
+450 450
+451 451
+452 452
+453 453
+454 454
+455 455
+456 456
+457 457
+458 458
+459 459
+460 460
+461 461
+462 462
+463 463
+464 464
+465 465
+466 466
+467 467
+468 468
+469 469
+470 470
+471 471
+472 472
+473 473
+474 474
+475 475
+476 476
+477 477
+478 478
+479 479
+480 480
+481 481
+482 482
+483 483
+484 484
+485 485
+486 486
+487 487
+488 488
+489 489
+490 490
+491 491
+492 492
+493 493
+494 494
+495 495
+496 496
+497 497
+498 498
+499 499
diff --git a/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/genset_examples.tsv b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/genset_examples.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..fe4a9a1b21a77835afaacc936f59963e8ed0090c
--- /dev/null
+++ b/SpeechT5/SpeechLM/dataset/LibriSpeech/fast_phone2unit/genset_examples.tsv
@@ -0,0 +1,101 @@
+id speaker n_frames tgt_text unit
+librilm-9899 librilm 323 AH B EH T R OW TH AH L AH W EH D IH NG AO R AH K R IH S AH N IH NG AO R DH AH M IH R P R AA K S IH M AH T IY AH V AH G IH T AA R IH Z S AH F IH SH AH N T AH K EY ZH AH N AH N D IH F DH AH AH K EY ZH AH N L AE K S S EH N D F AO R DH AH G IH T AA R AH N D D AE N S EH N IY W EY 0
+librilm-9900 librilm 449 AH B EH T R OW TH AH L B R OW K AH N AO F W AA Z AH TH IH NG DH AE T HH AE D HH AE P AH N D B IH F AO R AH N D M AY T HH AE P AH N AH G EH N IH T W AA Z AH T R AY F AH L IY V IH N AH M IH R N AH TH IH NG IH F OW N L IY HH AH N ER W ER AH N T AH CH T B AY IH T IH F OW N L IY AA T AH M AA R K UH D S T EY K HH IH Z L AY F AH P AA N HH IH Z AH N IH M P IY CH AH B AH L HH AH N ER 0
+librilm-9901 librilm 211 AH B EH T R OW TH AH L S EH R AH M OW N IY AH K AO R D IH NG L IY IH Z DH AH IH M IY D IY AH T AH K EY ZH AH N AH V DH AH K AH M IH NG T AH G EH DH ER AH V AW ER AH K W EY N T AH N S IH Z 0
+librilm-9902 librilm 45 AH B EH T R OW TH AH L D EY 0
+librilm-9903 librilm 141 AH B EH T R OW TH AH L HH IY R IH Z AO L M OW S T AE Z B AY N D IH NG AH N D K W AY T AE Z S AA L AH M AE Z AH M EH R IH JH 0
+librilm-9904 librilm 59 AH B EH T R OW TH AH L IH Z S EY K R AH D 0
+librilm-9905 librilm 79 AH B EH T R OW TH AH L IH Z S AH M TH IH NG HH EH V AH N L IY 0
+librilm-9906 librilm 225 AH B EH T R OW TH AH L R IH NG W AA Z P ER CH AH S T AH N D DH EH N HH ER K AA N SH AH N S B IY IH NG AH P IY Z D SH IY G EY V HH ER S EH L F K AH M P L IY T L IY T UW HH ER L AH V ER 0
+librilm-9907 librilm 288 AH B EH T R OW TH AH L T UH K P L EY S AO L W AA Z HH AA R M AH N IY AH N D F AO R AH T AY M N OW M AO R W AA Z S EH D AH V D IH S IH N ER SH IH T IH NG M AE D AH M D IY L AA P EH L T R IY AO R P AH T IH NG HH ER IH N W AO R D SH IH P 0
+librilm-9908 librilm 139 AH B EH T R OW TH AH L W IY HH AE V B IH N T UW AH B AO L AH V W IH CH AY M AH S T G IH V Y UW AH D IH S K R IH P SH AH N 0
+librilm-9909 librilm 491 AH B EH T R OW TH AH L W IH CH HH AE D T EY K AH N P L EY S AA N DH AH P R IY V IY AH S IY V N IH NG G EY V K AA Z F AO R P L EH N T AH F AH L SH R AH G IH NG AH V SH OW L D ER Z B IH K AO Z DH AH JH EH N T AH L M AH N AE Z Y EH T HH EH L D N OW R IH S P EH K T AH B AH L P AH Z IH SH AH N IH N L AY F AH N D DH AH F IY AE N S IY AE Z S EH V R AH L F IY M EY L F R EH N D Z AH S ER T AH D AH V EH R IY AH N S ER T AH N W AH N 0
+librilm-9910 librilm 269 AH B EH T R OW TH AH L W IH TH AW T AH N D IY V IH N AH G EH N S T DH AH K AH N S EH N T AH V P EH R AH N T S W AA Z S AH M TH IH NG K W AY T AW T S AY D AH V DH AH Y AH NG L EY D IY Z P AW ER AH V K AA M P R IY HH EH N SH AH N 0
+librilm-9911 librilm 29 AH B EH T R AH TH 0
+librilm-9912 librilm 127 AH B EH T R AH TH B R AY D AO T T UW B IY HH AE P IY Y UW AA R AO L W EY Z T EH L IH NG M IY S OW 0
+librilm-9913 librilm 108 AH B EH T R AH TH G ER L IH Z W AH N TH IH NG AH W AY F K W AY T AH N AH DH ER 0
+librilm-9914 librilm 168 AH B EH T R AH TH L AH V ER K AE N AA T T AA L ER EY T EH N IY AH S P ER ZH AH N K AE S T AH P AA N DH AH F EH R S EH K S S EH D JH AO R JH AH Z 0
+librilm-9915 librilm 61 AH B EH T R AH TH L AH V ER Z F EH R W EH L 0
+librilm-9916 librilm 335 AH B EH T R AH TH Y AH NG M AE N AO R HH IH Z F IY M EY L R EH L AH T IH V Z AH S IH S T IH NG HH IH M W AA Z AH K AH S T AH M D T UW M EY K AH P R EH Z AH N T AH V W AH N AO R M AO R P EH T IY K OW T S T UW HH IH Z S W IY T HH AA R T T UW IH N K R IY S HH ER W AO R D R OW B 0
+librilm-9917 librilm 24 AH B EH T ER 0
+librilm-9918 librilm 182 AH B EH T ER AH F AY N ER AH N OW B L ER F EH L OW DH AE N HH IY N EH V ER L AY V D AH N D DH AE T S W AH T AY W AA N T Y UW