diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..bf07816c74bac9b682df196e02c6482e474e9b52 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,31 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dea9b08cbd468d29f58606cfa7e7c3e88d3cb696 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +*.pyc +*.aux +*.log +*.out +*.synctex.gz +*.suo +*__pycache__ +*.idea +*.ipynb_checkpoints +*.pickle +*.npy +*.blg +*.bbl +*.bcf +*.toc +*.sh +wavs +log \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..77b56791b7e900eb0a7d8a258bf78ecb449c8468 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,130 @@ +# Contributor Covenant Code of Conduct +## First of all +Don't be evil, never + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +babysor00@gmail.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ed721bf8f29f5c8d947c2d333cc371021135fb0 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,24 @@ +MIT License + +Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ) +Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah) +Original work Copyright (c) 2019 fatchord (https://github.com/fatchord) +Original work Copyright (c) 2015 braindead (https://github.com/braindead) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README-CN.md b/README-CN.md new file mode 100644 index 0000000000000000000000000000000000000000..738b37f21a840026f64fd5bf699b013f459108a4 --- /dev/null +++ b/README-CN.md @@ -0,0 +1,230 @@ +## 实时语音克隆 - 中文/普通话 +![mockingbird](https://user-images.githubusercontent.com/12797292/131216767-6eb251d6-14fc-4951-8324-2722f0cd4c63.jpg) + +[![MIT License](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](http://choosealicense.com/licenses/mit/) + +### [English](README.md) | 中文 + +### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wiki教程](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) | [训练教程](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd) + +## 特性 +🌍 **中文** 支持普通话并使用多种中文数据集进行测试:aidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell 等 + +🤩 **PyTorch** 适用于 pytorch,已在 1.9.0 版本(最新于 2021 年 8 月)中测试,GPU Tesla T4 和 GTX 2060 + +🌍 **Windows + Linux** 可在 Windows 操作系统和 linux 操作系统中运行(苹果系统M1版也有社区成功运行案例) + +🤩 **Easy & Awesome** 仅需下载或新训练合成器(synthesizer)就有良好效果,复用预训练的编码器/声码器,或实时的HiFi-GAN作为vocoder + +🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用 + +### 进行中的工作 +* GUI/客户端大升级与合并 +[X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee) +[X] 增加 Voice Cloning and Conversion的演示页面 +[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面 +[ ] 增加其他的的预处理preprocessing 和训练 training 页面 +* 模型后端基于ESPnet2升级 + + +## 开始 +### 1. 安装要求 +> 按照原始存储库测试您是否已准备好所有环境。 +运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本** 。 + +* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。 +> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功 +* 安装 [ffmpeg](https://ffmpeg.org/download.html#get-packages)。 +* 运行`pip install -r requirements.txt` 来安装剩余的必要包。 +* 安装 webrtcvad `pip install webrtcvad-wheels`。 + +### 2. 准备预训练模型 +考虑训练您自己专属的模型或者下载社区他人训练好的模型: +> 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得,也欢迎提问 +#### 2.1 使用数据集自己训练encoder模型 (可选) + +* 进行音频和梅尔频谱图预处理: +`python encoder_preprocess.py ` +使用`-d {dataset}` 指定数据集,支持 librispeech_other,voxceleb1,aidatatang_200zh,使用逗号分割处理多数据集。 +* 训练encoder: `python encoder_train.py my_run /SV2TTS/encoder` +> 训练encoder使用了visdom。你可以加上`-no_visdom`禁用visdom,但是有可视化会更好。在单独的命令行/进程中运行"visdom"来启动visdom服务器。 + +#### 2.2 使用数据集自己训练合成器模型(与2.3二选一) +* 下载 数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav) +* 进行音频和梅尔频谱图预处理: +`python pre.py -d {dataset} -n {number}` +可传入参数: +* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, magicdata, aishell3, data_aishell, 不传默认为aidatatang_200zh +* `-n {number}` 指定并行数,CPU 11770k + 32GB实测10没有问题 +> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\` + +* 训练合成器: +`python synthesizer_train.py mandarin /SV2TTS/synthesizer` + +* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。 + +#### 2.3使用社区预先训练好的合成器(与2.2二选一) +> 当实在没有设备或者不想慢慢调试,可以使用社区贡献的模型(欢迎持续分享): + +| 作者 | 下载链接 | 效果预览 | 信息 | +| --- | ----------- | ----- | ----- | +| 作者 | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [百度盘链接](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps 用3个开源数据集混合训练 +| 作者 | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [百度盘链接](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) 提取码:om7f | | 25k steps 用3个开源数据集混合训练, 切换到tag v0.0.1使用 +|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音需切换到tag v0.0.1使用 +|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 注意:根据[issue](https://github.com/babysor/MockingBird/issues/37)修复 并切换到tag v0.0.1使用 + +#### 2.4训练声码器 (可选) +对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。 +* 预处理数据: +`python vocoder_preprocess.py -m ` +> ``替换为你的数据集目录,``替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_models\xxx* + + +* 训练wavernn声码器: +`python vocoder_train.py ` +> ``替换为你想要的标识,同一标识再次训练时会延续原模型 + +* 训练hifigan声码器: +`python vocoder_train.py hifigan` +> ``替换为你想要的标识,同一标识再次训练时会延续原模型 +* 训练fregan声码器: +`python vocoder_train.py --config config.json fregan` +> ``替换为你想要的标识,同一标识再次训练时会延续原模型 +* 将GAN声码器的训练切换为多GPU模式:修改GAN文件夹下.json文件中的"num_gpus"参数 +### 3. 启动程序或工具箱 +您可以尝试使用以下命令: + +### 3.1 启动Web程序(v2): +`python web.py` +运行成功后在浏览器打开地址, 默认为 `http://localhost:8080` +> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒 + +### 3.2 启动工具箱: +`python demo_toolbox.py -d ` +> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。 + +d48ea37adf3660e657cfb047c10edbc + +### 4. 番外:语音转换Voice Conversion(PPG based) +想像柯南拿着变声器然后发出毛利小五郎的声音吗?本项目现基于PPG-VC,引入额外两个模块(PPG extractor + PPG2Mel), 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中) +#### 4.0 准备环境 +* 确保项目以上环境已经安装ok,运行`pip install espnet` 来安装剩余的必要包。 +* 下载以下模型 链接:https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg +提取码:gh41 + * 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_models\xxx* + * 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_models\xxx* + * 预训练的PPG2Mel到 *ppg2mel\saved_models\xxx* + +#### 4.1 使用数据集自己训练PPG2Mel模型 (可选) + +* 下载aidatatang_200zh数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav) +* 进行音频和梅尔频谱图预处理: +`python pre4ppg.py -d {dataset} -n {number}` +可传入参数: +* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, 不传默认为aidatatang_200zh +* `-n {number}` 指定并行数,CPU 11770k在8的情况下,需要运行12到18小时!待优化 +> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\` + +* 训练合成器, 注意在上一步先下载好`ppg2mel.yaml`, 修改里面的地址指向预训练好的文件夹: +`python ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc ` +* 如果想要继续上一次的训练,可以通过`--load .\ppg2mel\saved_models\` 参数指定一个预训练模型文件。 + +#### 4.2 启动工具箱VC模式 +您可以尝试使用以下命令: +`python demo_toolbox.py -vc -d ` +> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。 +微信图片_20220305005351 + +## 引用及论文 +> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。 + +| URL | Designation | 标题 | 实现源码 | +| --- | ----------- | ----- | --------------------- | +| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 | +| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 | +| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | 本代码库 | +|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 | +|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) | +|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) +|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 | + +## 常見問題(FQ&A) +#### 1.數據集哪裡下載? +| 数据集 | OpenSLR地址 | 其他源 (Google Drive, Baidu网盘等) | +| --- | ----------- | ---------------| +| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) | +| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) | +| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) | +| data_aishell | [OpenSLR](https://www.openslr.org/33/) | | +> 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮 + +#### 2.``是什麼意思? +假如數據集路徑為 `D:\data\aidatatang_200zh`,那麼 ``就是 `D:\data` + +#### 3.訓練模型顯存不足 +訓練合成器時:將 `synthesizer/hparams.py`中的batch_size參數調小 +``` +//調整前 +tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule + (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size) + (2, 2e-4, 80_000, 12), # + (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames + (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration) + (2, 1e-5, 640_000, 12)], # lr = learning rate +//調整後 +tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule + (2, 5e-4, 40_000, 8), # (r, lr, step, batch_size) + (2, 2e-4, 80_000, 8), # + (2, 1e-4, 160_000, 8), # r = reduction factor (# of mel frames + (2, 3e-5, 320_000, 8), # synthesized for each decoder iteration) + (2, 1e-5, 640_000, 8)], # lr = learning rate +``` + +聲碼器-預處理數據集時:將 `synthesizer/hparams.py`中的batch_size參數調小 +``` +//調整前 +### Data Preprocessing + max_mel_frames = 900, + rescale = True, + rescaling_max = 0.9, + synthesis_batch_size = 16, # For vocoder preprocessing and inference. +//調整後 +### Data Preprocessing + max_mel_frames = 900, + rescale = True, + rescaling_max = 0.9, + synthesis_batch_size = 8, # For vocoder preprocessing and inference. +``` + +聲碼器-訓練聲碼器時:將 `vocoder/wavernn/hparams.py`中的batch_size參數調小 +``` +//調整前 +# Training +voc_batch_size = 100 +voc_lr = 1e-4 +voc_gen_at_checkpoint = 5 +voc_pad = 2 + +//調整後 +# Training +voc_batch_size = 6 +voc_lr = 1e-4 +voc_gen_at_checkpoint = 5 +voc_pad =2 +``` + +#### 4.碰到`RuntimeError: Error(s) in loading state_dict for Tacotron: size mismatch for encoder.embedding.weight: copying a param with shape torch.Size([70, 512]) from checkpoint, the shape in current model is torch.Size([75, 512]).` +請參照 issue [#37](https://github.com/babysor/MockingBird/issues/37) + +#### 5.如何改善CPU、GPU佔用率? +適情況調整batch_size參數來改善 + +#### 6.發生 `頁面文件太小,無法完成操作` +請參考這篇[文章](https://blog.csdn.net/qq_17755303/article/details/112564030),將虛擬內存更改為100G(102400),例如:档案放置D槽就更改D槽的虚拟内存 + +#### 7.什么时候算训练完成? +首先一定要出现注意力模型,其次是loss足够低,取决于硬件设备和数据集。拿本人的供参考,我的注意力是在 18k 步之后出现的,并且在 50k 步之后损失变得低于 0.4 +![attention_step_20500_sample_1](https://user-images.githubusercontent.com/7423248/128587252-f669f05a-f411-4811-8784-222156ea5e9d.png) + +![step-135500-mel-spectrogram_sample_1](https://user-images.githubusercontent.com/7423248/128587255-4945faa0-5517-46ea-b173-928eff999330.png) + diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..92735e994857265333218adbbf2d3ea1980baa3d --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +--- +title: MockingBird +emoji: 🔥 +colorFrom: red +colorTo: red +sdk: gradio +sdk_version: 3.1.3 +app_file: app.py +pinned: false +license: mit +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1a41430fac74af8193b0d42728291b319c5011 --- /dev/null +++ b/app.py @@ -0,0 +1,80 @@ + +import gradio as gr + +import re +import random +import string +import librosa +import numpy as np + +from pathlib import Path +from scipy.io.wavfile import write + +from encoder import inference as encoder +from vocoder.hifigan import inference as gan_vocoder +from synthesizer.inference import Synthesizer + +class Mandarin: + def __init__(self): + self.encoder_path = "encoder/saved_models/pretrained.pt" + self.vocoder_path = "vocoder/saved_models/pretrained/g_hifigan.pt" + self.config_fpath = "vocoder/hifigan/config_16k_.json" + self.accent = "synthesizer/saved_models/普通话.pt" + + synthesizers_cache = {} + if synthesizers_cache.get(self.accent) is None: + self.current_synt = Synthesizer(Path(self.accent)) + synthesizers_cache[self.accent] = self.current_synt + else: + self.current_synt = synthesizers_cache[self.accent] + + encoder.load_model(Path(self.encoder_path)) + gan_vocoder.load_model(Path(self.vocoder_path), self.config_fpath) + + def setVoice(self, timbre): + self.timbre = timbre + wav, sample_rate, = librosa.load(self.timbre) + + encoder_wav = encoder.preprocess_wav(wav, sample_rate) + self.embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + def say(self, text): + texts = filter(None, text.split("\n")) + punctuation = "!,。、?!,.?::" # punctuate and split/clean text + processed_texts = [] + for text in texts: + for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'): + if processed_text: + processed_texts.append(processed_text.strip()) + texts = processed_texts + embeds = [self.embed] * len(texts) + + specs = self.current_synt.synthesize_spectrograms(texts, embeds) + spec = np.concatenate(specs, axis=1) + wav, sample_rate = gan_vocoder.infer_waveform(spec) + + return wav, sample_rate + +def greet(audio, text, voice=None): + + if voice is None: + voice = Mandarin() + voice.setVoice(audio.name) + voice.say("加载成功") + wav, sample_rate = voice.say(text) + + output_file = "".join( random.sample(string.ascii_lowercase + string.digits, 11) ) + ".wav" + + write(output_file, sample_rate, wav.astype(np.float32)) + + return output_file, voice + +def main(): + gr.Interface( + fn=greet, + inputs=[gr.inputs.Audio(type="file"),"text", "state"], + outputs=[gr.outputs.Audio(type="file"), "state"] + ).launch() + +if __name__=="__main__": + main() diff --git a/demo_toolbox.py b/demo_toolbox.py new file mode 100644 index 0000000000000000000000000000000000000000..7030bd5a1d57647061064aa91c734e2f496e9b83 --- /dev/null +++ b/demo_toolbox.py @@ -0,0 +1,49 @@ +from pathlib import Path +from toolbox import Toolbox +from utils.argutils import print_args +from utils.modelutils import check_model_paths +import argparse +import os + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Runs the toolbox", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("-d", "--datasets_root", type=Path, help= \ + "Path to the directory containing your datasets. See toolbox/__init__.py for a list of " + "supported datasets.", default=None) + parser.add_argument("-vc", "--vc_mode", action="store_true", + help="Voice Conversion Mode(PPG based)") + parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models", + help="Directory containing saved encoder models") + parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models", + help="Directory containing saved synthesizer models") + parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models", + help="Directory containing saved vocoder models") + parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models", + help="Directory containing saved extrator models") + parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models", + help="Directory containing saved convert models") + parser.add_argument("--cpu", action="store_true", help=\ + "If True, processing is done on CPU, even when a GPU is available.") + parser.add_argument("--seed", type=int, default=None, help=\ + "Optional random number seed value to make toolbox deterministic.") + parser.add_argument("--no_mp3_support", action="store_true", help=\ + "If True, no mp3 files are allowed.") + args = parser.parse_args() + print_args(args, parser) + + if args.cpu: + # Hide GPUs from Pytorch to force CPU processing + os.environ["CUDA_VISIBLE_DEVICES"] = "" + del args.cpu + + ## Remind the user to download pretrained models if needed + check_model_paths(encoder_path=args.enc_models_dir, synthesizer_path=args.syn_models_dir, + vocoder_path=args.voc_models_dir) + + # Launch the toolbox + Toolbox(**vars(args)) diff --git a/encoder/__init__.py b/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder/audio.py b/encoder/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..5c087eea5d23d7e23ea0ef277ea4b92e9f4f2d55 --- /dev/null +++ b/encoder/audio.py @@ -0,0 +1,117 @@ +from scipy.ndimage.morphology import binary_dilation +from encoder.params_data import * +from pathlib import Path +from typing import Optional, Union +from warnings import warn +import numpy as np +import librosa +import struct + +try: + import webrtcvad +except: + warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.") + webrtcvad=None + +int16_max = (2 ** 15) - 1 + + +def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], + source_sr: Optional[int] = None, + normalize: Optional[bool] = True, + trim_silence: Optional[bool] = True): + """ + Applies the preprocessing operations used in training the Speaker Encoder to a waveform + either on disk or in memory. The waveform will be resampled to match the data hyperparameters. + + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + just .wav), either the waveform as a numpy array of floats. + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + this argument will be ignored. + """ + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(str(fpath_or_wav), sr=None) + else: + wav = fpath_or_wav + + # Resample the wav if needed + if source_sr is not None and source_sr != sampling_rate: + wav = librosa.resample(wav, source_sr, sampling_rate) + + # Apply the preprocessing: normalize volume and shorten long silences + if normalize: + wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) + if webrtcvad and trim_silence: + wav = trim_long_silences(wav) + + return wav + + +def wav_to_mel_spectrogram(wav): + """ + Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. + Note: this not a log-mel spectrogram. + """ + frames = librosa.feature.melspectrogram( + y=wav, + sr=sampling_rate, + n_fft=int(sampling_rate * mel_window_length / 1000), + hop_length=int(sampling_rate * mel_window_step / 1000), + n_mels=mel_n_channels + ) + return frames.astype(np.float32).T + + +def trim_long_silences(wav): + """ + Ensures that segments without voice in the waveform remain no longer than a + threshold determined by the VAD parameters in params.py. + + :param wav: the raw waveform as a numpy array of floats + :return: the same waveform with silences trimmed away (length <= original wav length) + """ + # Compute the voice detection window size + samples_per_window = (vad_window_length * sampling_rate) // 1000 + + # Trim the end of the audio to have a multiple of the window size + wav = wav[:len(wav) - (len(wav) % samples_per_window)] + + # Convert the float waveform to 16-bit mono PCM + pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) + + # Perform voice activation detection + voice_flags = [] + vad = webrtcvad.Vad(mode=3) + for window_start in range(0, len(wav), samples_per_window): + window_end = window_start + samples_per_window + voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], + sample_rate=sampling_rate)) + voice_flags = np.array(voice_flags) + + # Smooth the voice detection with a moving average + def moving_average(array, width): + array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) + ret = np.cumsum(array_padded, dtype=float) + ret[width:] = ret[width:] - ret[:-width] + return ret[width - 1:] / width + + audio_mask = moving_average(voice_flags, vad_moving_average_width) + audio_mask = np.round(audio_mask).astype(np.bool) + + # Dilate the voiced regions + audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) + audio_mask = np.repeat(audio_mask, samples_per_window) + + return wav[audio_mask == True] + + +def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) + if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): + return wav + return wav * (10 ** (dBFS_change / 20)) diff --git a/encoder/config.py b/encoder/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c21312f3de971bfa008254c6035cebc09f05e4c --- /dev/null +++ b/encoder/config.py @@ -0,0 +1,45 @@ +librispeech_datasets = { + "train": { + "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], + "other": ["LibriSpeech/train-other-500"] + }, + "test": { + "clean": ["LibriSpeech/test-clean"], + "other": ["LibriSpeech/test-other"] + }, + "dev": { + "clean": ["LibriSpeech/dev-clean"], + "other": ["LibriSpeech/dev-other"] + }, +} +libritts_datasets = { + "train": { + "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], + "other": ["LibriTTS/train-other-500"] + }, + "test": { + "clean": ["LibriTTS/test-clean"], + "other": ["LibriTTS/test-other"] + }, + "dev": { + "clean": ["LibriTTS/dev-clean"], + "other": ["LibriTTS/dev-other"] + }, +} +voxceleb_datasets = { + "voxceleb1" : { + "train": ["VoxCeleb1/wav"], + "test": ["VoxCeleb1/test_wav"] + }, + "voxceleb2" : { + "train": ["VoxCeleb2/dev/aac"], + "test": ["VoxCeleb2/test_wav"] + } +} + +other_datasets = [ + "LJSpeech-1.1", + "VCTK-Corpus/wav48", +] + +anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] diff --git a/encoder/data_objects/__init__.py b/encoder/data_objects/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef04ade68544d0477a7f6deb4e7d51e97f592910 --- /dev/null +++ b/encoder/data_objects/__init__.py @@ -0,0 +1,2 @@ +from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset +from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader diff --git a/encoder/data_objects/random_cycler.py b/encoder/data_objects/random_cycler.py new file mode 100644 index 0000000000000000000000000000000000000000..c405db6b27f46d874d8feb37e3f9c1e12c251109 --- /dev/null +++ b/encoder/data_objects/random_cycler.py @@ -0,0 +1,37 @@ +import random + +class RandomCycler: + """ + Creates an internal copy of a sequence and allows access to its items in a constrained random + order. For a source sequence of n items and one or several consecutive queries of a total + of m items, the following guarantees hold (one implies the other): + - Each item will be returned between m // n and ((m - 1) // n) + 1 times. + - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. + """ + + def __init__(self, source): + if len(source) == 0: + raise Exception("Can't create RandomCycler from an empty collection") + self.all_items = list(source) + self.next_items = [] + + def sample(self, count: int): + shuffle = lambda l: random.sample(l, len(l)) + + out = [] + while count > 0: + if count >= len(self.all_items): + out.extend(shuffle(list(self.all_items))) + count -= len(self.all_items) + continue + n = min(count, len(self.next_items)) + out.extend(self.next_items[:n]) + count -= n + self.next_items = self.next_items[n:] + if len(self.next_items) == 0: + self.next_items = shuffle(list(self.all_items)) + return out + + def __next__(self): + return self.sample(1)[0] + diff --git a/encoder/data_objects/speaker.py b/encoder/data_objects/speaker.py new file mode 100644 index 0000000000000000000000000000000000000000..494e882fe34fc38dcc793ab8c74a6cc2376bb7b5 --- /dev/null +++ b/encoder/data_objects/speaker.py @@ -0,0 +1,40 @@ +from encoder.data_objects.random_cycler import RandomCycler +from encoder.data_objects.utterance import Utterance +from pathlib import Path + +# Contains the set of utterances of a single speaker +class Speaker: + def __init__(self, root: Path): + self.root = root + self.name = root.name + self.utterances = None + self.utterance_cycler = None + + def _load_utterances(self): + with self.root.joinpath("_sources.txt").open("r") as sources_file: + sources = [l.split(",") for l in sources_file] + sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} + self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()] + self.utterance_cycler = RandomCycler(self.utterances) + + def random_partial(self, count, n_frames): + """ + Samples a batch of unique partial utterances from the disk in a way that all + utterances come up at least once every two cycles and in a random order every time. + + :param count: The number of partial utterances to sample from the set of utterances from + that speaker. Utterances are guaranteed not to be repeated if is not larger than + the number of utterances available. + :param n_frames: The number of frames in the partial utterance. + :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, + frames are the frames of the partial utterances and range is the range of the partial + utterance with regard to the complete utterance. + """ + if self.utterances is None: + self._load_utterances() + + utterances = self.utterance_cycler.sample(count) + + a = [(u,) + u.random_partial(n_frames) for u in utterances] + + return a diff --git a/encoder/data_objects/speaker_batch.py b/encoder/data_objects/speaker_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..56651dba5804a0c59c334e49ac18f8f5a4bfa444 --- /dev/null +++ b/encoder/data_objects/speaker_batch.py @@ -0,0 +1,12 @@ +import numpy as np +from typing import List +from encoder.data_objects.speaker import Speaker + +class SpeakerBatch: + def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int): + self.speakers = speakers + self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers} + + # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with + # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) + self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]]) diff --git a/encoder/data_objects/speaker_verification_dataset.py b/encoder/data_objects/speaker_verification_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..77a6e05eae6a939ae7575ae70b7173644141fffe --- /dev/null +++ b/encoder/data_objects/speaker_verification_dataset.py @@ -0,0 +1,56 @@ +from encoder.data_objects.random_cycler import RandomCycler +from encoder.data_objects.speaker_batch import SpeakerBatch +from encoder.data_objects.speaker import Speaker +from encoder.params_data import partials_n_frames +from torch.utils.data import Dataset, DataLoader +from pathlib import Path + +# TODO: improve with a pool of speakers for data efficiency + +class SpeakerVerificationDataset(Dataset): + def __init__(self, datasets_root: Path): + self.root = datasets_root + speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] + if len(speaker_dirs) == 0: + raise Exception("No speakers found. Make sure you are pointing to the directory " + "containing all preprocessed speaker directories.") + self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] + self.speaker_cycler = RandomCycler(self.speakers) + + def __len__(self): + return int(1e10) + + def __getitem__(self, index): + return next(self.speaker_cycler) + + def get_logs(self): + log_string = "" + for log_fpath in self.root.glob("*.txt"): + with log_fpath.open("r") as log_file: + log_string += "".join(log_file.readlines()) + return log_string + + +class SpeakerVerificationDataLoader(DataLoader): + def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, + batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, + worker_init_fn=None): + self.utterances_per_speaker = utterances_per_speaker + + super().__init__( + dataset=dataset, + batch_size=speakers_per_batch, + shuffle=False, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=self.collate, + pin_memory=pin_memory, + drop_last=False, + timeout=timeout, + worker_init_fn=worker_init_fn + ) + + def collate(self, speakers): + return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) + \ No newline at end of file diff --git a/encoder/data_objects/utterance.py b/encoder/data_objects/utterance.py new file mode 100644 index 0000000000000000000000000000000000000000..0768c3420f422a7464f305b4c1fb6752c57ceda7 --- /dev/null +++ b/encoder/data_objects/utterance.py @@ -0,0 +1,26 @@ +import numpy as np + + +class Utterance: + def __init__(self, frames_fpath, wave_fpath): + self.frames_fpath = frames_fpath + self.wave_fpath = wave_fpath + + def get_frames(self): + return np.load(self.frames_fpath) + + def random_partial(self, n_frames): + """ + Crops the frames into a partial utterance of n_frames + + :param n_frames: The number of frames of the partial utterance + :return: the partial utterance frames and a tuple indicating the start and end of the + partial utterance in the complete utterance. + """ + frames = self.get_frames() + if frames.shape[0] == n_frames: + start = 0 + else: + start = np.random.randint(0, frames.shape[0] - n_frames) + end = start + n_frames + return frames[start:end], (start, end) \ No newline at end of file diff --git a/encoder/inference.py b/encoder/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..af9a529f13878d939054625b64413e5060b6028e --- /dev/null +++ b/encoder/inference.py @@ -0,0 +1,195 @@ +from encoder.params_data import * +from encoder.model import SpeakerEncoder +from encoder.audio import preprocess_wav # We want to expose this function from here +from matplotlib import cm +from encoder import audio +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +import torch + +_model = None # type: SpeakerEncoder +_device = None # type: torch.device + + +def load_model(weights_fpath: Path, device=None): + """ + Loads the model in memory. If this function is not explicitely called, it will be run on the + first call to embed_frames() with the default weights file. + + :param weights_fpath: the path to saved model weights. + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The + model will be loaded and will run on this device. Outputs will however always be on the cpu. + If None, will default to your GPU if it"s available, otherwise your CPU. + """ + # TODO: I think the slow loading of the encoder might have something to do with the device it + # was saved on. Worth investigating. + global _model, _device + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + _device = torch.device(device) + _model = SpeakerEncoder(_device, torch.device("cpu")) + checkpoint = torch.load(weights_fpath, _device) + _model.load_state_dict(checkpoint["model_state"]) + _model.eval() + print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) + return _model + +def set_model(model, device=None): + global _model, _device + _model = model + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _device = device + _model.to(device) + +def is_loaded(): + return _model is not None + + +def embed_frames_batch(frames_batch): + """ + Computes embeddings for a batch of mel spectrogram. + + :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) + """ + if _model is None: + raise Exception("Model was not loaded. Call load_model() before inference.") + + frames = torch.from_numpy(frames_batch).to(_device) + embed = _model.forward(frames).detach().cpu().numpy() + return embed + + +def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, + min_pad_coverage=0.75, overlap=0.5, rate=None): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain + partial utterances of each. Both the waveform and the mel + spectrogram slices are returned, so as to make each partial utterance waveform correspond to + its spectrogram. This function assumes that the mel spectrogram parameters used are those + defined in params_data.py. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wave_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial + utterance + :param min_pad_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered, as if we padded the audio. Otherwise, + it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial + utterance, this parameter is ignored so that the function always returns at least 1 slice. + :param overlap: by how much the partial utterance should overlap. If set to 0, the partial + utterances are entirely disjoint. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 <= overlap < 1 + assert 0 < min_pad_coverage <= 1 + + if rate != None: + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + else: + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) + + assert 0 < frame_step, "The rate is too high" + assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \ + (sampling_rate / (samples_per_frame * partials_n_frames)) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partial_utterance_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + if coverage < min_pad_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + +def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): + """ + Computes an embedding for a single utterance. + + # TODO: handle multiple wavs to benefit from batching on GPU + :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 + :param using_partials: if True, then the utterance is split in partial utterances of + frames and the utterance embedding is computed from their + normalized average. If False, the utterance is instead computed from feeding the entire + spectogram to the network. + :param return_partials: if True, the partial embeddings will also be returned along with the + wav slices that correspond to the partial embeddings. + :param kwargs: additional arguments to compute_partial_splits() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. If is simultaneously set to False, both these values will be None + instead. + """ + # Process the entire utterance if not using partials + if not using_partials: + frames = audio.wav_to_mel_spectrogram(wav) + embed = embed_frames_batch(frames[None, ...])[0] + if return_partials: + return embed, None, None + return embed + + # Compute where to split the utterance into partials and pad if necessary + wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) + max_wave_length = wave_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials + frames = audio.wav_to_mel_spectrogram(wav) + frames_batch = np.array([frames[s] for s in mel_slices]) + partial_embeds = embed_frames_batch(frames_batch) + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wave_slices + return embed + + +def embed_speaker(wavs, **kwargs): + raise NotImplemented() + + +def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): + if ax is None: + ax = plt.gca() + + if shape is None: + height = int(np.sqrt(len(embed))) + shape = (height, -1) + embed = embed.reshape(shape) + + cmap = cm.get_cmap() + mappable = ax.imshow(embed, cmap=cmap) + cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) + sm = cm.ScalarMappable(cmap=cmap) + sm.set_clim(*color_range) + + ax.set_xticks([]), ax.set_yticks([]) + ax.set_title(title) diff --git a/encoder/model.py b/encoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e050d3204d8f1becdf0f8b3133470708e5420cea --- /dev/null +++ b/encoder/model.py @@ -0,0 +1,135 @@ +from encoder.params_model import * +from encoder.params_data import * +from scipy.interpolate import interp1d +from sklearn.metrics import roc_curve +from torch.nn.utils import clip_grad_norm_ +from scipy.optimize import brentq +from torch import nn +import numpy as np +import torch + + +class SpeakerEncoder(nn.Module): + def __init__(self, device, loss_device): + super().__init__() + self.loss_device = loss_device + + # Network defition + self.lstm = nn.LSTM(input_size=mel_n_channels, + hidden_size=model_hidden_size, + num_layers=model_num_layers, + batch_first=True).to(device) + self.linear = nn.Linear(in_features=model_hidden_size, + out_features=model_embedding_size).to(device) + self.relu = torch.nn.ReLU().to(device) + + # Cosine similarity scaling (with fixed initial parameter values) + self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) + self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) + + # Loss + self.loss_fn = nn.CrossEntropyLoss().to(loss_device) + + def do_gradient_ops(self): + # Gradient scale + self.similarity_weight.grad *= 0.01 + self.similarity_bias.grad *= 0.01 + + # Gradient clipping + clip_grad_norm_(self.parameters(), 3, norm_type=2) + + def forward(self, utterances, hidden_init=None): + """ + Computes the embeddings of a batch of utterance spectrograms. + + :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape + (batch_size, n_frames, n_channels) + :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, + batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the embeddings as a tensor of shape (batch_size, embedding_size) + """ + # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state + # and the final cell state. + out, (hidden, cell) = self.lstm(utterances, hidden_init) + + # We take only the hidden state of the last layer + embeds_raw = self.relu(self.linear(hidden[-1])) + + # L2-normalize it + embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5) + + return embeds + + def similarity_matrix(self, embeds): + """ + Computes the similarity matrix according the section 2.1 of GE2E. + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + utterances_per_speaker, embedding_size) + :return: the similarity matrix as a tensor of shape (speakers_per_batch, + utterances_per_speaker, speakers_per_batch) + """ + speakers_per_batch, utterances_per_speaker = embeds.shape[:2] + + # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation + centroids_incl = torch.mean(embeds, dim=1, keepdim=True) + centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5) + + # Exclusive centroids (1 per utterance) + centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) + centroids_excl /= (utterances_per_speaker - 1) + centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5) + + # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot + # product of these vectors (which is just an element-wise multiplication reduced by a sum). + # We vectorize the computation for efficiency. + sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, + speakers_per_batch).to(self.loss_device) + mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) + for j in range(speakers_per_batch): + mask = np.where(mask_matrix[j])[0] + sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) + sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) + + ## Even more vectorized version (slower maybe because of transpose) + # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker + # ).to(self.loss_device) + # eye = np.eye(speakers_per_batch, dtype=np.int) + # mask = np.where(1 - eye) + # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) + # mask = np.where(eye) + # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) + # sim_matrix2 = sim_matrix2.transpose(1, 2) + + sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias + return sim_matrix + + def loss(self, embeds): + """ + Computes the softmax loss according the section 2.1 of GE2E. + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + utterances_per_speaker, embedding_size) + :return: the loss and the EER for this batch of embeddings. + """ + speakers_per_batch, utterances_per_speaker = embeds.shape[:2] + + # Loss + sim_matrix = self.similarity_matrix(embeds) + sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, + speakers_per_batch)) + ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) + target = torch.from_numpy(ground_truth).long().to(self.loss_device) + loss = self.loss_fn(sim_matrix, target) + + # EER (not backpropagated) + with torch.no_grad(): + inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] + labels = np.array([inv_argmax(i) for i in ground_truth]) + preds = sim_matrix.detach().cpu().numpy() + + # Snippet from https://yangcha.github.io/EER-ROC/ + fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) + eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) + + return loss, eer diff --git a/encoder/params_data.py b/encoder/params_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb1716ed45617f2b127a7fb8885afe6cc74fb71 --- /dev/null +++ b/encoder/params_data.py @@ -0,0 +1,29 @@ + +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms +# Number of spectrogram frames at inference +inference_n_frames = 80 # 800 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + diff --git a/encoder/params_model.py b/encoder/params_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3e356472fb5a27f370cb3920976a11d12a76c1b7 --- /dev/null +++ b/encoder/params_model.py @@ -0,0 +1,11 @@ + +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 + + +## Training parameters +learning_rate_init = 1e-4 +speakers_per_batch = 64 +utterances_per_speaker = 10 diff --git a/encoder/preprocess.py b/encoder/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..69986bb3bb0a2d8a0e352d1cb330a375d55f7e2c --- /dev/null +++ b/encoder/preprocess.py @@ -0,0 +1,184 @@ +from multiprocess.pool import ThreadPool +from encoder.params_data import * +from encoder.config import librispeech_datasets, anglophone_nationalites +from datetime import datetime +from encoder import audio +from pathlib import Path +from tqdm import tqdm +import numpy as np + + +class DatasetLog: + """ + Registers metadata about the dataset in a text file. + """ + def __init__(self, root, name): + self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") + self.sample_data = dict() + + start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) + self.write_line("Creating dataset %s on %s" % (name, start_time)) + self.write_line("-----") + self._log_params() + + def _log_params(self): + from encoder import params_data + self.write_line("Parameter values:") + for param_name in (p for p in dir(params_data) if not p.startswith("__")): + value = getattr(params_data, param_name) + self.write_line("\t%s: %s" % (param_name, value)) + self.write_line("-----") + + def write_line(self, line): + self.text_file.write("%s\n" % line) + + def add_sample(self, **kwargs): + for param_name, value in kwargs.items(): + if not param_name in self.sample_data: + self.sample_data[param_name] = [] + self.sample_data[param_name].append(value) + + def finalize(self): + self.write_line("Statistics:") + for param_name, values in self.sample_data.items(): + self.write_line("\t%s:" % param_name) + self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) + self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values))) + self.write_line("-----") + end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) + self.write_line("Finished on %s" % end_time) + self.text_file.close() + + +def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog): + dataset_root = datasets_root.joinpath(dataset_name) + if not dataset_root.exists(): + print("Couldn\'t find %s, skipping this dataset." % dataset_root) + return None, None + return dataset_root, DatasetLog(out_dir, dataset_name) + + +def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension, + skip_existing, logger): + print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) + + # Function to preprocess utterances for one speaker + def preprocess_speaker(speaker_dir: Path): + # Give a name to the speaker that includes its dataset + speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) + + # Create an output directory with that name, as well as a txt file containing a + # reference to each source file. + speaker_out_dir = out_dir.joinpath(speaker_name) + speaker_out_dir.mkdir(exist_ok=True) + sources_fpath = speaker_out_dir.joinpath("_sources.txt") + + # There's a possibility that the preprocessing was interrupted earlier, check if + # there already is a sources file. + if sources_fpath.exists(): + try: + with sources_fpath.open("r") as sources_file: + existing_fnames = {line.split(",")[0] for line in sources_file} + except: + existing_fnames = {} + else: + existing_fnames = {} + + # Gather all audio files for that speaker recursively + sources_file = sources_fpath.open("a" if skip_existing else "w") + for in_fpath in speaker_dir.glob("**/*.%s" % extension): + # Check if the target output file already exists + out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) + out_fname = out_fname.replace(".%s" % extension, ".npy") + if skip_existing and out_fname in existing_fnames: + continue + + # Load and preprocess the waveform + wav = audio.preprocess_wav(in_fpath) + if len(wav) == 0: + continue + + # Create the mel spectrogram, discard those that are too short + frames = audio.wav_to_mel_spectrogram(wav) + if len(frames) < partials_n_frames: + continue + + out_fpath = speaker_out_dir.joinpath(out_fname) + np.save(out_fpath, frames) + logger.add_sample(duration=len(wav) / sampling_rate) + sources_file.write("%s,%s\n" % (out_fname, in_fpath)) + + sources_file.close() + + # Process the utterances for each speaker + with ThreadPool(8) as pool: + list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), + unit="speakers")) + logger.finalize() + print("Done preprocessing %s.\n" % dataset_name) + +def preprocess_aidatatang_200zh(datasets_root: Path, out_dir: Path, skip_existing=False): + dataset_name = "aidatatang_200zh" + dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + if not dataset_root: + return + # Preprocess all speakers + speaker_dirs = list(dataset_root.joinpath("corpus", "train").glob("*")) + _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav", + skip_existing, logger) + +def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False): + for dataset_name in librispeech_datasets["train"]["other"]: + # Initialize the preprocessing + dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + if not dataset_root: + return + + # Preprocess all speakers + speaker_dirs = list(dataset_root.glob("*")) + _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac", + skip_existing, logger) + + +def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False): + # Initialize the preprocessing + dataset_name = "VoxCeleb1" + dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + if not dataset_root: + return + + # Get the contents of the meta file + with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: + metadata = [line.split("\t") for line in metafile][1:] + + # Select the ID and the nationality, filter out non-anglophone speakers + nationalities = {line[0]: line[3] for line in metadata} + keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if + nationality.lower() in anglophone_nationalites] + print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." % + (len(keep_speaker_ids), len(nationalities))) + + # Get the speaker directories for anglophone speakers only + speaker_dirs = dataset_root.joinpath("wav").glob("*") + speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if + speaker_dir.name in keep_speaker_ids] + print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." % + (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs))) + + # Preprocess all speakers + _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav", + skip_existing, logger) + + +def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False): + # Initialize the preprocessing + dataset_name = "VoxCeleb2" + dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + if not dataset_root: + return + + # Get the speaker directories + # Preprocess all speakers + speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*")) + _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", + skip_existing, logger) diff --git a/encoder/saved_models/pretrained.pt b/encoder/saved_models/pretrained.pt new file mode 100644 index 0000000000000000000000000000000000000000..1d0676b25e7c930c6743c78fcb77b8d12c8bef05 --- /dev/null +++ b/encoder/saved_models/pretrained.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57715adc6f36047166ab06e37b904240aee2f4d10fc88f78ed91510cf4b38666 +size 17095158 diff --git a/encoder/train.py b/encoder/train.py new file mode 100644 index 0000000000000000000000000000000000000000..619952e8de6c390912fe341403a39169592e585d --- /dev/null +++ b/encoder/train.py @@ -0,0 +1,123 @@ +from encoder.visualizations import Visualizations +from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset +from encoder.params_model import * +from encoder.model import SpeakerEncoder +from utils.profiler import Profiler +from pathlib import Path +import torch + +def sync(device: torch.device): + # For correct profiling (cuda operations are async) + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, + backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, + no_visdom: bool): + # Create a dataset and a dataloader + dataset = SpeakerVerificationDataset(clean_data_root) + loader = SpeakerVerificationDataLoader( + dataset, + speakers_per_batch, + utterances_per_speaker, + num_workers=8, + ) + + # Setup the device on which to run the forward pass and the loss. These can be different, + # because the forward pass is faster on the GPU whereas the loss is often (depending on your + # hyperparameters) faster on the CPU. + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # FIXME: currently, the gradient is None if loss_device is cuda + loss_device = torch.device("cpu") + + # Create the model and the optimizer + model = SpeakerEncoder(device, loss_device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) + init_step = 1 + + # Configure file path for the model + state_fpath = models_dir.joinpath(run_id + ".pt") + backup_dir = models_dir.joinpath(run_id + "_backups") + + # Load any existing model + if not force_restart: + if state_fpath.exists(): + print("Found existing model \"%s\", loading it and resuming training." % run_id) + checkpoint = torch.load(state_fpath) + init_step = checkpoint["step"] + model.load_state_dict(checkpoint["model_state"]) + optimizer.load_state_dict(checkpoint["optimizer_state"]) + optimizer.param_groups[0]["lr"] = learning_rate_init + else: + print("No model \"%s\" found, starting training from scratch." % run_id) + else: + print("Starting the training from scratch.") + model.train() + + # Initialize the visualization environment + vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) + vis.log_dataset(dataset) + vis.log_params() + device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") + vis.log_implementation({"Device": device_name}) + + # Training loop + profiler = Profiler(summarize_every=10, disabled=False) + for step, speaker_batch in enumerate(loader, init_step): + profiler.tick("Blocking, waiting for batch (threaded)") + + # Forward pass + inputs = torch.from_numpy(speaker_batch.data).to(device) + sync(device) + profiler.tick("Data to %s" % device) + embeds = model(inputs) + sync(device) + profiler.tick("Forward pass") + embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) + loss, eer = model.loss(embeds_loss) + sync(loss_device) + profiler.tick("Loss") + + # Backward pass + model.zero_grad() + loss.backward() + profiler.tick("Backward pass") + model.do_gradient_ops() + optimizer.step() + profiler.tick("Parameter update") + + # Update visualizations + # learning_rate = optimizer.param_groups[0]["lr"] + vis.update(loss.item(), eer, step) + + # Draw projections and save them to the backup folder + if umap_every != 0 and step % umap_every == 0: + print("Drawing and saving projections (step %d)" % step) + backup_dir.mkdir(exist_ok=True) + projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) + embeds = embeds.detach().cpu().numpy() + vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) + vis.save() + + # Overwrite the latest version of the model + if save_every != 0 and step % save_every == 0: + print("Saving the model (step %d)" % step) + torch.save({ + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, state_fpath) + + # Make a backup + if backup_every != 0 and step % backup_every == 0: + print("Making a backup (step %d)" % step) + backup_dir.mkdir(exist_ok=True) + backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) + torch.save({ + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, backup_fpath) + + profiler.tick("Extras (visualizations, saving)") diff --git a/encoder/visualizations.py b/encoder/visualizations.py new file mode 100644 index 0000000000000000000000000000000000000000..980c74f95f1f7df41ebccc983600b2713c0b0502 --- /dev/null +++ b/encoder/visualizations.py @@ -0,0 +1,178 @@ +from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset +from datetime import datetime +from time import perf_counter as timer +import matplotlib.pyplot as plt +import numpy as np +# import webbrowser +import visdom +import umap + +colormap = np.array([ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], +], dtype=np.float) / 255 + + +class Visualizations: + def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False): + # Tracking data + self.last_update_timestamp = timer() + self.update_every = update_every + self.step_times = [] + self.losses = [] + self.eers = [] + print("Updating the visualizations every %d steps." % update_every) + + # If visdom is disabled TODO: use a better paradigm for that + self.disabled = disabled + if self.disabled: + return + + # Set the environment name + now = str(datetime.now().strftime("%d-%m %Hh%M")) + if env_name is None: + self.env_name = now + else: + self.env_name = "%s (%s)" % (env_name, now) + + # Connect to visdom and open the corresponding window in the browser + try: + self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) + except ConnectionError: + raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " + "start it.") + # webbrowser.open("http://localhost:8097/env/" + self.env_name) + + # Create the windows + self.loss_win = None + self.eer_win = None + # self.lr_win = None + self.implementation_win = None + self.projection_win = None + self.implementation_string = "" + + def log_params(self): + if self.disabled: + return + from encoder import params_data + from encoder import params_model + param_string = "Model parameters:
" + for param_name in (p for p in dir(params_model) if not p.startswith("__")): + value = getattr(params_model, param_name) + param_string += "\t%s: %s
" % (param_name, value) + param_string += "Data parameters:
" + for param_name in (p for p in dir(params_data) if not p.startswith("__")): + value = getattr(params_data, param_name) + param_string += "\t%s: %s
" % (param_name, value) + self.vis.text(param_string, opts={"title": "Parameters"}) + + def log_dataset(self, dataset: SpeakerVerificationDataset): + if self.disabled: + return + dataset_string = "" + dataset_string += "Speakers: %s\n" % len(dataset.speakers) + dataset_string += "\n" + dataset.get_logs() + dataset_string = dataset_string.replace("\n", "
") + self.vis.text(dataset_string, opts={"title": "Dataset"}) + + def log_implementation(self, params): + if self.disabled: + return + implementation_string = "" + for param, value in params.items(): + implementation_string += "%s: %s\n" % (param, value) + implementation_string = implementation_string.replace("\n", "
") + self.implementation_string = implementation_string + self.implementation_win = self.vis.text( + implementation_string, + opts={"title": "Training implementation"} + ) + + def update(self, loss, eer, step): + # Update the tracking data + now = timer() + self.step_times.append(1000 * (now - self.last_update_timestamp)) + self.last_update_timestamp = now + self.losses.append(loss) + self.eers.append(eer) + print(".", end="") + + # Update the plots every steps + if step % self.update_every != 0: + return + time_string = "Step time: mean: %5dms std: %5dms" % \ + (int(np.mean(self.step_times)), int(np.std(self.step_times))) + print("\nStep %6d Loss: %.4f EER: %.4f %s" % + (step, np.mean(self.losses), np.mean(self.eers), time_string)) + if not self.disabled: + self.loss_win = self.vis.line( + [np.mean(self.losses)], + [step], + win=self.loss_win, + update="append" if self.loss_win else None, + opts=dict( + legend=["Avg. loss"], + xlabel="Step", + ylabel="Loss", + title="Loss", + ) + ) + self.eer_win = self.vis.line( + [np.mean(self.eers)], + [step], + win=self.eer_win, + update="append" if self.eer_win else None, + opts=dict( + legend=["Avg. EER"], + xlabel="Step", + ylabel="EER", + title="Equal error rate" + ) + ) + if self.implementation_win is not None: + self.vis.text( + self.implementation_string + ("%s" % time_string), + win=self.implementation_win, + opts={"title": "Training implementation"}, + ) + + # Reset the tracking + self.losses.clear() + self.eers.clear() + self.step_times.clear() + + def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, + max_speakers=10): + max_speakers = min(max_speakers, len(colormap)) + embeds = embeds[:max_speakers * utterances_per_speaker] + + n_speakers = len(embeds) // utterances_per_speaker + ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) + colors = [colormap[i] for i in ground_truth] + + reducer = umap.UMAP() + projected = reducer.fit_transform(embeds) + plt.scatter(projected[:, 0], projected[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection (step %d)" % step) + if not self.disabled: + self.projection_win = self.vis.matplot(plt, win=self.projection_win) + if out_fpath is not None: + plt.savefig(out_fpath) + plt.clf() + + def save(self): + if not self.disabled: + self.vis.save([self.env_name]) + \ No newline at end of file diff --git a/encoder_preprocess.py b/encoder_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..853c6cb6c5cdda5c2e53ce3370d2570f2925f01a --- /dev/null +++ b/encoder_preprocess.py @@ -0,0 +1,61 @@ +from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2, preprocess_aidatatang_200zh +from utils.argutils import print_args +from pathlib import Path +import argparse + +if __name__ == "__main__": + class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): + pass + + parser = argparse.ArgumentParser( + description="Preprocesses audio files from datasets, encodes them as mel spectrograms and " + "writes them to the disk. This will allow you to train the encoder. The " + "datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ", + formatter_class=MyFormatter + ) + parser.add_argument("datasets_root", type=Path, help=\ + "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.") + parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ + "Path to the output directory that will contain the mel spectrograms. If left out, " + "defaults to /SV2TTS/encoder/") + parser.add_argument("-d", "--datasets", type=str, + default="librispeech_other,voxceleb1,aidatatang_200zh", help=\ + "Comma-separated list of the name of the datasets you want to preprocess. Only the train " + "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, " + "voxceleb2.") + parser.add_argument("-s", "--skip_existing", action="store_true", help=\ + "Whether to skip existing output files with the same name. Useful if this script was " + "interrupted.") + parser.add_argument("--no_trim", action="store_true", help=\ + "Preprocess audio without trimming silences (not recommended).") + args = parser.parse_args() + + # Verify webrtcvad is available + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + del args.no_trim + + # Process the arguments + args.datasets = args.datasets.split(",") + if not hasattr(args, "out_dir"): + args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder") + assert args.datasets_root.exists() + args.out_dir.mkdir(exist_ok=True, parents=True) + + # Preprocess the datasets + print_args(args, parser) + preprocess_func = { + "librispeech_other": preprocess_librispeech, + "voxceleb1": preprocess_voxceleb1, + "voxceleb2": preprocess_voxceleb2, + "aidatatang_200zh": preprocess_aidatatang_200zh, + } + args = vars(args) + for dataset in args.pop("datasets"): + print("Preprocessing %s" % dataset) + preprocess_func[dataset](**args) diff --git a/encoder_train.py b/encoder_train.py new file mode 100644 index 0000000000000000000000000000000000000000..b8740a894d615aadfe529cb36068fc8e3496125f --- /dev/null +++ b/encoder_train.py @@ -0,0 +1,47 @@ +from utils.argutils import print_args +from encoder.train import train +from pathlib import Path +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Trains the speaker encoder. You must have run encoder_preprocess.py first.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("run_id", type=str, help= \ + "Name for this model instance. If a model state from the same run ID was previously " + "saved, the training will restart from there. Pass -f to overwrite saved states and " + "restart from scratch.") + parser.add_argument("clean_data_root", type=Path, help= \ + "Path to the output directory of encoder_preprocess.py. If you left the default " + "output directory when preprocessing, it should be /SV2TTS/encoder/.") + parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\ + "Path to the output directory that will contain the saved model weights, as well as " + "backups of those weights and plots generated during training.") + parser.add_argument("-v", "--vis_every", type=int, default=10, help= \ + "Number of steps between updates of the loss and the plots.") + parser.add_argument("-u", "--umap_every", type=int, default=100, help= \ + "Number of steps between updates of the umap projection. Set to 0 to never update the " + "projections.") + parser.add_argument("-s", "--save_every", type=int, default=500, help= \ + "Number of steps between updates of the model on the disk. Set to 0 to never save the " + "model.") + parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \ + "Number of steps between backups of the model. Set to 0 to never make backups of the " + "model.") + parser.add_argument("-f", "--force_restart", action="store_true", help= \ + "Do not load any saved model.") + parser.add_argument("--visdom_server", type=str, default="http://localhost") + parser.add_argument("--no_visdom", action="store_true", help= \ + "Disable visdom.") + args = parser.parse_args() + + # Process the arguments + args.models_dir.mkdir(exist_ok=True) + + # Run the training + print_args(args, parser) + train(**vars(args)) + \ No newline at end of file diff --git a/gen_voice.py b/gen_voice.py new file mode 100644 index 0000000000000000000000000000000000000000..3be4159e29e36851be761163c3e3ace02cf8d29c --- /dev/null +++ b/gen_voice.py @@ -0,0 +1,128 @@ +from encoder.params_model import model_embedding_size as speaker_embedding_size +from utils.argutils import print_args +from utils.modelutils import check_model_paths +from synthesizer.inference import Synthesizer +from encoder import inference as encoder +from vocoder.wavernn import inference as rnn_vocoder +from vocoder.hifigan import inference as gan_vocoder +from pathlib import Path +import numpy as np +import soundfile as sf +import librosa +import argparse +import torch +import sys +import os +import re +import cn2an +import glob + +from audioread.exceptions import NoBackendError +vocoder = gan_vocoder + +def gen_one_wav(synthesizer, in_fpath, embed, texts, file_name, seq): + embeds = [embed] * len(texts) + # If you know what the attention layer alignments are, you can retrieve them here by + # passing return_alignments=True + specs = synthesizer.synthesize_spectrograms(texts, embeds, style_idx=-1, min_stop_token=4, steps=400) + #spec = specs[0] + breaks = [spec.shape[1] for spec in specs] + spec = np.concatenate(specs, axis=1) + + # If seed is specified, reset torch seed and reload vocoder + # Synthesizing the waveform is fairly straightforward. Remember that the longer the + # spectrogram, the more time-efficient the vocoder. + generated_wav, output_sample_rate = vocoder.infer_waveform(spec) + + # Add breaks + b_ends = np.cumsum(np.array(breaks) * synthesizer.hparams.hop_size) + b_starts = np.concatenate(([0], b_ends[:-1])) + wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)] + breaks = [np.zeros(int(0.15 * synthesizer.sample_rate))] * len(breaks) + generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)]) + + ## Post-generation + # There's a bug with sounddevice that makes the audio cut one second earlier, so we + # pad it. + + # Trim excess silences to compensate for gaps in spectrograms (issue #53) + generated_wav = encoder.preprocess_wav(generated_wav) + generated_wav = generated_wav / np.abs(generated_wav).max() * 0.97 + + # Save it on the disk + model=os.path.basename(in_fpath) + filename = "%s_%d_%s.wav" %(file_name, seq, model) + sf.write(filename, generated_wav, synthesizer.sample_rate) + + print("\nSaved output as %s\n\n" % filename) + + +def generate_wav(enc_model_fpath, syn_model_fpath, voc_model_fpath, in_fpath, input_txt, file_name): + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device_id) + ## Print some environment information (for debugging purposes) + print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with " + "%.1fGb total memory.\n" % + (torch.cuda.device_count(), + device_id, + gpu_properties.name, + gpu_properties.major, + gpu_properties.minor, + gpu_properties.total_memory / 1e9)) + else: + print("Using CPU for inference.\n") + + print("Preparing the encoder, the synthesizer and the vocoder...") + encoder.load_model(enc_model_fpath) + synthesizer = Synthesizer(syn_model_fpath) + vocoder.load_model(voc_model_fpath) + + encoder_wav = synthesizer.load_preprocess_wav(in_fpath) + embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + texts = input_txt.split("\n") + seq=0 + each_num=1500 + + punctuation = '!,。、,' # punctuate and split/clean text + processed_texts = [] + cur_num = 0 + for text in texts: + for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'): + if processed_text: + processed_texts.append(processed_text.strip()) + cur_num += len(processed_text.strip()) + if cur_num > each_num: + seq = seq +1 + gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq) + processed_texts = [] + cur_num = 0 + + if len(processed_texts)>0: + seq = seq +1 + gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq) + +if (len(sys.argv)>=3): + my_txt = "" + print("reading from :", sys.argv[1]) + with open(sys.argv[1], "r") as f: + for line in f.readlines(): + #line = line.strip('\n') + my_txt += line + txt_file_name = sys.argv[1] + wav_file_name = sys.argv[2] + + output = cn2an.transform(my_txt, "an2cn") + print(output) + generate_wav( + Path("encoder/saved_models/pretrained.pt"), + Path("synthesizer/saved_models/mandarin.pt"), + Path("vocoder/saved_models/pretrained/g_hifigan.pt"), wav_file_name, output, txt_file_name + ) + +else: + print("please input the file name") + exit(1) + + diff --git a/mkgui/__init__.py b/mkgui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mkgui/app.py b/mkgui/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d4364aafd85208155ef4cae5f0e8daef8a5034eb --- /dev/null +++ b/mkgui/app.py @@ -0,0 +1,145 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from encoder import inference as encoder +import librosa +from scipy.io.wavfile import write +import re +import numpy as np +from mkgui.base.components.types import FileContent +from vocoder.hifigan import inference as gan_vocoder +from synthesizer.inference import Synthesizer +from typing import Any, Tuple +import matplotlib.pyplot as plt + +# Constants +AUDIO_SAMPLES_DIR = f"samples{os.sep}" +SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models" +ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" +VOC_MODELS_DIRT = f"vocoder{os.sep}saved_models" +TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav" +TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav" +if not os.path.isdir("wavs"): + os.makedirs("wavs") + +# Load local sample audio as options TODO: load dataset +if os.path.isdir(AUDIO_SAMPLES_DIR): + audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav"))) +# Pre-Load models +if os.path.isdir(SYN_MODELS_DIRT): + synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded synthesizer models: " + str(len(synthesizers))) +else: + raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(VOC_MODELS_DIRT): + vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt"))) + print("Loaded vocoders models: " + str(len(synthesizers))) +else: + raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.") + + + +class Input(BaseModel): + message: str = Field( + ..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容" + ) + local_audio_file: audio_input_selection = Field( + ..., alias="输入语音(本地wav)", + description="选择本地语音文件." + ) + upload_audio_file: FileContent = Field(default=None, alias="或上传语音", + description="拖拽或点击上传.", mime_type="audio/wav") + encoder: encoders = Field( + ..., alias="编码模型", + description="选择语音编码模型文件." + ) + synthesizer: synthesizers = Field( + ..., alias="合成模型", + description="选择语音合成模型文件." + ) + vocoder: vocoders = Field( + ..., alias="语音解码模型", + description="选择语音解码模型文件(目前只支持HifiGan类型)." + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: Tuple[AudioEntity, AudioEntity] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + src, result = self.__root__ + + streamlit_app.subheader("Synthesized Audio") + streamlit_app.audio(result.content, format="audio/wav") + + fig, ax = plt.subplots() + ax.imshow(src.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Source Audio)") + streamlit_app.pyplot(fig) + fig, ax = plt.subplots() + ax.imshow(result.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Result Audio)") + streamlit_app.pyplot(fig) + + +def synthesize(input: Input) -> Output: + """synthesize(合成)""" + # load models + encoder.load_model(Path(input.encoder.value)) + current_synt = Synthesizer(Path(input.synthesizer.value)) + gan_vocoder.load_model(Path(input.vocoder.value)) + + # load file + if input.upload_audio_file != None: + with open(TEMP_SOURCE_AUDIO, "w+b") as f: + f.write(input.upload_audio_file.as_bytes()) + f.seek(0) + wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO) + else: + wav, sample_rate = librosa.load(input.local_audio_file.value) + write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav + + source_spec = Synthesizer.make_spectrogram(wav) + + # preprocess + encoder_wav = encoder.preprocess_wav(wav, sample_rate) + embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + # Load input text + texts = filter(None, input.message.split("\n")) + punctuation = '!,。、,' # punctuate and split/clean text + processed_texts = [] + for text in texts: + for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'): + if processed_text: + processed_texts.append(processed_text.strip()) + texts = processed_texts + + # synthesize and vocode + embeds = [embed] * len(texts) + specs = current_synt.synthesize_spectrograms(texts, embeds) + spec = np.concatenate(specs, axis=1) + sample_rate = Synthesizer.sample_rate + wav, sample_rate = gan_vocoder.infer_waveform(spec) + + # write and output + write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav + with open(TEMP_SOURCE_AUDIO, "rb") as f: + source_file = f.read() + with open(TEMP_RESULT_AUDIO, "rb") as f: + result_file = f.read() + return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec))) \ No newline at end of file diff --git a/mkgui/app_vc.py b/mkgui/app_vc.py new file mode 100644 index 0000000000000000000000000000000000000000..1d69b4a23c80f800b775705f53bc483108307d6c --- /dev/null +++ b/mkgui/app_vc.py @@ -0,0 +1,166 @@ +from synthesizer.inference import Synthesizer +from pydantic import BaseModel, Field +from encoder import inference as speacker_encoder +import torch +import os +from pathlib import Path +from enum import Enum +import ppg_extractor as Extractor +import ppg2mel as Convertor +import librosa +from scipy.io.wavfile import write +import re +import numpy as np +from mkgui.base.components.types import FileContent +from vocoder.hifigan import inference as gan_vocoder +from typing import Any, Tuple +import matplotlib.pyplot as plt + + +# Constants +AUDIO_SAMPLES_DIR = f'sample{os.sep}' +EXT_MODELS_DIRT = f'ppg_extractor{os.sep}saved_models' +CONV_MODELS_DIRT = f'ppg2mel{os.sep}saved_models' +VOC_MODELS_DIRT = f'vocoder{os.sep}saved_models' +TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav' +TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav' +TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav' + +# Load local sample audio as options TODO: load dataset +if os.path.isdir(AUDIO_SAMPLES_DIR): + audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav"))) +# Pre-Load models +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +else: + raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(CONV_MODELS_DIRT): + convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) + print("Loaded convertor models: " + str(len(convertors))) +else: + raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(VOC_MODELS_DIRT): + vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt"))) + print("Loaded vocoders models: " + str(len(vocoders))) +else: + raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.") + +class Input(BaseModel): + local_audio_file: audio_input_selection = Field( + ..., alias="输入语音(本地wav)", + description="选择本地语音文件." + ) + upload_audio_file: FileContent = Field(default=None, alias="或上传语音", + description="拖拽或点击上传.", mime_type="audio/wav") + local_audio_file_target: audio_input_selection = Field( + ..., alias="目标语音(本地wav)", + description="选择本地语音文件." + ) + upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音", + description="拖拽或点击上传.", mime_type="audio/wav") + extractor: extractors = Field( + ..., alias="编码模型", + description="选择语音编码模型文件." + ) + convertor: convertors = Field( + ..., alias="转换模型", + description="选择语音转换模型文件." + ) + vocoder: vocoders = Field( + ..., alias="语音解码模型", + description="选择语音解码模型文件(目前只支持HifiGan类型)." + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: Tuple[AudioEntity, AudioEntity, AudioEntity] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + src, target, result = self.__root__ + + streamlit_app.subheader("Synthesized Audio") + streamlit_app.audio(result.content, format="audio/wav") + + fig, ax = plt.subplots() + ax.imshow(src.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Source Audio)") + streamlit_app.pyplot(fig) + fig, ax = plt.subplots() + ax.imshow(target.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Target Audio)") + streamlit_app.pyplot(fig) + fig, ax = plt.subplots() + ax.imshow(result.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Result Audio)") + streamlit_app.pyplot(fig) + +def convert(input: Input) -> Output: + """convert(转换)""" + # load models + extractor = Extractor.load_model(Path(input.extractor.value)) + convertor = Convertor.load_model(Path(input.convertor.value)) + # current_synt = Synthesizer(Path(input.synthesizer.value)) + gan_vocoder.load_model(Path(input.vocoder.value)) + + # load file + if input.upload_audio_file != None: + with open(TEMP_SOURCE_AUDIO, "w+b") as f: + f.write(input.upload_audio_file.as_bytes()) + f.seek(0) + src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO) + else: + src_wav, sample_rate = librosa.load(input.local_audio_file.value) + write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav + + if input.upload_audio_file_target != None: + with open(TEMP_TARGET_AUDIO, "w+b") as f: + f.write(input.upload_audio_file_target.as_bytes()) + f.seek(0) + ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO) + else: + ref_wav, _ = librosa.load(input.local_audio_file_target.value) + write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav + + ppg = extractor.extract_from_wav(src_wav) + # Import necessary dependency of Voice Conversion + from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv + ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav))) + speacker_encoder.load_model(Path("encoder{os.sep}saved_models{os.sep}pretrained_bak_5805000.pt")) + embed = speacker_encoder.embed_utterance(ref_wav) + lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True) + min_len = min(ppg.shape[1], len(lf0_uv)) + ppg = ppg[:, :min_len] + lf0_uv = lf0_uv[:min_len] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _, mel_pred, att_ws = convertor.inference( + ppg, + logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device), + spembs=torch.from_numpy(embed).unsqueeze(0).to(device), + ) + mel_pred= mel_pred.transpose(0, 1) + breaks = [mel_pred.shape[1]] + mel_pred= mel_pred.detach().cpu().numpy() + + # synthesize and vocode + wav, sample_rate = gan_vocoder.infer_waveform(mel_pred) + + # write and output + write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav + with open(TEMP_SOURCE_AUDIO, "rb") as f: + source_file = f.read() + with open(TEMP_TARGET_AUDIO, "rb") as f: + target_file = f.read() + with open(TEMP_RESULT_AUDIO, "rb") as f: + result_file = f.read() + + + return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav)))) \ No newline at end of file diff --git a/mkgui/base/__init__.py b/mkgui/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6905fa0da4ea5b5b30797d5dae08dd2a199318ad --- /dev/null +++ b/mkgui/base/__init__.py @@ -0,0 +1,2 @@ + +from .core import Opyrator diff --git a/mkgui/base/api/__init__.py b/mkgui/base/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c41028aab20a323d33f3107ef6483557fb74bb --- /dev/null +++ b/mkgui/base/api/__init__.py @@ -0,0 +1 @@ +from .fastapi_app import create_api diff --git a/mkgui/base/api/fastapi_utils.py b/mkgui/base/api/fastapi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..adf582a7c33c2d68ed32fb8b3382fdeb388db0d0 --- /dev/null +++ b/mkgui/base/api/fastapi_utils.py @@ -0,0 +1,102 @@ +"""Collection of utilities for FastAPI apps.""" + +import inspect +from typing import Any, Type + +from fastapi import FastAPI, Form +from pydantic import BaseModel + + +def as_form(cls: Type[BaseModel]) -> Any: + """Adds an as_form class method to decorated models. + + The as_form class method can be used with FastAPI endpoints + """ + new_params = [ + inspect.Parameter( + field.alias, + inspect.Parameter.POSITIONAL_ONLY, + default=(Form(field.default) if not field.required else Form(...)), + ) + for field in cls.__fields__.values() + ] + + async def _as_form(**data): # type: ignore + return cls(**data) + + sig = inspect.signature(_as_form) + sig = sig.replace(parameters=new_params) + _as_form.__signature__ = sig # type: ignore + setattr(cls, "as_form", _as_form) + return cls + + +def patch_fastapi(app: FastAPI) -> None: + """Patch function to allow relative url resolution. + + This patch is required to make fastapi fully functional with a relative url path. + This code snippet can be copy-pasted to any Fastapi application. + """ + from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html + from starlette.requests import Request + from starlette.responses import HTMLResponse + + async def redoc_ui_html(req: Request) -> HTMLResponse: + assert app.openapi_url is not None + redoc_ui = get_redoc_html( + openapi_url="./" + app.openapi_url.lstrip("/"), + title=app.title + " - Redoc UI", + ) + + return HTMLResponse(redoc_ui.body.decode("utf-8")) + + async def swagger_ui_html(req: Request) -> HTMLResponse: + assert app.openapi_url is not None + swagger_ui = get_swagger_ui_html( + openapi_url="./" + app.openapi_url.lstrip("/"), + title=app.title + " - Swagger UI", + oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, + ) + + # insert request interceptor to have all request run on relativ path + request_interceptor = ( + "requestInterceptor: (e) => {" + "\n\t\t\tvar url = window.location.origin + window.location.pathname" + '\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);' + "\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605 + "\n\t\t\te.contextUrl = url" + "\n\t\t\te.url = url" + "\n\t\t\treturn e;}" + ) + + return HTMLResponse( + swagger_ui.body.decode("utf-8").replace( + "dom_id: '#swagger-ui',", + "dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",", + ) + ) + + # remove old docs route and add our patched route + routes_new = [] + for app_route in app.routes: + if app_route.path == "/docs": # type: ignore + continue + + if app_route.path == "/redoc": # type: ignore + continue + + routes_new.append(app_route) + + app.router.routes = routes_new + + assert app.docs_url is not None + app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False) + assert app.redoc_url is not None + app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False) + + # Make graphql realtive + from starlette import graphql + + graphql.GRAPHIQL = graphql.GRAPHIQL.replace( + "({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}' + ) diff --git a/mkgui/base/components/__init__.py b/mkgui/base/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mkgui/base/components/outputs.py b/mkgui/base/components/outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..f4859c64b9e21114436e57863fedd5fd161da627 --- /dev/null +++ b/mkgui/base/components/outputs.py @@ -0,0 +1,43 @@ +from typing import List + +from pydantic import BaseModel + + +class ScoredLabel(BaseModel): + label: str + score: float + + +class ClassificationOutput(BaseModel): + __root__: List[ScoredLabel] + + def __iter__(self): # type: ignore + return iter(self.__root__) + + def __getitem__(self, item): # type: ignore + return self.__root__[item] + + def render_output_ui(self, streamlit) -> None: # type: ignore + import plotly.express as px + + sorted_predictions = sorted( + [prediction.dict() for prediction in self.__root__], + key=lambda k: k["score"], + ) + + num_labels = len(sorted_predictions) + if len(sorted_predictions) > 10: + num_labels = streamlit.slider( + "Maximum labels to show: ", + min_value=1, + max_value=len(sorted_predictions), + value=len(sorted_predictions), + ) + fig = px.bar( + sorted_predictions[len(sorted_predictions) - num_labels :], + x="score", + y="label", + orientation="h", + ) + streamlit.plotly_chart(fig, use_container_width=True) + # fig.show() diff --git a/mkgui/base/components/types.py b/mkgui/base/components/types.py new file mode 100644 index 0000000000000000000000000000000000000000..125809a81b306ddeab4cf6ab0ba6abdbe8d0c4ed --- /dev/null +++ b/mkgui/base/components/types.py @@ -0,0 +1,46 @@ +import base64 +from typing import Any, Dict, overload + + +class FileContent(str): + def as_bytes(self) -> bytes: + return base64.b64decode(self, validate=True) + + def as_str(self) -> str: + return self.as_bytes().decode() + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(format="byte") + + @classmethod + def __get_validators__(cls) -> Any: # type: ignore + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> "FileContent": + if isinstance(value, FileContent): + return value + elif isinstance(value, str): + return FileContent(value) + elif isinstance(value, (bytes, bytearray, memoryview)): + return FileContent(base64.b64encode(value).decode()) + else: + raise Exception("Wrong type") + +# # 暂时无法使用,因为浏览器中没有考虑选择文件夹 +# class DirectoryContent(FileContent): +# @classmethod +# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: +# field_schema.update(format="path") + +# @classmethod +# def validate(cls, value: Any) -> "DirectoryContent": +# if isinstance(value, DirectoryContent): +# return value +# elif isinstance(value, str): +# return DirectoryContent(value) +# elif isinstance(value, (bytes, bytearray, memoryview)): +# return DirectoryContent(base64.b64encode(value).decode()) +# else: +# raise Exception("Wrong type") diff --git a/mkgui/base/core.py b/mkgui/base/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8166a330c6e486e916533fcea2ae4393217a852e --- /dev/null +++ b/mkgui/base/core.py @@ -0,0 +1,203 @@ +import importlib +import inspect +import re +from typing import Any, Callable, Type, Union, get_type_hints + +from pydantic import BaseModel, parse_raw_as +from pydantic.tools import parse_obj_as + + +def name_to_title(name: str) -> str: + """Converts a camelCase or snake_case name to title case.""" + # If camelCase -> convert to snake case + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + # Convert to title case + return name.replace("_", " ").strip().title() + + +def is_compatible_type(type: Type) -> bool: + """Returns `True` if the type is opyrator-compatible.""" + try: + if issubclass(type, BaseModel): + return True + except Exception: + pass + + try: + # valid list type + if type.__origin__ is list and issubclass(type.__args__[0], BaseModel): + return True + except Exception: + pass + + return False + + +def get_input_type(func: Callable) -> Type: + """Returns the input type of a given function (callable). + + Args: + func: The function for which to get the input type. + + Raises: + ValueError: If the function does not have a valid input type annotation. + """ + type_hints = get_type_hints(func) + + if "input" not in type_hints: + raise ValueError( + "The callable MUST have a parameter with the name `input` with typing annotation. " + "For example: `def my_opyrator(input: InputModel) -> OutputModel:`." + ) + + input_type = type_hints["input"] + + if not is_compatible_type(input_type): + raise ValueError( + "The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models." + ) + + # TODO: return warning if more than one input parameters + + return input_type + + +def get_output_type(func: Callable) -> Type: + """Returns the output type of a given function (callable). + + Args: + func: The function for which to get the output type. + + Raises: + ValueError: If the function does not have a valid output type annotation. + """ + type_hints = get_type_hints(func) + if "return" not in type_hints: + raise ValueError( + "The return type of the callable MUST be annotated with type hints." + "For example: `def my_opyrator(input: InputModel) -> OutputModel:`." + ) + + output_type = type_hints["return"] + + if not is_compatible_type(output_type): + raise ValueError( + "The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models." + ) + + return output_type + + +def get_callable(import_string: str) -> Callable: + """Import a callable from an string.""" + callable_seperator = ":" + if callable_seperator not in import_string: + # Use dot as seperator + callable_seperator = "." + + if callable_seperator not in import_string: + raise ValueError("The callable path MUST specify the function. ") + + mod_name, callable_name = import_string.rsplit(callable_seperator, 1) + mod = importlib.import_module(mod_name) + return getattr(mod, callable_name) + + +class Opyrator: + def __init__(self, func: Union[Callable, str]) -> None: + if isinstance(func, str): + # Try to load the function from a string notion + self.function = get_callable(func) + else: + self.function = func + + self._action = "Execute" + self._input_type = None + self._output_type = None + + if not callable(self.function): + raise ValueError("The provided function parameters is not a callable.") + + if inspect.isclass(self.function): + raise ValueError( + "The provided callable is an uninitialized Class. This is not allowed." + ) + + if inspect.isfunction(self.function): + # The provided callable is a function + self._input_type = get_input_type(self.function) + self._output_type = get_output_type(self.function) + + try: + # Get name + self._name = name_to_title(self.function.__name__) + except Exception: + pass + + try: + # Get description from function + doc_string = inspect.getdoc(self.function) + if doc_string: + self._action = doc_string + except Exception: + pass + elif hasattr(self.function, "__call__"): + # The provided callable is a function + self._input_type = get_input_type(self.function.__call__) # type: ignore + self._output_type = get_output_type(self.function.__call__) # type: ignore + + try: + # Get name + self._name = name_to_title(type(self.function).__name__) + except Exception: + pass + + try: + # Get action from + doc_string = inspect.getdoc(self.function.__call__) # type: ignore + if doc_string: + self._action = doc_string + + if ( + not self._action + or self._action == "Call" + ): + # Get docstring from class instead of __call__ function + doc_string = inspect.getdoc(self.function) + if doc_string: + self._action = doc_string + except Exception: + pass + else: + raise ValueError("Unknown callable type.") + + @property + def name(self) -> str: + return self._name + + @property + def action(self) -> str: + return self._action + + @property + def input_type(self) -> Any: + return self._input_type + + @property + def output_type(self) -> Any: + return self._output_type + + def __call__(self, input: Any, **kwargs: Any) -> Any: + + input_obj = input + + if isinstance(input, str): + # Allow json input + input_obj = parse_raw_as(self.input_type, input) + + if isinstance(input, dict): + # Allow dict input + input_obj = parse_obj_as(self.input_type, input) + + return self.function(input_obj, **kwargs) diff --git a/mkgui/base/ui/__init__.py b/mkgui/base/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..593b254ea68dc4c5b3c3f5d4622334133316866f --- /dev/null +++ b/mkgui/base/ui/__init__.py @@ -0,0 +1 @@ +from .streamlit_ui import render_streamlit_ui diff --git a/mkgui/base/ui/schema_utils.py b/mkgui/base/ui/schema_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a2be43c07175f18e6f285eae5fddc5c0c2faa7aa --- /dev/null +++ b/mkgui/base/ui/schema_utils.py @@ -0,0 +1,129 @@ +from typing import Dict + + +def resolve_reference(reference: str, references: Dict) -> Dict: + return references[reference.split("/")[-1]] + + +def get_single_reference_item(property: Dict, references: Dict) -> Dict: + # Ref can either be directly in the properties or the first element of allOf + reference = property.get("$ref") + if reference is None: + reference = property["allOf"][0]["$ref"] + return resolve_reference(reference, references) + + +def is_single_string_property(property: Dict) -> bool: + return property.get("type") == "string" + + +def is_single_datetime_property(property: Dict) -> bool: + if property.get("type") != "string": + return False + return property.get("format") in ["date-time", "time", "date"] + + +def is_single_boolean_property(property: Dict) -> bool: + return property.get("type") == "boolean" + + +def is_single_number_property(property: Dict) -> bool: + return property.get("type") in ["integer", "number"] + + +def is_single_file_property(property: Dict) -> bool: + if property.get("type") != "string": + return False + # TODO: binary? + return property.get("format") == "byte" + + +def is_single_directory_property(property: Dict) -> bool: + if property.get("type") != "string": + return False + return property.get("format") == "path" + +def is_multi_enum_property(property: Dict, references: Dict) -> bool: + if property.get("type") != "array": + return False + + if property.get("uniqueItems") is not True: + # Only relevant if it is a set or other datastructures with unique items + return False + + try: + _ = resolve_reference(property["items"]["$ref"], references)["enum"] + return True + except Exception: + return False + + +def is_single_enum_property(property: Dict, references: Dict) -> bool: + try: + _ = get_single_reference_item(property, references)["enum"] + return True + except Exception: + return False + + +def is_single_dict_property(property: Dict) -> bool: + if property.get("type") != "object": + return False + return "additionalProperties" in property + + +def is_single_reference(property: Dict) -> bool: + if property.get("type") is not None: + return False + + return bool(property.get("$ref")) + + +def is_multi_file_property(property: Dict) -> bool: + if property.get("type") != "array": + return False + + if property.get("items") is None: + return False + + try: + # TODO: binary + return property["items"]["format"] == "byte" + except Exception: + return False + + +def is_single_object(property: Dict, references: Dict) -> bool: + try: + object_reference = get_single_reference_item(property, references) + if object_reference["type"] != "object": + return False + return "properties" in object_reference + except Exception: + return False + + +def is_property_list(property: Dict) -> bool: + if property.get("type") != "array": + return False + + if property.get("items") is None: + return False + + try: + return property["items"]["type"] in ["string", "number", "integer"] + except Exception: + return False + + +def is_object_list_property(property: Dict, references: Dict) -> bool: + if property.get("type") != "array": + return False + + try: + object_reference = resolve_reference(property["items"]["$ref"], references) + if object_reference["type"] != "object": + return False + return "properties" in object_reference + except Exception: + return False diff --git a/mkgui/base/ui/streamlit_ui.py b/mkgui/base/ui/streamlit_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..479fe1c3e3ec6cd9f2c785c777ea9fe892853d8b --- /dev/null +++ b/mkgui/base/ui/streamlit_ui.py @@ -0,0 +1,888 @@ +import datetime +import inspect +import mimetypes +import sys +from os import getcwd, unlink +from platform import system +from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, List, Type +from PIL import Image + +import pandas as pd +import streamlit as st +from fastapi.encoders import jsonable_encoder +from loguru import logger +from pydantic import BaseModel, ValidationError, parse_obj_as + +from mkgui.base import Opyrator +from mkgui.base.core import name_to_title +from mkgui.base.ui import schema_utils +from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS + +STREAMLIT_RUNNER_SNIPPET = """ +from mkgui.base.ui import render_streamlit_ui +from mkgui.base import Opyrator + +import streamlit as st + +# TODO: Make it configurable +# Page config can only be setup once +st.set_page_config( + page_title="MockingBird", + page_icon="🧊", + layout="wide") + +render_streamlit_ui() +""" + +# with st.spinner("Loading MockingBird GUI. Please wait..."): +# opyrator = Opyrator("{opyrator_path}") + + +def launch_ui(port: int = 8501) -> None: + with NamedTemporaryFile( + suffix=".py", mode="w", encoding="utf-8", delete=False + ) as f: + f.write(STREAMLIT_RUNNER_SNIPPET) + f.seek(0) + + import subprocess + + python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"' + if system() == "Windows": + python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&" + subprocess.run( + f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""", + shell=True, + ) + + subprocess.run( + f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""", + shell=True, + ) + + f.close() + unlink(f.name) + + +def function_has_named_arg(func: Callable, parameter: str) -> bool: + try: + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.name == "input": + return True + except Exception: + return False + return False + + +def has_output_ui_renderer(data_item: BaseModel) -> bool: + return hasattr(data_item, "render_output_ui") + + +def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool: + return hasattr(input_class, "render_input_ui") + + +def is_compatible_audio(mime_type: str) -> bool: + return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"] + + +def is_compatible_image(mime_type: str) -> bool: + return mime_type in ["image/png", "image/jpeg"] + + +def is_compatible_video(mime_type: str) -> bool: + return mime_type in ["video/mp4"] + + +class InputUI: + def __init__(self, session_state, input_class: Type[BaseModel]): + self._session_state = session_state + self._input_class = input_class + + self._schema_properties = input_class.schema(by_alias=True).get( + "properties", {} + ) + self._schema_references = input_class.schema(by_alias=True).get( + "definitions", {} + ) + + def render_ui(self, streamlit_app_root) -> None: + if has_input_ui_renderer(self._input_class): + # The input model has a rendering function + # The rendering also returns the current state of input data + self._session_state.input_data = self._input_class.render_input_ui( # type: ignore + st, self._session_state.input_data + ) + return + + # print(self._schema_properties) + for property_key in self._schema_properties.keys(): + property = self._schema_properties[property_key] + + if not property.get("title"): + # Set property key as fallback title + property["title"] = name_to_title(property_key) + + try: + if "input_data" in self._session_state: + self._store_value( + property_key, + self._render_property(streamlit_app_root, property_key, property), + ) + except Exception as e: + print("Exception!", e) + pass + + def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict: + streamlit_kwargs = { + "label": property.get("title"), + "key": key, + } + + if property.get("description"): + streamlit_kwargs["help"] = property.get("description") + return streamlit_kwargs + + def _store_value(self, key: str, value: Any) -> None: + data_element = self._session_state.input_data + key_elements = key.split(".") + for i, key_element in enumerate(key_elements): + if i == len(key_elements) - 1: + # add value to this element + data_element[key_element] = value + return + if key_element not in data_element: + data_element[key_element] = {} + data_element = data_element[key_element] + + def _get_value(self, key: str) -> Any: + data_element = self._session_state.input_data + key_elements = key.split(".") + for i, key_element in enumerate(key_elements): + if i == len(key_elements) - 1: + # add value to this element + if key_element not in data_element: + return None + return data_element[key_element] + if key_element not in data_element: + data_element[key_element] = {} + data_element = data_element[key_element] + return None + + def _render_single_datetime_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + if property.get("format") == "time": + if property.get("default"): + try: + streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore + property.get("default") + ) + except Exception: + pass + return streamlit_app.time_input(**streamlit_kwargs) + elif property.get("format") == "date": + if property.get("default"): + try: + streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore + property.get("default") + ) + except Exception: + pass + return streamlit_app.date_input(**streamlit_kwargs) + elif property.get("format") == "date-time": + if property.get("default"): + try: + streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore + property.get("default") + ) + except Exception: + pass + with streamlit_app.container(): + streamlit_app.subheader(streamlit_kwargs.get("label")) + if streamlit_kwargs.get("description"): + streamlit_app.text(streamlit_kwargs.get("description")) + selected_date = None + selected_time = None + date_col, time_col = streamlit_app.columns(2) + with date_col: + date_kwargs = {"label": "Date", "key": key + "-date-input"} + if streamlit_kwargs.get("value"): + try: + date_kwargs["value"] = streamlit_kwargs.get( # type: ignore + "value" + ).date() + except Exception: + pass + selected_date = streamlit_app.date_input(**date_kwargs) + + with time_col: + time_kwargs = {"label": "Time", "key": key + "-time-input"} + if streamlit_kwargs.get("value"): + try: + time_kwargs["value"] = streamlit_kwargs.get( # type: ignore + "value" + ).time() + except Exception: + pass + selected_time = streamlit_app.time_input(**time_kwargs) + return datetime.datetime.combine(selected_date, selected_time) + else: + streamlit_app.warning( + "Date format is not supported: " + str(property.get("format")) + ) + + def _render_single_file_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + file_extension = None + if "mime_type" in property: + file_extension = mimetypes.guess_extension(property["mime_type"]) + + uploaded_file = streamlit_app.file_uploader( + **streamlit_kwargs, accept_multiple_files=False, type=file_extension + ) + if uploaded_file is None: + return None + + bytes = uploaded_file.getvalue() + if property.get("mime_type"): + if is_compatible_audio(property["mime_type"]): + # Show audio + streamlit_app.audio(bytes, format=property.get("mime_type")) + if is_compatible_image(property["mime_type"]): + # Show image + streamlit_app.image(bytes) + if is_compatible_video(property["mime_type"]): + # Show video + streamlit_app.video(bytes, format=property.get("mime_type")) + return bytes + + def _render_single_string_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + if property.get("default"): + streamlit_kwargs["value"] = property.get("default") + elif property.get("example"): + # TODO: also use example for other property types + # Use example as value if it is provided + streamlit_kwargs["value"] = property.get("example") + + if property.get("maxLength") is not None: + streamlit_kwargs["max_chars"] = property.get("maxLength") + + if ( + property.get("format") + or ( + property.get("maxLength") is not None + and int(property.get("maxLength")) < 140 # type: ignore + ) + or property.get("writeOnly") + ): + # If any format is set, use single text input + # If max chars is set to less than 140, use single text input + # If write only -> password field + if property.get("writeOnly"): + streamlit_kwargs["type"] = "password" + return streamlit_app.text_input(**streamlit_kwargs) + else: + # Otherwise use multiline text area + return streamlit_app.text_area(**streamlit_kwargs) + + def _render_multi_enum_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + reference_item = schema_utils.resolve_reference( + property["items"]["$ref"], self._schema_references + ) + # TODO: how to select defaults + return streamlit_app.multiselect( + **streamlit_kwargs, options=reference_item["enum"] + ) + + def _render_single_enum_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + reference_item = schema_utils.get_single_reference_item( + property, self._schema_references + ) + + if property.get("default") is not None: + try: + streamlit_kwargs["index"] = reference_item["enum"].index( + property.get("default") + ) + except Exception: + # Use default selection + pass + + return streamlit_app.selectbox( + **streamlit_kwargs, options=reference_item["enum"] + ) + + def _render_single_dict_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + # Add title and subheader + streamlit_app.subheader(property.get("title")) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + streamlit_app.markdown("---") + + current_dict = self._get_value(key) + if not current_dict: + current_dict = {} + + key_col, value_col = streamlit_app.columns(2) + + with key_col: + updated_key = streamlit_app.text_input( + "Key", value="", key=key + "-new-key" + ) + + with value_col: + # TODO: also add boolean? + value_kwargs = {"label": "Value", "key": key + "-new-value"} + if property["additionalProperties"].get("type") == "integer": + value_kwargs["value"] = 0 # type: ignore + updated_value = streamlit_app.number_input(**value_kwargs) + elif property["additionalProperties"].get("type") == "number": + value_kwargs["value"] = 0.0 # type: ignore + value_kwargs["format"] = "%f" + updated_value = streamlit_app.number_input(**value_kwargs) + else: + value_kwargs["value"] = "" + updated_value = streamlit_app.text_input(**value_kwargs) + + streamlit_app.markdown("---") + + with streamlit_app.container(): + clear_col, add_col = streamlit_app.columns([1, 2]) + + with clear_col: + if streamlit_app.button("Clear Items", key=key + "-clear-items"): + current_dict = {} + + with add_col: + if ( + streamlit_app.button("Add Item", key=key + "-add-item") + and updated_key + ): + current_dict[updated_key] = updated_value + + streamlit_app.write(current_dict) + + return current_dict + + def _render_single_reference( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + reference_item = schema_utils.get_single_reference_item( + property, self._schema_references + ) + return self._render_property(streamlit_app, key, reference_item) + + def _render_multi_file_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + file_extension = None + if "mime_type" in property: + file_extension = mimetypes.guess_extension(property["mime_type"]) + + uploaded_files = streamlit_app.file_uploader( + **streamlit_kwargs, accept_multiple_files=True, type=file_extension + ) + uploaded_files_bytes = [] + if uploaded_files: + for uploaded_file in uploaded_files: + uploaded_files_bytes.append(uploaded_file.read()) + return uploaded_files_bytes + + def _render_single_boolean_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + if property.get("default"): + streamlit_kwargs["value"] = property.get("default") + return streamlit_app.checkbox(**streamlit_kwargs) + + def _render_single_number_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + number_transform = int + if property.get("type") == "number": + number_transform = float # type: ignore + streamlit_kwargs["format"] = "%f" + + if "multipleOf" in property: + # Set stepcount based on multiple of parameter + streamlit_kwargs["step"] = number_transform(property["multipleOf"]) + elif number_transform == int: + # Set step size to 1 as default + streamlit_kwargs["step"] = 1 + elif number_transform == float: + # Set step size to 0.01 as default + # TODO: adapt to default value + streamlit_kwargs["step"] = 0.01 + + if "minimum" in property: + streamlit_kwargs["min_value"] = number_transform(property["minimum"]) + if "exclusiveMinimum" in property: + streamlit_kwargs["min_value"] = number_transform( + property["exclusiveMinimum"] + streamlit_kwargs["step"] + ) + if "maximum" in property: + streamlit_kwargs["max_value"] = number_transform(property["maximum"]) + + if "exclusiveMaximum" in property: + streamlit_kwargs["max_value"] = number_transform( + property["exclusiveMaximum"] - streamlit_kwargs["step"] + ) + + if property.get("default") is not None: + streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore + else: + if "min_value" in streamlit_kwargs: + streamlit_kwargs["value"] = streamlit_kwargs["min_value"] + elif number_transform == int: + streamlit_kwargs["value"] = 0 + else: + # Set default value to step + streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"]) + + if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs: + # TODO: Only if less than X steps + return streamlit_app.slider(**streamlit_kwargs) + else: + return streamlit_app.number_input(**streamlit_kwargs) + + def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any: + properties = property["properties"] + object_inputs = {} + for property_key in properties: + property = properties[property_key] + if not property.get("title"): + # Set property key as fallback title + property["title"] = name_to_title(property_key) + # construct full key based on key parts -> required later to get the value + full_key = key + "." + property_key + object_inputs[property_key] = self._render_property( + streamlit_app, full_key, property + ) + return object_inputs + + def _render_single_object_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + # Add title and subheader + title = property.get("title") + streamlit_app.subheader(title) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + object_reference = schema_utils.get_single_reference_item( + property, self._schema_references + ) + return self._render_object_input(streamlit_app, key, object_reference) + + def _render_property_list_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + # Add title and subheader + streamlit_app.subheader(property.get("title")) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + streamlit_app.markdown("---") + + current_list = self._get_value(key) + if not current_list: + current_list = [] + + value_kwargs = {"label": "Value", "key": key + "-new-value"} + if property["items"]["type"] == "integer": + value_kwargs["value"] = 0 # type: ignore + new_value = streamlit_app.number_input(**value_kwargs) + elif property["items"]["type"] == "number": + value_kwargs["value"] = 0.0 # type: ignore + value_kwargs["format"] = "%f" + new_value = streamlit_app.number_input(**value_kwargs) + else: + value_kwargs["value"] = "" + new_value = streamlit_app.text_input(**value_kwargs) + + streamlit_app.markdown("---") + + with streamlit_app.container(): + clear_col, add_col = streamlit_app.columns([1, 2]) + + with clear_col: + if streamlit_app.button("Clear Items", key=key + "-clear-items"): + current_list = [] + + with add_col: + if ( + streamlit_app.button("Add Item", key=key + "-add-item") + and new_value is not None + ): + current_list.append(new_value) + + streamlit_app.write(current_list) + + return current_list + + def _render_object_list_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + # TODO: support max_items, and min_items properties + + # Add title and subheader + streamlit_app.subheader(property.get("title")) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + streamlit_app.markdown("---") + + current_list = self._get_value(key) + if not current_list: + current_list = [] + + object_reference = schema_utils.resolve_reference( + property["items"]["$ref"], self._schema_references + ) + input_data = self._render_object_input(streamlit_app, key, object_reference) + + streamlit_app.markdown("---") + + with streamlit_app.container(): + clear_col, add_col = streamlit_app.columns([1, 2]) + + with clear_col: + if streamlit_app.button("Clear Items", key=key + "-clear-items"): + current_list = [] + + with add_col: + if ( + streamlit_app.button("Add Item", key=key + "-add-item") + and input_data + ): + current_list.append(input_data) + + streamlit_app.write(current_list) + return current_list + + def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any: + if schema_utils.is_single_enum_property(property, self._schema_references): + return self._render_single_enum_input(streamlit_app, key, property) + + if schema_utils.is_multi_enum_property(property, self._schema_references): + return self._render_multi_enum_input(streamlit_app, key, property) + + if schema_utils.is_single_file_property(property): + return self._render_single_file_input(streamlit_app, key, property) + + if schema_utils.is_multi_file_property(property): + return self._render_multi_file_input(streamlit_app, key, property) + + if schema_utils.is_single_datetime_property(property): + return self._render_single_datetime_input(streamlit_app, key, property) + + if schema_utils.is_single_boolean_property(property): + return self._render_single_boolean_input(streamlit_app, key, property) + + if schema_utils.is_single_dict_property(property): + return self._render_single_dict_input(streamlit_app, key, property) + + if schema_utils.is_single_number_property(property): + return self._render_single_number_input(streamlit_app, key, property) + + if schema_utils.is_single_string_property(property): + return self._render_single_string_input(streamlit_app, key, property) + + if schema_utils.is_single_object(property, self._schema_references): + return self._render_single_object_input(streamlit_app, key, property) + + if schema_utils.is_object_list_property(property, self._schema_references): + return self._render_object_list_input(streamlit_app, key, property) + + if schema_utils.is_property_list(property): + return self._render_property_list_input(streamlit_app, key, property) + + if schema_utils.is_single_reference(property): + return self._render_single_reference(streamlit_app, key, property) + + streamlit_app.warning( + "The type of the following property is currently not supported: " + + str(property.get("title")) + ) + raise Exception("Unsupported property") + + +class OutputUI: + def __init__(self, output_data: Any, input_data: Any): + self._output_data = output_data + self._input_data = input_data + + def render_ui(self, streamlit_app) -> None: + try: + if isinstance(self._output_data, BaseModel): + self._render_single_output(streamlit_app, self._output_data) + return + if type(self._output_data) == list: + self._render_list_output(streamlit_app, self._output_data) + return + except Exception as ex: + streamlit_app.exception(ex) + # Fallback to + streamlit_app.json(jsonable_encoder(self._output_data)) + + def _render_single_text_property( + self, streamlit: st, property_schema: Dict, value: Any + ) -> None: + # Add title and subheader + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + if value is None or value == "": + streamlit.info("No value returned!") + else: + streamlit.code(str(value), language="plain") + + def _render_single_file_property( + self, streamlit: st, property_schema: Dict, value: Any + ) -> None: + # Add title and subheader + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + if value is None or value == "": + streamlit.info("No value returned!") + else: + # TODO: Detect if it is a FileContent instance + # TODO: detect if it is base64 + file_extension = "" + if "mime_type" in property_schema: + mime_type = property_schema["mime_type"] + file_extension = mimetypes.guess_extension(mime_type) or "" + + if is_compatible_audio(mime_type): + streamlit.audio(value.as_bytes(), format=mime_type) + return + + if is_compatible_image(mime_type): + streamlit.image(value.as_bytes()) + return + + if is_compatible_video(mime_type): + streamlit.video(value.as_bytes(), format=mime_type) + return + + filename = ( + (property_schema["title"] + file_extension) + .lower() + .strip() + .replace(" ", "-") + ) + streamlit.markdown( + f'', + unsafe_allow_html=True, + ) + + def _render_single_complex_property( + self, streamlit: st, property_schema: Dict, value: Any + ) -> None: + # Add title and subheader + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + + streamlit.json(jsonable_encoder(value)) + + def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None: + try: + if has_output_ui_renderer(output_data): + if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore + # render method also requests the input data + output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore + else: + output_data.render_output_ui(streamlit) # type: ignore + return + except Exception: + # Use default auto-generation methods if the custom rendering throws an exception + logger.exception( + "Failed to execute custom render_output_ui function. Using auto-generation instead" + ) + + model_schema = output_data.schema(by_alias=False) + model_properties = model_schema.get("properties") + definitions = model_schema.get("definitions") + + if model_properties: + for property_key in output_data.__dict__: + property_schema = model_properties.get(property_key) + if not property_schema.get("title"): + # Set property key as fallback title + property_schema["title"] = property_key + + output_property_value = output_data.__dict__[property_key] + + if has_output_ui_renderer(output_property_value): + output_property_value.render_output_ui(streamlit) # type: ignore + continue + + if isinstance(output_property_value, BaseModel): + # Render output recursivly + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + self._render_single_output(streamlit, output_property_value) + continue + + if property_schema: + if schema_utils.is_single_file_property(property_schema): + self._render_single_file_property( + streamlit, property_schema, output_property_value + ) + continue + + if ( + schema_utils.is_single_string_property(property_schema) + or schema_utils.is_single_number_property(property_schema) + or schema_utils.is_single_datetime_property(property_schema) + or schema_utils.is_single_boolean_property(property_schema) + ): + self._render_single_text_property( + streamlit, property_schema, output_property_value + ) + continue + if definitions and schema_utils.is_single_enum_property( + property_schema, definitions + ): + self._render_single_text_property( + streamlit, property_schema, output_property_value.value + ) + continue + + # TODO: render dict as table + + self._render_single_complex_property( + streamlit, property_schema, output_property_value + ) + return + + def _render_list_output(self, streamlit: st, output_data: List) -> None: + try: + data_items: List = [] + for data_item in output_data: + if has_output_ui_renderer(data_item): + # Render using the render function + data_item.render_output_ui(streamlit) # type: ignore + continue + data_items.append(data_item.dict()) + # Try to show as dataframe + streamlit.table(pd.DataFrame(data_items)) + except Exception: + # Fallback to + streamlit.json(jsonable_encoder(output_data)) + + +def getOpyrator(mode: str) -> Opyrator: + if mode == None or mode.startswith('VC'): + from mkgui.app_vc import convert + return Opyrator(convert) + if mode == None or mode.startswith('预处理'): + from mkgui.preprocess import preprocess + return Opyrator(preprocess) + if mode == None or mode.startswith('模型训练'): + from mkgui.train import train + return Opyrator(train) + if mode == None or mode.startswith('模型训练(VC)'): + from mkgui.train_vc import train_vc + return Opyrator(train_vc) + from mkgui.app import synthesize + return Opyrator(synthesize) + + +def render_streamlit_ui() -> None: + # init + session_state = st.session_state + session_state.input_data = {} + # Add custom css settings + st.markdown(f"", unsafe_allow_html=True) + + with st.spinner("Loading MockingBird GUI. Please wait..."): + session_state.mode = st.sidebar.selectbox( + '模式选择', + ( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)") + ) + if "mode" in session_state: + mode = session_state.mode + else: + mode = "" + opyrator = getOpyrator(mode) + title = opyrator.name + mode + + col1, col2, _ = st.columns(3) + col2.title(title) + col2.markdown("欢迎使用MockingBird Web 2") + + image = Image.open('.\\mkgui\\static\\mb.png') + col1.image(image) + + st.markdown("---") + left, right = st.columns([0.4, 0.6]) + + with left: + st.header("Control 控制") + InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st) + execute_selected = st.button(opyrator.action) + if execute_selected: + with st.spinner("Executing operation. Please wait..."): + try: + input_data_obj = parse_obj_as( + opyrator.input_type, session_state.input_data + ) + session_state.output_data = opyrator(input=input_data_obj) + session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object? + except ValidationError as ex: + st.error(ex) + else: + # st.success("Operation executed successfully.") + pass + + with right: + st.header("Result 结果") + if 'output_data' in session_state: + OutputUI( + session_state.output_data, session_state.latest_operation_input + ).render_ui(st) + if st.button("Clear"): + # Clear all state + for key in st.session_state.keys(): + del st.session_state[key] + session_state.input_data = {} + st.experimental_rerun() + else: + # placeholder + st.caption("请使用左侧控制板进行输入并运行获得结果") + + diff --git a/mkgui/base/ui/streamlit_utils.py b/mkgui/base/ui/streamlit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..beb6e65c61f8a16b4376494123f31178cdb88bde --- /dev/null +++ b/mkgui/base/ui/streamlit_utils.py @@ -0,0 +1,13 @@ +CUSTOM_STREAMLIT_CSS = """ +div[data-testid="stBlock"] button { + width: 100% !important; + margin-bottom: 20px !important; + border-color: #bfbfbf !important; +} +section[data-testid="stSidebar"] div { + max-width: 10rem; +} +pre code { + white-space: pre-wrap; +} +""" diff --git a/mkgui/preprocess.py b/mkgui/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..91579b6f2d09bc161a73f199b4581a013cccb194 --- /dev/null +++ b/mkgui/preprocess.py @@ -0,0 +1,96 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from typing import Any, Tuple + + +# Constants +EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" +ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" + + +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +else: + raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + +class Model(str, Enum): + VC_PPG2MEL = "ppg2mel" + +class Dataset(str, Enum): + AIDATATANG_200ZH = "aidatatang_200zh" + AIDATATANG_200ZH_S = "aidatatang_200zh_s" + +class Input(BaseModel): + # def render_input_ui(st, input) -> Dict: + # input["selected_dataset"] = st.selectbox( + # '选择数据集', + # ("aidatatang_200zh", "aidatatang_200zh_s") + # ) + # return input + model: Model = Field( + Model.VC_PPG2MEL, title="目标模型", + ) + dataset: Dataset = Field( + Dataset.AIDATATANG_200ZH, title="数据集选择", + ) + datasets_root: str = Field( + ..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)", + format=True, + example="..\\trainning_data\\" + ) + output_root: str = Field( + ..., alias="输出根目录", description="输出结果根目录(相对/绝对)", + format=True, + example="..\\trainning_data\\" + ) + n_processes: int = Field( + 2, alias="处理线程数", description="根据CPU线程数来设置", + le=32, ge=1 + ) + extractor: extractors = Field( + ..., alias="特征提取模型", + description="选择PPG特征提取模型文件." + ) + encoder: encoders = Field( + ..., alias="语音编码模型", + description="选择语音编码模型文件." + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: Tuple[str, int] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + sr, count = self.__root__ + streamlit_app.subheader(f"Dataset {sr} done processed total of {count}") + +def preprocess(input: Input) -> Output: + """Preprocess(预处理)""" + finished = 0 + if input.model == Model.VC_PPG2MEL: + from ppg2mel.preprocess import preprocess_dataset + finished = preprocess_dataset( + datasets_root=Path(input.datasets_root), + dataset=input.dataset, + out_dir=Path(input.output_root), + n_processes=input.n_processes, + ppg_encoder_model_fpath=Path(input.extractor.value), + speaker_encoder_model=Path(input.encoder.value) + ) + # TODO: pass useful return code + return Output(__root__=(input.dataset, finished)) \ No newline at end of file diff --git a/mkgui/static/mb.png b/mkgui/static/mb.png new file mode 100644 index 0000000000000000000000000000000000000000..abd804cab48147cdfafc4a385cf501322bca6e1c Binary files /dev/null and b/mkgui/static/mb.png differ diff --git a/mkgui/train.py b/mkgui/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7104d5469eebcf7046450e08d4a5836f87705c39 --- /dev/null +++ b/mkgui/train.py @@ -0,0 +1,106 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from typing import Any +from synthesizer.hparams import hparams +from synthesizer.train import train as synt_train + +# Constants +SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models" +ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" + + +# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" +# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models" +# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" + +# Pre-Load models +if os.path.isdir(SYN_MODELS_DIRT): + synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded synthesizer models: " + str(len(synthesizers))) +else: + raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + +class Model(str, Enum): + DEFAULT = "default" + +class Input(BaseModel): + model: Model = Field( + Model.DEFAULT, title="模型类型", + ) + # datasets_root: str = Field( + # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", + # format=True, + # example="..\\trainning_data\\" + # ) + input_root: str = Field( + ..., alias="输入目录", description="预处理数据根目录", + format=True, + example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer" + ) + run_id: str = Field( + "", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练", + ) + synthesizer: synthesizers = Field( + ..., alias="已有合成模型", + description="选择语音合成模型文件." + ) + gpu: bool = Field( + True, alias="GPU训练", description="选择“是”,则使用GPU训练", + ) + verbose: bool = Field( + True, alias="打印详情", description="选择“是”,输出更多详情", + ) + encoder: encoders = Field( + ..., alias="语音编码模型", + description="选择语音编码模型文件." + ) + save_every: int = Field( + 1000, alias="更新间隔", description="每隔n步则更新一次模型", + ) + backup_every: int = Field( + 10000, alias="保存间隔", description="每隔n步则保存一次模型", + ) + log_every: int = Field( + 500, alias="打印间隔", description="每隔n步则打印一次训练统计", + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: int + + def render_output_ui(self, streamlit_app) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + streamlit_app.subheader(f"Training started with code: {self.__root__}") + +def train(input: Input) -> Output: + """Train(训练)""" + + print(">>> Start training ...") + force_restart = len(input.run_id) > 0 + if not force_restart: + input.run_id = Path(input.synthesizer.value).name.split('.')[0] + + synt_train( + input.run_id, + input.input_root, + f"synthesizer{os.sep}saved_models", + input.save_every, + input.backup_every, + input.log_every, + force_restart, + hparams + ) + return Output(__root__=0) \ No newline at end of file diff --git a/mkgui/train_vc.py b/mkgui/train_vc.py new file mode 100644 index 0000000000000000000000000000000000000000..8c233724b6b5572903069a1a2c2a9d41dd3f2167 --- /dev/null +++ b/mkgui/train_vc.py @@ -0,0 +1,155 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from typing import Any, Tuple +import numpy as np +from utils.load_yaml import HpsYaml +from utils.util import AttrDict +import torch + +# Constants +EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" +CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models" +ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" + + +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +else: + raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(CONV_MODELS_DIRT): + convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) + print("Loaded convertor models: " + str(len(convertors))) +else: + raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + +class Model(str, Enum): + VC_PPG2MEL = "ppg2mel" + +class Dataset(str, Enum): + AIDATATANG_200ZH = "aidatatang_200zh" + AIDATATANG_200ZH_S = "aidatatang_200zh_s" + +class Input(BaseModel): + # def render_input_ui(st, input) -> Dict: + # input["selected_dataset"] = st.selectbox( + # '选择数据集', + # ("aidatatang_200zh", "aidatatang_200zh_s") + # ) + # return input + model: Model = Field( + Model.VC_PPG2MEL, title="模型类型", + ) + # datasets_root: str = Field( + # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", + # format=True, + # example="..\\trainning_data\\" + # ) + output_root: str = Field( + ..., alias="输出目录(可选)", description="建议不填,保持默认", + format=True, + example="" + ) + continue_mode: bool = Field( + True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练", + ) + gpu: bool = Field( + True, alias="GPU训练", description="选择“是”,则使用GPU训练", + ) + verbose: bool = Field( + True, alias="打印详情", description="选择“是”,输出更多详情", + ) + # TODO: Move to hiden fields by default + convertor: convertors = Field( + ..., alias="转换模型", + description="选择语音转换模型文件." + ) + extractor: extractors = Field( + ..., alias="特征提取模型", + description="选择PPG特征提取模型文件." + ) + encoder: encoders = Field( + ..., alias="语音编码模型", + description="选择语音编码模型文件." + ) + njobs: int = Field( + 8, alias="进程数", description="适用于ppg2mel", + ) + seed: int = Field( + default=0, alias="初始随机数", description="适用于ppg2mel", + ) + model_name: str = Field( + ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效", + example="test" + ) + model_config: str = Field( + ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效", + example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2" + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: Tuple[str, int] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + sr, count = self.__root__ + streamlit_app.subheader(f"Dataset {sr} done processed total of {count}") + +def train_vc(input: Input) -> Output: + """Train VC(训练 VC)""" + + print(">>> OneShot VC training ...") + params = AttrDict() + params.update({ + "gpu": input.gpu, + "cpu": not input.gpu, + "njobs": input.njobs, + "seed": input.seed, + "verbose": input.verbose, + "load": input.convertor.value, + "warm_start": False, + }) + if input.continue_mode: + # trace old model and config + p = Path(input.convertor.value) + params.name = p.parent.name + # search a config file + model_config_fpaths = list(p.parent.rglob("*.yaml")) + if len(model_config_fpaths) == 0: + raise "No model yaml config found for convertor" + config = HpsYaml(model_config_fpaths[0]) + params.ckpdir = p.parent.parent + params.config = model_config_fpaths[0] + params.logdir = os.path.join(p.parent, "log") + else: + # Make the config dict dot visitable + config = HpsYaml(input.config) + np.random.seed(input.seed) + torch.manual_seed(input.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(input.seed) + mode = "train" + from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + solver = Solver(config, params, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + + # TODO: pass useful return code + return Output(__root__=(input.dataset, 0)) \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..f6347069125d292a246f2e461ac5f7829b297478 --- /dev/null +++ b/packages.txt @@ -0,0 +1,5 @@ +libasound2-dev +portaudio19-dev +libportaudio2 +libportaudiocpp0 +ffmpeg \ No newline at end of file diff --git a/ppg2mel/__init__.py b/ppg2mel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc54db831d54907582206146e26cf797828787c3 --- /dev/null +++ b/ppg2mel/__init__.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Songxiang Liu +# Apache 2.0 + +from typing import List + +import torch +import torch.nn.functional as F + +import numpy as np + +from .utils.abs_model import AbsMelDecoder +from .rnn_decoder_mol import Decoder +from .utils.cnn_postnet import Postnet +from .utils.vc_utils import get_mask_from_lengths + +from utils.load_yaml import HpsYaml + +class MelDecoderMOLv2(AbsMelDecoder): + """Use an encoder to preprocess ppg.""" + def __init__( + self, + num_speakers: int, + spk_embed_dim: int, + bottle_neck_feature_dim: int, + encoder_dim: int = 256, + encoder_downsample_rates: List = [2, 2], + attention_rnn_dim: int = 512, + decoder_rnn_dim: int = 512, + num_decoder_rnn_layer: int = 1, + concat_context_to_last: bool = True, + prenet_dims: List = [256, 128], + num_mixtures: int = 5, + frames_per_step: int = 2, + mask_padding: bool = True, + ): + super().__init__() + + self.mask_padding = mask_padding + self.bottle_neck_feature_dim = bottle_neck_feature_dim + self.num_mels = 80 + self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1] + self.frames_per_step = frames_per_step + self.use_spk_dvec = True + + input_dim = bottle_neck_feature_dim + + # Downsampling convolution + self.bnf_prenet = torch.nn.Sequential( + torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[0], + stride=encoder_downsample_rates[0], + padding=encoder_downsample_rates[0]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[1], + stride=encoder_downsample_rates[1], + padding=encoder_downsample_rates[1]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + ) + decoder_enc_dim = encoder_dim + self.pitch_convs = torch.nn.Sequential( + torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[0], + stride=encoder_downsample_rates[0], + padding=encoder_downsample_rates[0]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[1], + stride=encoder_downsample_rates[1], + padding=encoder_downsample_rates[1]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + ) + + self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim) + + # Decoder + self.decoder = Decoder( + enc_dim=decoder_enc_dim, + num_mels=self.num_mels, + frames_per_step=frames_per_step, + attention_rnn_dim=attention_rnn_dim, + decoder_rnn_dim=decoder_rnn_dim, + num_decoder_rnn_layer=num_decoder_rnn_layer, + prenet_dims=prenet_dims, + num_mixtures=num_mixtures, + use_stop_tokens=True, + concat_context_to_last=concat_context_to_last, + encoder_down_factor=self.encoder_down_factor, + ) + + # Mel-Spec Postnet: some residual CNN layers + self.postnet = Postnet() + + def parse_output(self, outputs, output_lengths=None): + if self.mask_padding and output_lengths is not None: + mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1)) + mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels) + outputs[0].data.masked_fill_(mask, 0.0) + outputs[1].data.masked_fill_(mask, 0.0) + return outputs + + def forward( + self, + bottle_neck_features: torch.Tensor, + feature_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + logf0_uv: torch.Tensor = None, + spembs: torch.Tensor = None, + output_att_ws: bool = False, + ): + decoder_inputs = self.bnf_prenet( + bottle_neck_features.transpose(1, 2) + ).transpose(1, 2) + logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2) + decoder_inputs = decoder_inputs + logf0_uv + + assert spembs is not None + spk_embeds = F.normalize( + spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1) + decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1) + decoder_inputs = self.reduce_proj(decoder_inputs) + + # (B, num_mels, T_dec) + T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor') + mel_outputs, predicted_stop, alignments = self.decoder( + decoder_inputs, speech, T_dec) + ## Post-processing + mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + if output_att_ws: + return self.parse_output( + [mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths) + else: + return self.parse_output( + [mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths) + + # return mel_outputs, mel_outputs_postnet + + def inference( + self, + bottle_neck_features: torch.Tensor, + logf0_uv: torch.Tensor = None, + spembs: torch.Tensor = None, + ): + decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2) + logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2) + decoder_inputs = decoder_inputs + logf0_uv + + assert spembs is not None + spk_embeds = F.normalize( + spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1) + bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1) + bottle_neck_features = self.reduce_proj(bottle_neck_features) + + ## Decoder + if bottle_neck_features.size(0) > 1: + mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features) + else: + mel_outputs, alignments = self.decoder.inference(bottle_neck_features,) + ## Post-processing + mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + # outputs = mel_outputs_postnet[0] + + return mel_outputs[0], mel_outputs_postnet[0], alignments[0] + +def load_model(model_file, device=None): + # search a config file + model_config_fpaths = list(model_file.parent.rglob("*.yaml")) + if len(model_config_fpaths) == 0: + raise "No model yaml config found for convertor" + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model_config = HpsYaml(model_config_fpaths[0]) + ppg2mel_model = MelDecoderMOLv2( + **model_config["model"] + ).to(device) + ckpt = torch.load(model_file, map_location=device) + ppg2mel_model.load_state_dict(ckpt["model"]) + ppg2mel_model.eval() + return ppg2mel_model diff --git a/ppg2mel/preprocess.py b/ppg2mel/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..0feee6e2458ee770d1b94c53a043b1146b580cef --- /dev/null +++ b/ppg2mel/preprocess.py @@ -0,0 +1,113 @@ + +import os +import torch +import numpy as np +from tqdm import tqdm +from pathlib import Path +import soundfile +import resampy + +from ppg_extractor import load_model +import encoder.inference as Encoder +from encoder.audio import preprocess_wav +from encoder import audio +from utils.f0_utils import compute_f0 + +from torch.multiprocessing import Pool, cpu_count +from functools import partial + +SAMPLE_RATE=16000 + +def _compute_bnf( + wav: any, + output_fpath: str, + device: torch.device, + ppg_model_local: any, +): + """ + Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF). + """ + ppg_model_local.to(device) + wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0) + wav_length = torch.LongTensor([wav.shape[0]]).to(device) + with torch.no_grad(): + bnf = ppg_model_local(wav_tensor, wav_length) + bnf_npy = bnf.squeeze(0).cpu().numpy() + np.save(output_fpath, bnf_npy, allow_pickle=False) + return bnf_npy, len(bnf_npy) + +def _compute_f0_from_wav(wav, output_fpath): + """Compute merged f0 values.""" + f0 = compute_f0(wav, SAMPLE_RATE) + np.save(output_fpath, f0, allow_pickle=False) + return f0, len(f0) + +def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device): + Encoder.set_model(encoder_model_local) + # Compute where to split the utterance into partials and pad if necessary + wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75) + max_wave_length = wave_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials + frames = audio.wav_to_mel_spectrogram(wav) + frames_batch = np.array([frames[s] for s in mel_slices]) + partial_embeds = Encoder.embed_frames_batch(frames_batch) + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + np.save(output_fpath, embed, allow_pickle=False) + return embed, len(embed) + +def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local): + # wav = preprocess_wav(wav_path) + # try: + wav, sr = soundfile.read(wav_path) + if len(wav) < sr: + return None, sr, len(wav) + if sr != SAMPLE_RATE: + wav = resampy.resample(wav, sr, SAMPLE_RATE) + sr = SAMPLE_RATE + utt_id = os.path.basename(wav_path).rstrip(".wav") + + _, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) + _, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) + _, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav) + +def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model): + # Glob wav files + wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav")) + print(f"Globbed {len(wav_file_list)} wav files.") + + out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True) + out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True) + out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True) + ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") + encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu") + if n_processes is None: + n_processes = cpu_count() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device) + job = Pool(n_processes).imap(func, wav_file_list) + list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav")) + + # finish processing and mark + t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") + d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") + e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") + for file in sorted(out_dir.joinpath("f0").glob("*.npy")): + id = os.path.basename(file).split(".f0.npy")[0] + if id.endswith("01"): + d_fid_file.write(id + "\n") + elif id.endswith("09"): + e_fid_file.write(id + "\n") + else: + t_fid_file.write(id + "\n") + t_fid_file.close() + d_fid_file.close() + e_fid_file.close() + return len(wav_file_list) diff --git a/ppg2mel/rnn_decoder_mol.py b/ppg2mel/rnn_decoder_mol.py new file mode 100644 index 0000000000000000000000000000000000000000..9d48d7bc697baef107818569dc3e87a96708fb00 --- /dev/null +++ b/ppg2mel/rnn_decoder_mol.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .utils.mol_attention import MOLAttention +from .utils.basic_layers import Linear +from .utils.vc_utils import get_mask_from_lengths + + +class DecoderPrenet(nn.Module): + def __init__(self, in_dim, sizes): + super().__init__() + in_sizes = [in_dim] + sizes[:-1] + self.layers = nn.ModuleList( + [Linear(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, sizes)]) + + def forward(self, x): + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class Decoder(nn.Module): + """Mixture of Logistic (MoL) attention-based RNN Decoder.""" + def __init__( + self, + enc_dim, + num_mels, + frames_per_step, + attention_rnn_dim, + decoder_rnn_dim, + prenet_dims, + num_mixtures, + encoder_down_factor=1, + num_decoder_rnn_layer=1, + use_stop_tokens=False, + concat_context_to_last=False, + ): + super().__init__() + self.enc_dim = enc_dim + self.encoder_down_factor = encoder_down_factor + self.num_mels = num_mels + self.frames_per_step = frames_per_step + self.attention_rnn_dim = attention_rnn_dim + self.decoder_rnn_dim = decoder_rnn_dim + self.prenet_dims = prenet_dims + self.use_stop_tokens = use_stop_tokens + self.num_decoder_rnn_layer = num_decoder_rnn_layer + self.concat_context_to_last = concat_context_to_last + + # Mel prenet + self.prenet = DecoderPrenet(num_mels, prenet_dims) + self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims) + + # Attention RNN + self.attention_rnn = nn.LSTMCell( + prenet_dims[-1] + enc_dim, + attention_rnn_dim + ) + + # Attention + self.attention_layer = MOLAttention( + attention_rnn_dim, + r=frames_per_step/encoder_down_factor, + M=num_mixtures, + ) + + # Decoder RNN + self.decoder_rnn_layers = nn.ModuleList() + for i in range(num_decoder_rnn_layer): + if i == 0: + self.decoder_rnn_layers.append( + nn.LSTMCell( + enc_dim + attention_rnn_dim, + decoder_rnn_dim)) + else: + self.decoder_rnn_layers.append( + nn.LSTMCell( + decoder_rnn_dim, + decoder_rnn_dim)) + # self.decoder_rnn = nn.LSTMCell( + # 2 * enc_dim + attention_rnn_dim, + # decoder_rnn_dim + # ) + if concat_context_to_last: + self.linear_projection = Linear( + enc_dim + decoder_rnn_dim, + num_mels * frames_per_step + ) + else: + self.linear_projection = Linear( + decoder_rnn_dim, + num_mels * frames_per_step + ) + + + # Stop-token layer + if self.use_stop_tokens: + if concat_context_to_last: + self.stop_layer = Linear( + enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid" + ) + else: + self.stop_layer = Linear( + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid" + ) + + + def get_go_frame(self, memory): + B = memory.size(0) + go_frame = torch.zeros((B, self.num_mels), dtype=torch.float, + device=memory.device) + return go_frame + + def initialize_decoder_states(self, memory, mask): + device = next(self.parameters()).device + B = memory.size(0) + + # attention rnn states + self.attention_hidden = torch.zeros( + (B, self.attention_rnn_dim), device=device) + self.attention_cell = torch.zeros( + (B, self.attention_rnn_dim), device=device) + + # decoder rnn states + self.decoder_hiddens = [] + self.decoder_cells = [] + for i in range(self.num_decoder_rnn_layer): + self.decoder_hiddens.append( + torch.zeros((B, self.decoder_rnn_dim), + device=device) + ) + self.decoder_cells.append( + torch.zeros((B, self.decoder_rnn_dim), + device=device) + ) + # self.decoder_hidden = torch.zeros( + # (B, self.decoder_rnn_dim), device=device) + # self.decoder_cell = torch.zeros( + # (B, self.decoder_rnn_dim), device=device) + + self.attention_context = torch.zeros( + (B, self.enc_dim), device=device) + + self.memory = memory + # self.processed_memory = self.attention_layer.memory_layer(memory) + self.mask = mask + + def parse_decoder_inputs(self, decoder_inputs): + """Prepare decoder inputs, i.e. gt mel + Args: + decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training. + """ + decoder_inputs = decoder_inputs.reshape( + decoder_inputs.size(0), + int(decoder_inputs.size(1)/self.frames_per_step), -1) + # (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels) + decoder_inputs = decoder_inputs.transpose(0, 1) + # (T_out//r, B, num_mels) + decoder_inputs = decoder_inputs[:,:,-self.num_mels:] + return decoder_inputs + + def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs): + """ Prepares decoder outputs for output + Args: + mel_outputs: + alignments: + """ + # (T_out//r, B, T_enc) -> (B, T_out//r, T_enc) + alignments = torch.stack(alignments).transpose(0, 1) + # (T_out//r, B) -> (B, T_out//r) + if stop_outputs is not None: + if alignments.size(0) == 1: + stop_outputs = torch.stack(stop_outputs).unsqueeze(0) + else: + stop_outputs = torch.stack(stop_outputs).transpose(0, 1) + stop_outputs = stop_outputs.contiguous() + # (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r) + mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() + # decouple frames per step + # (B, T_out, num_mels) + mel_outputs = mel_outputs.view( + mel_outputs.size(0), -1, self.num_mels) + return mel_outputs, alignments, stop_outputs + + def attend(self, decoder_input): + cell_input = torch.cat((decoder_input, self.attention_context), -1) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_context, attention_weights = self.attention_layer( + self.attention_hidden, self.memory, None, self.mask) + + decoder_rnn_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) + + return decoder_rnn_input, self.attention_context, attention_weights + + def decode(self, decoder_input): + for i in range(self.num_decoder_rnn_layer): + if i == 0: + self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i]( + decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i])) + else: + self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i]( + self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i])) + return self.decoder_hiddens[-1] + + def forward(self, memory, mel_inputs, memory_lengths): + """ Decoder forward pass for training + Args: + memory: (B, T_enc, enc_dim) Encoder outputs + decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing. + memory_lengths: (B, ) Encoder output lengths for attention masking. + Returns: + mel_outputs: (B, T, num_mels) mel outputs from the decoder + alignments: (B, T//r, T_enc) attention weights. + """ + # [1, B, num_mels] + go_frame = self.get_go_frame(memory).unsqueeze(0) + # [T//r, B, num_mels] + mel_inputs = self.parse_decoder_inputs(mel_inputs) + # [T//r + 1, B, num_mels] + mel_inputs = torch.cat((go_frame, mel_inputs), dim=0) + # [T//r + 1, B, prenet_dim] + decoder_inputs = self.prenet(mel_inputs) + # decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__) + + self.initialize_decoder_states( + memory, mask=~get_mask_from_lengths(memory_lengths), + ) + + self.attention_layer.init_states(memory) + # self.attention_layer_pitch.init_states(memory_pitch) + + mel_outputs, alignments = [], [] + if self.use_stop_tokens: + stop_outputs = [] + else: + stop_outputs = None + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] + # decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)] + + decoder_rnn_input, context, attention_weights = self.attend(decoder_input) + + decoder_rnn_output = self.decode(decoder_rnn_input) + if self.concat_context_to_last: + decoder_rnn_output = torch.cat( + (decoder_rnn_output, context), dim=1) + + mel_output = self.linear_projection(decoder_rnn_output) + if self.use_stop_tokens: + stop_output = self.stop_layer(decoder_rnn_output) + stop_outputs += [stop_output.squeeze()] + mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze + alignments += [attention_weights] + # alignments_pitch += [attention_weights_pitch] + + mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs( + mel_outputs, alignments, stop_outputs) + if stop_outputs is None: + return mel_outputs, alignments + else: + return mel_outputs, stop_outputs, alignments + + def inference(self, memory, stop_threshold=0.5): + """ Decoder inference + Args: + memory: (1, T_enc, D_enc) Encoder outputs + Returns: + mel_outputs: mel outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + # [1, num_mels] + decoder_input = self.get_go_frame(memory) + + self.initialize_decoder_states(memory, mask=None) + + self.attention_layer.init_states(memory) + + mel_outputs, alignments = [], [] + # NOTE(sx): heuristic + max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step + min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5 + while True: + decoder_input = self.prenet(decoder_input) + + decoder_input_final, context, alignment = self.attend(decoder_input) + + #mel_output, stop_output, alignment = self.decode(decoder_input) + decoder_rnn_output = self.decode(decoder_input_final) + if self.concat_context_to_last: + decoder_rnn_output = torch.cat( + (decoder_rnn_output, context), dim=1) + + mel_output = self.linear_projection(decoder_rnn_output) + stop_output = self.stop_layer(decoder_rnn_output) + + mel_outputs += [mel_output.squeeze(1)] + alignments += [alignment] + + if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step: + break + if len(mel_outputs) >= max_decoder_step: + # print("Warning! Decoding steps reaches max decoder steps.") + break + + decoder_input = mel_output[:,-self.num_mels:] + + + mel_outputs, alignments, _ = self.parse_decoder_outputs( + mel_outputs, alignments, None) + + return mel_outputs, alignments + + def inference_batched(self, memory, stop_threshold=0.5): + """ Decoder inference + Args: + memory: (B, T_enc, D_enc) Encoder outputs + Returns: + mel_outputs: mel outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + # [1, num_mels] + decoder_input = self.get_go_frame(memory) + + self.initialize_decoder_states(memory, mask=None) + + self.attention_layer.init_states(memory) + + mel_outputs, alignments = [], [] + stop_outputs = [] + # NOTE(sx): heuristic + max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step + min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5 + while True: + decoder_input = self.prenet(decoder_input) + + decoder_input_final, context, alignment = self.attend(decoder_input) + + #mel_output, stop_output, alignment = self.decode(decoder_input) + decoder_rnn_output = self.decode(decoder_input_final) + if self.concat_context_to_last: + decoder_rnn_output = torch.cat( + (decoder_rnn_output, context), dim=1) + + mel_output = self.linear_projection(decoder_rnn_output) + # (B, 1) + stop_output = self.stop_layer(decoder_rnn_output) + stop_outputs += [stop_output.squeeze()] + # stop_outputs.append(stop_output) + + mel_outputs += [mel_output.squeeze(1)] + alignments += [alignment] + # print(stop_output.shape) + if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \ + and len(mel_outputs) >= min_decoder_step: + break + if len(mel_outputs) >= max_decoder_step: + # print("Warning! Decoding steps reaches max decoder steps.") + break + + decoder_input = mel_output[:,-self.num_mels:] + + + mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs( + mel_outputs, alignments, stop_outputs) + mel_outputs_stacked = [] + for mel, stop_logit in zip(mel_outputs, stop_outputs): + idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item() + mel_outputs_stacked.append(mel[:idx,:]) + mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0) + return mel_outputs, alignments diff --git a/ppg2mel/train.py b/ppg2mel/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ef729075a837a680175559ecbdde0b398a73a9 --- /dev/null +++ b/ppg2mel/train.py @@ -0,0 +1,62 @@ +import sys +import torch +import argparse +import numpy as np +from utils.load_yaml import HpsYaml +from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + +# For reproducibility, comment these may speed up training +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +def main(): + # Arguments + parser = argparse.ArgumentParser(description= + 'Training PPG2Mel VC model.') + parser.add_argument('--config', type=str, + help='Path to experiment config, e.g., config/vc.yaml') + parser.add_argument('--name', default=None, type=str, help='Name for logging.') + parser.add_argument('--logdir', default='log/', type=str, + help='Logging path.', required=False) + parser.add_argument('--ckpdir', default='ckpt/', type=str, + help='Checkpoint path.', required=False) + parser.add_argument('--outdir', default='result/', type=str, + help='Decode output path.', required=False) + parser.add_argument('--load', default=None, type=str, + help='Load pre-trained model (for training only)', required=False) + parser.add_argument('--warm_start', action='store_true', + help='Load model weights only, ignore specified layers.') + parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducable results.', required=False) + parser.add_argument('--njobs', default=8, type=int, + help='Number of threads for dataloader/decoding.', required=False) + parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') + # parser.add_argument('--no-pin', action='store_true', + # help='Disable pin-memory for dataloader') + parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') + + ### + + paras = parser.parse_args() + setattr(paras, 'gpu', not paras.cpu) + setattr(paras, 'pin_memory', not paras.no_pin) + setattr(paras, 'verbose', not paras.no_msg) + # Make the config dict dot visitable + config = HpsYaml(paras.config) + + np.random.seed(paras.seed) + torch.manual_seed(paras.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(paras.seed) + + print(">>> OneShot VC training ...") + mode = "train" + solver = Solver(config, paras, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/ppg2mel/train/__init__.py b/ppg2mel/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4287ca8617970fa8fc025b75cb319c7032706910 --- /dev/null +++ b/ppg2mel/train/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/ppg2mel/train/loss.py b/ppg2mel/train/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..301248cc1ef24c549499e10396ae6c3afab3ba09 --- /dev/null +++ b/ppg2mel/train/loss.py @@ -0,0 +1,50 @@ +from typing import Dict +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils.nets_utils import make_pad_mask + + +class MaskedMSELoss(nn.Module): + def __init__(self, frames_per_step): + super().__init__() + self.frames_per_step = frames_per_step + self.mel_loss_criterion = nn.MSELoss(reduction='none') + # self.loss = nn.MSELoss() + self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none') + + def get_mask(self, lengths, max_len=None): + # lengths: [B,] + if max_len is None: + max_len = torch.max(lengths) + batch_size = lengths.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device) + seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand) + return (seq_range_expand < seq_length_expand).float() + + def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths, + stop_target, stop_pred): + ## process stop_target + B = stop_target.size(0) + stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0] + stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long() + stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step)) + + mel_trg.requires_grad = False + # (B, T, 1) + mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1) + # (B, T, D) + mel_mask = mel_mask.expand_as(mel_trg) + mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum() + mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum() + + mel_loss = mel_loss_pre + mel_loss_post + + # stop token loss + stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum() + + return mel_loss, stop_loss diff --git a/ppg2mel/train/optim.py b/ppg2mel/train/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..62533b95864019df1eca855287cc0bcdb53745d4 --- /dev/null +++ b/ppg2mel/train/optim.py @@ -0,0 +1,45 @@ +import torch +import numpy as np + + +class Optimizer(): + def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, + **kwargs): + + # Setup torch optimizer + self.opt_type = optimizer + self.init_lr = lr + self.sch_type = lr_scheduler + opt = getattr(torch.optim, optimizer) + if lr_scheduler == 'warmup': + warmup_step = 4000.0 + init_lr = lr + self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \ + np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5) + self.opt = opt(parameters, lr=1.0) + else: + self.lr_scheduler = None + self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better? + + def get_opt_state_dict(self): + return self.opt.state_dict() + + def load_opt_state_dict(self, state_dict): + self.opt.load_state_dict(state_dict) + + def pre_step(self, step): + if self.lr_scheduler is not None: + cur_lr = self.lr_scheduler(step) + for param_group in self.opt.param_groups: + param_group['lr'] = cur_lr + else: + cur_lr = self.init_lr + self.opt.zero_grad() + return cur_lr + + def step(self): + self.opt.step() + + def create_msg(self): + return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})' + .format(self.opt_type, self.init_lr, self.sch_type)] diff --git a/ppg2mel/train/option.py b/ppg2mel/train/option.py new file mode 100644 index 0000000000000000000000000000000000000000..f66c600b84e0404c7937bacf8653776ce9be74c0 --- /dev/null +++ b/ppg2mel/train/option.py @@ -0,0 +1,10 @@ +# Default parameters which will be imported by solver +default_hparas = { + 'GRAD_CLIP': 5.0, # Grad. clip threshold + 'PROGRESS_STEP': 100, # Std. output refresh freq. + # Decode steps for objective validation (step = ratio*input_txt_len) + 'DEV_STEP_RATIO': 1.2, + # Number of examples (alignment/text) to show in tensorboard + 'DEV_N_EXAMPLE': 4, + 'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs) +} diff --git a/ppg2mel/train/solver.py b/ppg2mel/train/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca71cbf2a6b621fa299245f831d4d723ba56977 --- /dev/null +++ b/ppg2mel/train/solver.py @@ -0,0 +1,217 @@ +import os +import sys +import abc +import math +import yaml +import torch +from torch.utils.tensorboard import SummaryWriter + +from .option import default_hparas +from utils.util import human_format, Timer +from utils.load_yaml import HpsYaml + + +class BaseSolver(): + ''' + Prototype Solver for all kinds of tasks + Arguments + config - yaml-styled config + paras - argparse outcome + mode - "train"/"test" + ''' + + def __init__(self, config, paras, mode="train"): + # General Settings + self.config = config # load from yaml file + self.paras = paras # command line args + self.mode = mode # 'train' or 'test' + for k, v in default_hparas.items(): + setattr(self, k, v) + self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \ + else torch.device('cpu') + + # Name experiment + self.exp_name = paras.name + if self.exp_name is None: + if 'exp_name' in self.config: + self.exp_name = self.config.exp_name + else: + # By default, exp is named after config file + self.exp_name = paras.config.split('/')[-1].replace('.yaml', '') + if mode == 'train': + self.exp_name += '_seed{}'.format(paras.seed) + + + if mode == 'train': + # Filepath setup + os.makedirs(paras.ckpdir, exist_ok=True) + self.ckpdir = os.path.join(paras.ckpdir, self.exp_name) + os.makedirs(self.ckpdir, exist_ok=True) + + # Logger settings + self.logdir = os.path.join(paras.logdir, self.exp_name) + self.log = SummaryWriter( + self.logdir, flush_secs=self.TB_FLUSH_FREQ) + self.timer = Timer() + + # Hyper-parameters + self.step = 0 + self.valid_step = config.hparas.valid_step + self.max_step = config.hparas.max_step + + self.verbose('Exp. name : {}'.format(self.exp_name)) + self.verbose('Loading data... large corpus may took a while.') + + # elif mode == 'test': + # # Output path + # os.makedirs(paras.outdir, exist_ok=True) + # self.ckpdir = os.path.join(paras.outdir, self.exp_name) + + # Load training config to get acoustic feat and build model + # self.src_config = HpsYaml(config.src.config) + # self.paras.load = config.src.ckpt + + # self.verbose('Evaluating result of tr. config @ {}'.format( + # config.src.config)) + + def backward(self, loss): + ''' + Standard backward step with self.timer and debugger + Arguments + loss - the loss to perform loss.backward() + ''' + self.timer.set() + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.GRAD_CLIP) + if math.isnan(grad_norm): + self.verbose('Error : grad norm is NaN @ step '+str(self.step)) + else: + self.optimizer.step() + self.timer.cnt('bw') + return grad_norm + + def load_ckpt(self): + ''' Load ckpt if --load option is specified ''' + print(self.paras) + if self.paras.load is not None: + if self.paras.warm_start: + self.verbose(f"Warm starting model from checkpoint {self.paras.load}.") + ckpt = torch.load( + self.paras.load, map_location=self.device if self.mode == 'train' + else 'cpu') + model_dict = ckpt['model'] + if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0: + model_dict = {k:v for k, v in model_dict.items() + if k not in self.config.model.ignore_layers} + dummy_dict = self.model.state_dict() + dummy_dict.update(model_dict) + model_dict = dummy_dict + self.model.load_state_dict(model_dict) + else: + # Load weights + ckpt = torch.load( + self.paras.load, map_location=self.device if self.mode == 'train' + else 'cpu') + self.model.load_state_dict(ckpt['model']) + + # Load task-dependent items + if self.mode == 'train': + self.step = ckpt['global_step'] + self.optimizer.load_opt_state_dict(ckpt['optimizer']) + self.verbose('Load ckpt from {}, restarting at step {}'.format( + self.paras.load, self.step)) + else: + for k, v in ckpt.items(): + if type(v) is float: + metric, score = k, v + self.model.eval() + self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format( + self.paras.load, metric, score)) + + def verbose(self, msg): + ''' Verbose function for print information to stdout''' + if self.paras.verbose: + if type(msg) == list: + for m in msg: + print('[INFO]', m.ljust(100)) + else: + print('[INFO]', msg.ljust(100)) + + def progress(self, msg): + ''' Verbose function for updating progress on stdout (do not include newline) ''' + if self.paras.verbose: + sys.stdout.write("\033[K") # Clear line + print('[{}] {}'.format(human_format(self.step), msg), end='\r') + + def write_log(self, log_name, log_dict): + ''' + Write log to TensorBoard + log_name - Name of tensorboard variable + log_value - / Value of variable (e.g. dict of losses), passed if value = None + ''' + if type(log_dict) is dict: + log_dict = {key: val for key, val in log_dict.items() if ( + val is not None and not math.isnan(val))} + if log_dict is None: + pass + elif len(log_dict) > 0: + if 'align' in log_name or 'spec' in log_name: + img, form = log_dict + self.log.add_image( + log_name, img, global_step=self.step, dataformats=form) + elif 'text' in log_name or 'hyp' in log_name: + self.log.add_text(log_name, log_dict, self.step) + else: + self.log.add_scalars(log_name, log_dict, self.step) + + def save_checkpoint(self, f_name, metric, score, show_msg=True): + '''' + Ckpt saver + f_name - the name of ckpt file (w/o prefix) to store, overwrite if existed + score - The value of metric used to evaluate model + ''' + ckpt_path = os.path.join(self.ckpdir, f_name) + full_dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.get_opt_state_dict(), + "global_step": self.step, + metric: score + } + + torch.save(full_dict, ckpt_path) + if show_msg: + self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}". + format(human_format(self.step), metric, score, ckpt_path)) + + + # ----------------------------------- Abtract Methods ------------------------------------------ # + @abc.abstractmethod + def load_data(self): + ''' + Called by main to load all data + After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set) + No return value + ''' + raise NotImplementedError + + @abc.abstractmethod + def set_model(self): + ''' + Called by main to set models + After this call, model related attributes should be setup (e.g. self.l2_loss) + The followings MUST be setup + - self.model (torch.nn.Module) + - self.optimizer (src.Optimizer), + init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas']) + Loading pre-trained model should also be performed here + No return value + ''' + raise NotImplementedError + + @abc.abstractmethod + def exec(self): + ''' + Called by main to execute training/inference + ''' + raise NotImplementedError diff --git a/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py b/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py new file mode 100644 index 0000000000000000000000000000000000000000..daf1c6a00d7fe9d0e7ef319b980f92a07bbd6774 --- /dev/null +++ b/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py @@ -0,0 +1,288 @@ +import os, sys +# sys.path.append('/home/shaunxliu/projects/nnsp') +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.ticker import MaxNLocator +import torch +from torch.utils.data import DataLoader +import numpy as np +from .solver import BaseSolver +from utils.data_load import OneshotVcDataset, MultiSpkVcCollate +# from src.rnn_ppg2mel import BiRnnPpg2MelModel +# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL +from .loss import MaskedMSELoss +from .optim import Optimizer +from utils.util import human_format +from ppg2mel import MelDecoderMOLv2 + + +class Solver(BaseSolver): + """Customized Solver.""" + def __init__(self, config, paras, mode): + super().__init__(config, paras, mode) + self.num_att_plots = 5 + self.att_ws_dir = f"{self.logdir}/att_ws" + os.makedirs(self.att_ws_dir, exist_ok=True) + self.best_loss = np.inf + + def fetch_data(self, data): + """Move data to device""" + data = [i.to(self.device) for i in data] + return data + + def load_data(self): + """ Load data for training/validation/plotting.""" + train_dataset = OneshotVcDataset( + meta_file=self.config.data.train_fid_list, + vctk_ppg_dir=self.config.data.vctk_ppg_dir, + libri_ppg_dir=self.config.data.libri_ppg_dir, + vctk_f0_dir=self.config.data.vctk_f0_dir, + libri_f0_dir=self.config.data.libri_f0_dir, + vctk_wav_dir=self.config.data.vctk_wav_dir, + libri_wav_dir=self.config.data.libri_wav_dir, + vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir, + libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir, + ppg_file_ext=self.config.data.ppg_file_ext, + min_max_norm_mel=self.config.data.min_max_norm_mel, + mel_min=self.config.data.mel_min, + mel_max=self.config.data.mel_max, + ) + dev_dataset = OneshotVcDataset( + meta_file=self.config.data.dev_fid_list, + vctk_ppg_dir=self.config.data.vctk_ppg_dir, + libri_ppg_dir=self.config.data.libri_ppg_dir, + vctk_f0_dir=self.config.data.vctk_f0_dir, + libri_f0_dir=self.config.data.libri_f0_dir, + vctk_wav_dir=self.config.data.vctk_wav_dir, + libri_wav_dir=self.config.data.libri_wav_dir, + vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir, + libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir, + ppg_file_ext=self.config.data.ppg_file_ext, + min_max_norm_mel=self.config.data.min_max_norm_mel, + mel_min=self.config.data.mel_min, + mel_max=self.config.data.mel_max, + ) + self.train_dataloader = DataLoader( + train_dataset, + num_workers=self.paras.njobs, + shuffle=True, + batch_size=self.config.hparas.batch_size, + pin_memory=False, + drop_last=True, + collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step, + use_spk_dvec=True), + ) + self.dev_dataloader = DataLoader( + dev_dataset, + num_workers=self.paras.njobs, + shuffle=False, + batch_size=self.config.hparas.batch_size, + pin_memory=False, + drop_last=False, + collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step, + use_spk_dvec=True), + ) + self.plot_dataloader = DataLoader( + dev_dataset, + num_workers=self.paras.njobs, + shuffle=False, + batch_size=1, + pin_memory=False, + drop_last=False, + collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step, + use_spk_dvec=True, + give_uttids=True), + ) + msg = "Have prepared training set and dev set." + self.verbose(msg) + + def load_pretrained_params(self): + print("Load pretrained model from: ", self.config.data.pretrain_model_file) + ignore_layer_prefixes = ["speaker_embedding_table"] + pretrain_model_file = self.config.data.pretrain_model_file + pretrain_ckpt = torch.load( + pretrain_model_file, map_location=self.device + )["model"] + model_dict = self.model.state_dict() + print(self.model) + + # 1. filter out unnecessrary keys + for prefix in ignore_layer_prefixes: + pretrain_ckpt = {k : v + for k, v in pretrain_ckpt.items() if not k.startswith(prefix) + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrain_ckpt) + + # 3. load the new state dict + self.model.load_state_dict(model_dict) + + def set_model(self): + """Setup model and optimizer""" + # Model + print("[INFO] Model name: ", self.config["model_name"]) + self.model = MelDecoderMOLv2( + **self.config["model"] + ).to(self.device) + # self.load_pretrained_params() + + # model_params = [{'params': self.model.spk_embedding.weight}] + model_params = [{'params': self.model.parameters()}] + + # Loss criterion + self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step) + + # Optimizer + self.optimizer = Optimizer(model_params, **self.config["hparas"]) + self.verbose(self.optimizer.create_msg()) + + # Automatically load pre-trained model if self.paras.load is given + self.load_ckpt() + + def exec(self): + self.verbose("Total training steps {}.".format( + human_format(self.max_step))) + + mel_loss = None + n_epochs = 0 + # Set as current time + self.timer.set() + + while self.step < self.max_step: + for data in self.train_dataloader: + # Pre-step: updata lr_rate and do zero_grad + lr_rate = self.optimizer.pre_step(self.step) + total_loss = 0 + # data to device + ppgs, lf0_uvs, mels, in_lengths, \ + out_lengths, spk_ids, stop_tokens = self.fetch_data(data) + self.timer.cnt("rd") + mel_outputs, mel_outputs_postnet, predicted_stop = self.model( + ppgs, + in_lengths, + mels, + out_lengths, + lf0_uvs, + spk_ids + ) + mel_loss, stop_loss = self.loss_criterion( + mel_outputs, + mel_outputs_postnet, + mels, + out_lengths, + stop_tokens, + predicted_stop + ) + loss = mel_loss + stop_loss + + self.timer.cnt("fw") + + # Back-prop + grad_norm = self.backward(loss) + self.step += 1 + + # Logger + if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): + self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}" + .format(loss.cpu().item(), mel_loss.cpu().item(), + stop_loss.cpu().item(), grad_norm, self.timer.show())) + self.write_log('loss', {'tr/loss': loss, + 'tr/mel-loss': mel_loss, + 'tr/stop-loss': stop_loss}) + + # Validation + if (self.step == 1) or (self.step % self.valid_step == 0): + self.validate() + + # End of step + # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 + torch.cuda.empty_cache() + self.timer.set() + if self.step > self.max_step: + break + n_epochs += 1 + self.log.close() + + def validate(self): + self.model.eval() + dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0 + + for i, data in enumerate(self.dev_dataloader): + self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader))) + # Fetch data + ppgs, lf0_uvs, mels, in_lengths, \ + out_lengths, spk_ids, stop_tokens = self.fetch_data(data) + with torch.no_grad(): + mel_outputs, mel_outputs_postnet, predicted_stop = self.model( + ppgs, + in_lengths, + mels, + out_lengths, + lf0_uvs, + spk_ids + ) + mel_loss, stop_loss = self.loss_criterion( + mel_outputs, + mel_outputs_postnet, + mels, + out_lengths, + stop_tokens, + predicted_stop + ) + loss = mel_loss + stop_loss + + dev_loss += loss.cpu().item() + dev_mel_loss += mel_loss.cpu().item() + dev_stop_loss += stop_loss.cpu().item() + + dev_loss = dev_loss / (i + 1) + dev_mel_loss = dev_mel_loss / (i + 1) + dev_stop_loss = dev_stop_loss / (i + 1) + self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False) + if dev_loss < self.best_loss: + self.best_loss = dev_loss + self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss) + self.write_log('loss', {'dv/loss': dev_loss, + 'dv/mel-loss': dev_mel_loss, + 'dv/stop-loss': dev_stop_loss}) + + # plot attention + for i, data in enumerate(self.plot_dataloader): + if i == self.num_att_plots: + break + # Fetch data + ppgs, lf0_uvs, mels, in_lengths, \ + out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1]) + fid = data[-1][0] + with torch.no_grad(): + _, _, _, att_ws = self.model( + ppgs, + in_lengths, + mels, + out_lengths, + lf0_uvs, + spk_ids, + output_att_ws=True + ) + att_ws = att_ws.squeeze(0).cpu().numpy() + att_ws = att_ws[None] + w, h = plt.figaspect(1.0 / len(att_ws)) + fig = plt.Figure(figsize=(w * 1.3, h * 1.3)) + axes = fig.subplots(1, len(att_ws)) + if len(att_ws) == 1: + axes = [axes] + + for ax, aw in zip(axes, att_ws): + ax.imshow(aw.astype(np.float32), aspect="auto") + ax.set_title(f"{fid}") + ax.set_xlabel("Input") + ax.set_ylabel("Output") + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png" + fig.savefig(fig_name) + + # Resume training + self.model.train() + diff --git a/ppg2mel/utils/abs_model.py b/ppg2mel/utils/abs_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d27a6df74c6988dd4355cbef149ed90f3a36cf --- /dev/null +++ b/ppg2mel/utils/abs_model.py @@ -0,0 +1,23 @@ +from abc import ABC +from abc import abstractmethod + +import torch + +class AbsMelDecoder(torch.nn.Module, ABC): + """The abstract PPG-based voice conversion class + This "model" is one of mediator objects for "Task" class. + + """ + + @abstractmethod + def forward( + self, + bottle_neck_features: torch.Tensor, + feature_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + logf0_uv: torch.Tensor = None, + spembs: torch.Tensor = None, + styleembs: torch.Tensor = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/ppg2mel/utils/basic_layers.py b/ppg2mel/utils/basic_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..45d80f1ef9e459a6e2d8494cf8d4ca1e599f772f --- /dev/null +++ b/ppg2mel/utils/basic_layers.py @@ -0,0 +1,79 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x + +class Linear(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(Linear, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + +class Conv1d(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): + super(Conv1d, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1)/2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) + + def forward(self, x): + # x: BxDxT + return self.conv(x) + + + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x diff --git a/ppg2mel/utils/cnn_postnet.py b/ppg2mel/utils/cnn_postnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1980cdd8421838e48fc8a977731054beb5eb8cc6 --- /dev/null +++ b/ppg2mel/utils/cnn_postnet.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .basic_layers import Linear, Conv1d + + +class Postnet(nn.Module): + """Postnet + - Five 1-d convolution with 512 channels and kernel size 5 + """ + def __init__(self, num_mels=80, + num_layers=5, + hidden_dim=512, + kernel_size=5): + super(Postnet, self).__init__() + self.convolutions = nn.ModuleList() + + self.convolutions.append( + nn.Sequential( + Conv1d( + num_mels, hidden_dim, + kernel_size=kernel_size, stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hidden_dim))) + + for i in range(1, num_layers - 1): + self.convolutions.append( + nn.Sequential( + Conv1d( + hidden_dim, + hidden_dim, + kernel_size=kernel_size, stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hidden_dim))) + + self.convolutions.append( + nn.Sequential( + Conv1d( + hidden_dim, num_mels, + kernel_size=kernel_size, stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, w_init_gain='linear'), + nn.BatchNorm1d(num_mels))) + + def forward(self, x): + # x: (B, num_mels, T_dec) + for i in range(len(self.convolutions) - 1): + x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) + x = F.dropout(self.convolutions[-1](x), 0.5, self.training) + return x diff --git a/ppg2mel/utils/mol_attention.py b/ppg2mel/utils/mol_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa91f8a4d3878efe8316798df9b87995a2fff4b --- /dev/null +++ b/ppg2mel/utils/mol_attention.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MOLAttention(nn.Module): + """ Discretized Mixture of Logistic (MOL) attention. + C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and + GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis". + """ + def __init__( + self, + query_dim, + r=1, + M=5, + ): + """ + Args: + query_dim: attention_rnn_dim. + M: number of mixtures. + """ + super().__init__() + if r < 1: + self.r = float(r) + else: + self.r = int(r) + self.M = M + self.score_mask_value = 0.0 # -float("inf") + self.eps = 1e-5 + # Position arrary for encoder time steps + self.J = None + # Query layer: [w, sigma,] + self.query_layer = torch.nn.Sequential( + nn.Linear(query_dim, 256, bias=True), + nn.ReLU(), + nn.Linear(256, 3*M, bias=True) + ) + self.mu_prev = None + self.initialize_bias() + + def initialize_bias(self): + """Initialize sigma and Delta.""" + # sigma + torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0) + # Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0 + # softplus(-0.432) = 0.5003 + if self.r == 2: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545) + elif self.r == 4: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815) + elif self.r == 1: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413) + else: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432) + + + def init_states(self, memory): + """Initialize mu_prev and J. + This function should be called by the decoder before decoding one batch. + Args: + memory: (B, T, D_enc) encoder output. + """ + B, T_enc, _ = memory.size() + device = memory.device + self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage + # self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float) + self.mu_prev = torch.zeros(B, self.M).to(device) + + def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None): + """ + att_rnn_h: attetion rnn hidden state. + memory: encoder outputs (B, T_enc, D). + mask: binary mask for padded data (B, T_enc). + """ + # [B, 3M] + mixture_params = self.query_layer(att_rnn_h) + + # [B, M] + w_hat = mixture_params[:, :self.M] + sigma_hat = mixture_params[:, self.M:2*self.M] + Delta_hat = mixture_params[:, 2*self.M:3*self.M] + + # print("w_hat: ", w_hat) + # print("sigma_hat: ", sigma_hat) + # print("Delta_hat: ", Delta_hat) + + # Dropout to de-correlate attention heads + w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed? + + # Mixture parameters + w = torch.softmax(w_hat, dim=-1) + self.eps + sigma = F.softplus(sigma_hat) + self.eps + Delta = F.softplus(Delta_hat) + mu_cur = self.mu_prev + Delta + # print("w:", w) + j = self.J[:memory.size(1) + 1] + + # Attention weights + # CDF of logistic distribution + phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid( + (mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1)))) + # print("phi_t:", phi_t) + + # Discretize attention weights + # (B, T_enc + 1) + alpha_t = torch.sum(phi_t, dim=1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + alpha_t[alpha_t == 0] = self.eps + # print("alpha_t: ", alpha_t.size()) + # Apply masking + if mask is not None: + alpha_t.data.masked_fill_(mask, self.score_mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1) + if memory_pitch is not None: + context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1) + + self.mu_prev = mu_cur + + if memory_pitch is not None: + return context, context_pitch, alpha_t + return context, alpha_t + diff --git a/ppg2mel/utils/nets_utils.py b/ppg2mel/utils/nets_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..098e3b4c5dfded0c05df1cf0138496c3303eb1e3 --- /dev/null +++ b/ppg2mel/utils/nets_utils.py @@ -0,0 +1,451 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" + +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + assert isinstance(m, torch.nn.Module) + device = next(m.parameters()).device + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, :xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError('length_dim cannot be 0: {}'.format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple(slice(None) if i in (0, length_dim) else None + for i in range(xs.dim())) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), + pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == 'c': + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if 'real' not in x or 'imag' not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x['real'], x['imag']) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ("x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x))) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the training args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == 'transformer': + return np.array([1]) + + elif mode == 'mt' and arch == 'rnn': + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + logging.warning('Subsampling is not performed for machine translation.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \ + (mode == 'mt' and arch == 'rnn') or \ + (mode == 'st' and arch == 'rnn'): + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mix': + subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mulenc': + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int) + if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Encoder %d: Subsampling is not performed for vgg*. ' + 'It is performed in max pooling layers at CNN.', idx + 1) + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch)) + + +def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f'Rename: {old_prefix} -> {new_prefix}') + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v diff --git a/ppg2mel/utils/vc_utils.py b/ppg2mel/utils/vc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b6bf01fa070bbe4cde3ce38973eda12ea0a464 --- /dev/null +++ b/ppg2mel/utils/vc_utils.py @@ -0,0 +1,22 @@ +import torch + + +def gcd(a, b): + """Greatest common divisor.""" + a, b = (a, b) if a >=b else (b, a) + if a%b == 0: + return b + else : + return gcd(b, a%b) + +def lcm(a, b): + """Least common multiple""" + return a * b // gcd(a, b) + +def get_mask_from_lengths(lengths, max_len=None): + if max_len is None: + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) + mask = (ids < lengths.unsqueeze(1)).bool() + return mask + diff --git a/ppg2mel_train.py b/ppg2mel_train.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6a06c805109159ff40cad69668f1fc38cf1e9b --- /dev/null +++ b/ppg2mel_train.py @@ -0,0 +1,67 @@ +import sys +import torch +import argparse +import numpy as np +from utils.load_yaml import HpsYaml +from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + +# For reproducibility, comment these may speed up training +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +def main(): + # Arguments + parser = argparse.ArgumentParser(description= + 'Training PPG2Mel VC model.') + parser.add_argument('--config', type=str, + help='Path to experiment config, e.g., config/vc.yaml') + parser.add_argument('--name', default=None, type=str, help='Name for logging.') + parser.add_argument('--logdir', default='log/', type=str, + help='Logging path.', required=False) + parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str, + help='Checkpoint path.', required=False) + parser.add_argument('--outdir', default='result/', type=str, + help='Decode output path.', required=False) + parser.add_argument('--load', default=None, type=str, + help='Load pre-trained model (for training only)', required=False) + parser.add_argument('--warm_start', action='store_true', + help='Load model weights only, ignore specified layers.') + parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducable results.', required=False) + parser.add_argument('--njobs', default=8, type=int, + help='Number of threads for dataloader/decoding.', required=False) + parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') + parser.add_argument('--no-pin', action='store_true', + help='Disable pin-memory for dataloader') + parser.add_argument('--test', action='store_true', help='Test the model.') + parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') + parser.add_argument('--finetune', action='store_true', help='Finetune model') + parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') + parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') + parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') + + ### + + paras = parser.parse_args() + setattr(paras, 'gpu', not paras.cpu) + setattr(paras, 'pin_memory', not paras.no_pin) + setattr(paras, 'verbose', not paras.no_msg) + # Make the config dict dot visitable + config = HpsYaml(paras.config) + + np.random.seed(paras.seed) + torch.manual_seed(paras.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(paras.seed) + + print(">>> OneShot VC training ...") + mode = "train" + solver = Solver(config, paras, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/ppg_extractor/__init__.py b/ppg_extractor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42a3983c56ba94c07bddefdfa357c30ad9e48a32 --- /dev/null +++ b/ppg_extractor/__init__.py @@ -0,0 +1,102 @@ +import argparse +import torch +from pathlib import Path +import yaml + +from .frontend import DefaultFrontend +from .utterance_mvn import UtteranceMVN +from .encoder.conformer_encoder import ConformerEncoder + +_model = None # type: PPGModel +_device = None + +class PPGModel(torch.nn.Module): + def __init__( + self, + frontend, + normalizer, + encoder, + ): + super().__init__() + self.frontend = frontend + self.normalize = normalizer + self.encoder = encoder + + def forward(self, speech, speech_lengths): + """ + + Args: + speech (tensor): (B, L) + speech_lengths (tensor): (B, ) + + Returns: + bottle_neck_feats (tensor): (B, L//hop_size, 144) + + """ + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + feats, feats_lengths = self.normalize(feats, feats_lengths) + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + return encoder_out + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ): + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def extract_from_wav(self, src_wav): + src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device) + src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device) + return self(src_wav_tensor, src_wav_lengths) + + +def build_model(args): + normalizer = UtteranceMVN(**args.normalize_conf) + frontend = DefaultFrontend(**args.frontend_conf) + encoder = ConformerEncoder(input_size=80, **args.encoder_conf) + model = PPGModel(frontend, normalizer, encoder) + + return model + + +def load_model(model_file, device=None): + global _model, _device + + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + _device = device + # search a config file + model_config_fpaths = list(model_file.parent.rglob("*.yaml")) + config_file = model_config_fpaths[0] + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + + args = argparse.Namespace(**args) + + model = build_model(args) + model_state_dict = model.state_dict() + + ckpt_state_dict = torch.load(model_file, map_location=_device) + ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k} + + model_state_dict.update(ckpt_state_dict) + model.load_state_dict(model_state_dict) + + _model = model.eval().to(_device) + return _model + + diff --git a/ppg_extractor/e2e_asr_common.py b/ppg_extractor/e2e_asr_common.py new file mode 100644 index 0000000000000000000000000000000000000000..b67f9f1322250ba1a2044ce3a544aa102c07ee4f --- /dev/null +++ b/ppg_extractor/e2e_asr_common.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Common functions for ASR.""" + +import argparse +import editdistance +import json +import logging +import numpy as np +import six +import sys + +from itertools import groupby + + +def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): + """End detection. + + desribed in Eq. (50) of S. Watanabe et al + "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" + + :param ended_hyps: + :param i: + :param M: + :param D_end: + :return: + """ + if len(ended_hyps) == 0: + return False + count = 0 + best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0] + for m in six.moves.range(M): + # get ended_hyps with their length is i - m + hyp_length = i - m + hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length] + if len(hyps_same_length) > 0: + best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0] + if best_hyp_same_length['score'] - best_hyp['score'] < D_end: + count += 1 + + if count == M: + return True + else: + return False + + +# TODO(takaaki-hori): add different smoothing methods +def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): + """Obtain label distribution for loss smoothing. + + :param odim: + :param lsm_type: + :param blank: + :param transcript: + :return: + """ + if transcript is not None: + with open(transcript, 'rb') as f: + trans_json = json.load(f)['utts'] + + if lsm_type == 'unigram': + assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type + labelcount = np.zeros(odim) + for k, v in trans_json.items(): + ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()]) + # to avoid an error when there is no text in an uttrance + if len(ids) > 0: + labelcount[ids] += 1 + labelcount[odim - 1] = len(transcript) # count + labelcount[labelcount == 0] = 1 # flooring + labelcount[blank] = 0 # remove counts for blank + labeldist = labelcount.astype(np.float32) / np.sum(labelcount) + else: + logging.error( + "Error: unexpected label smoothing type: %s" % lsm_type) + sys.exit() + + return labeldist + + +def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True): + """Return the output size of the VGG frontend. + + :param in_channel: input channel size + :param out_channel: output channel size + :return: output size + :rtype int + """ + idim = idim / in_channel + if downsample: + idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling + idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling + return int(idim) * out_channel # numer of channels + + +class ErrorCalculator(object): + """Calculate CER and WER for E2E_ASR and CTC models during training. + + :param y_hats: numpy array with predicted text + :param y_pads: numpy array with true (target) text + :param char_list: + :param sym_space: + :param sym_blank: + :return: + """ + + def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False, + trans_type="char"): + """Construct an ErrorCalculator object.""" + super(ErrorCalculator, self).__init__() + + self.report_cer = report_cer + self.report_wer = report_wer + self.trans_type = trans_type + self.char_list = char_list + self.space = sym_space + self.blank = sym_blank + self.idx_blank = self.char_list.index(self.blank) + if self.space in self.char_list: + self.idx_space = self.char_list.index(self.space) + else: + self.idx_space = None + + def __call__(self, ys_hat, ys_pad, is_ctc=False): + """Calculate sentence-level WER/CER score. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :param bool is_ctc: calculate CER score for CTC + :return: sentence-level WER score + :rtype float + :return: sentence-level CER score + :rtype float + """ + cer, wer = None, None + if is_ctc: + return self.calculate_cer_ctc(ys_hat, ys_pad) + elif not self.report_cer and not self.report_wer: + return cer, wer + + seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad) + if self.report_cer: + cer = self.calculate_cer(seqs_hat, seqs_true) + + if self.report_wer: + wer = self.calculate_wer(seqs_hat, seqs_true) + return cer, wer + + def calculate_cer_ctc(self, ys_hat, ys_pad): + """Calculate sentence-level CER score for CTC. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :return: average sentence-level CER score + :rtype float + """ + cers, char_ref_lens = [], [] + for i, y in enumerate(ys_hat): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[i] + seq_hat, seq_true = [], [] + for idx in y_hat: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_hat.append(self.char_list[int(idx)]) + + for idx in y_true: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_true.append(self.char_list[int(idx)]) + if self.trans_type == "char": + hyp_chars = "".join(seq_hat) + ref_chars = "".join(seq_true) + else: + hyp_chars = " ".join(seq_hat) + ref_chars = " ".join(seq_true) + + if len(ref_chars) > 0: + cers.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None + return cer_ctc + + def convert_to_char(self, ys_hat, ys_pad): + """Convert index to character. + + :param torch.Tensor seqs_hat: prediction (batch, seqlen) + :param torch.Tensor seqs_true: reference (batch, seqlen) + :return: token list of prediction + :rtype list + :return: token list of reference + :rtype list + """ + seqs_hat, seqs_true = [], [] + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + eos_true = np.where(y_true == -1)[0] + eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true) + # To avoid wrong higher WER than the one obtained from the decoding + # eos from y_true is used to mark the eos in y_hat + # because of that y_hats has not padded outs with -1. + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]] + seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + # seq_hat_text = "".join(seq_hat).replace(self.space, ' ') + seq_hat_text = " ".join(seq_hat).replace(self.space, ' ') + seq_hat_text = seq_hat_text.replace(self.blank, '') + # seq_true_text = "".join(seq_true).replace(self.space, ' ') + seq_true_text = " ".join(seq_true).replace(self.space, ' ') + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + return seqs_hat, seqs_true + + def calculate_cer(self, seqs_hat, seqs_true): + """Calculate sentence-level CER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level CER score + :rtype float + """ + char_eds, char_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_chars = seq_hat_text.replace(' ', '') + ref_chars = seq_true_text.replace(' ', '') + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + return float(sum(char_eds)) / sum(char_ref_lens) + + def calculate_wer(self, seqs_hat, seqs_true): + """Calculate sentence-level WER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level WER score + :rtype float + """ + word_eds, word_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + return float(sum(word_eds)) / sum(word_ref_lens) + + +class ErrorCalculatorTrans(object): + """Calculate CER and WER for transducer models. + + Args: + decoder (nn.Module): decoder module + args (Namespace): argument Namespace containing options + report_cer (boolean): compute CER option + report_wer (boolean): compute WER option + + """ + + def __init__(self, decoder, args, report_cer=False, report_wer=False): + """Construct an ErrorCalculator object for transducer model.""" + super(ErrorCalculatorTrans, self).__init__() + + self.dec = decoder + + recog_args = {'beam_size': args.beam_size, + 'nbest': args.nbest, + 'space': args.sym_space, + 'score_norm_transducer': args.score_norm_transducer} + + self.recog_args = argparse.Namespace(**recog_args) + + self.char_list = args.char_list + self.space = args.sym_space + self.blank = args.sym_blank + + self.report_cer = args.report_cer + self.report_wer = args.report_wer + + def __call__(self, hs_pad, ys_pad): + """Calculate sentence-level WER/CER score for transducer models. + + Args: + hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D) + ys_pad (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): sentence-level CER score + (float): sentence-level WER score + + """ + cer, wer = None, None + + if not self.report_cer and not self.report_wer: + return cer, wer + + batchsize = int(hs_pad.size(0)) + batch_nbest = [] + + for b in six.moves.range(batchsize): + if self.recog_args.beam_size == 1: + nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args) + else: + nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args) + batch_nbest.append(nbest_hyps) + + ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest] + + seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu()) + + if self.report_cer: + cer = self.calculate_cer(seqs_hat, seqs_true) + + if self.report_wer: + wer = self.calculate_wer(seqs_hat, seqs_true) + + return cer, wer + + def convert_to_char(self, ys_hat, ys_pad): + """Convert index to character. + + Args: + ys_hat (torch.Tensor): prediction (batch, seqlen) + ys_pad (torch.Tensor): reference (batch, seqlen) + + Returns: + (list): token list of prediction + (list): token list of reference + + """ + seqs_hat, seqs_true = [], [] + + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + + eos_true = np.where(y_true == -1)[0] + eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true) + + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]] + seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + + seq_hat_text = "".join(seq_hat).replace(self.space, ' ') + seq_hat_text = seq_hat_text.replace(self.blank, '') + seq_true_text = "".join(seq_true).replace(self.space, ' ') + + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + + return seqs_hat, seqs_true + + def calculate_cer(self, seqs_hat, seqs_true): + """Calculate sentence-level CER score for transducer model. + + Args: + seqs_hat (torch.Tensor): prediction (batch, seqlen) + seqs_true (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): average sentence-level CER score + + """ + char_eds, char_ref_lens = [], [] + + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_chars = seq_hat_text.replace(' ', '') + ref_chars = seq_true_text.replace(' ', '') + + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + return float(sum(char_eds)) / sum(char_ref_lens) + + def calculate_wer(self, seqs_hat, seqs_true): + """Calculate sentence-level WER score for transducer model. + + Args: + seqs_hat (torch.Tensor): prediction (batch, seqlen) + seqs_true (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): average sentence-level WER score + + """ + word_eds, word_ref_lens = [], [] + + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + + return float(sum(word_eds)) / sum(word_ref_lens) diff --git a/ppg_extractor/encoder/__init__.py b/ppg_extractor/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ppg_extractor/encoder/attention.py b/ppg_extractor/encoder/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7a0d5078bb40f5c2797bcf38b9c077a62ccfa4 --- /dev/null +++ b/ppg_extractor/encoder/attention.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention layer definition.""" + +import math + +import numpy +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + :param int n_head: the number of head s + :param int n_feat: the number of features + :param float dropout_rate: dropout rate + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + :param torch.Tensor query: (batch, time1, size) + :param torch.Tensor key: (batch, time2, size) + :param torch.Tensor value: (batch, time2, size) + :return torch.Tensor transformed query, key and value + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + :param torch.Tensor value: (batch, head, time2, size) + :param torch.Tensor scores: (batch, head, time1, time2) + :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2) + :return torch.Tensor transformed `value` (batch, time1, d_model) + weighted by the attention score (batch, time1, time2) + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float( + numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min + ) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute 'Scaled Dot Product Attention'. + + :param torch.Tensor query: (batch, time1, size) + :param torch.Tensor key: (batch, time2, size) + :param torch.Tensor value: (batch, time2, size) + :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2) + :param torch.nn.Dropout dropout: + :return torch.Tensor: attention output (batch, time1, d_model) + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + + Paper: https://arxiv.org/abs/1901.02860 + + :param int n_head: the number of head s + :param int n_feat: the number of features + :param float dropout_rate: dropout rate + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional ecoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, zero_triu=False): + """Compute relative positinal encoding. + + :param torch.Tensor x: (batch, time, size) + :param bool zero_triu: return the lower triangular part of the matrix + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + :param torch.Tensor query: (batch, time1, size) + :param torch.Tensor key: (batch, time2, size) + :param torch.Tensor value: (batch, time2, size) + :param torch.Tensor pos_emb: (batch, time1, size) + :param torch.Tensor mask: (batch, time1, time2) + :param torch.nn.Dropout dropout: + :return torch.Tensor: attention output (batch, time1, d_model) + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) diff --git a/ppg_extractor/encoder/conformer_encoder.py b/ppg_extractor/encoder/conformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d31e97a28fffec3599558109971b771c64ee2a80 --- /dev/null +++ b/ppg_extractor/encoder/conformer_encoder.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" + +import logging +import torch +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from .convolution import ConvolutionModule +from .encoder_layer import EncoderLayer +from ..nets_utils import get_activation, make_pad_mask +from .vgg import VGG2L +from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention +from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding +from .layer_norm import LayerNorm +from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d +from .positionwise_feed_forward import PositionwiseFeedForward +from .repeat import repeat +from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling + + +class ConformerEncoder(torch.nn.Module): + """Conformer encoder module. + + :param int idim: input dim + :param int attention_dim: dimention of attention + :param int attention_heads: the number of heads of multi head attention + :param int linear_units: the number of units of position-wise feed forward + :param int num_blocks: the number of decoder blocks + :param float dropout_rate: dropout rate + :param float attention_dropout_rate: dropout rate in attention + :param float positional_dropout_rate: dropout rate after adding positional encoding + :param str or torch.nn.Module input_layer: input layer type + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + :param str positionwise_layer_type: linear of conv1d + :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + :param str encoder_pos_enc_layer_type: encoder positional encoding layer type + :param str encoder_attn_layer_type: encoder attention layer type + :param str activation_type: encoder activation function type + :param bool macaron_style: whether to use macaron style for positionwise layer + :param bool use_cnn_module: whether to use convolution module + :param int cnn_module_kernel: kernerl size of convolution module + :param int padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + cnn_module_kernel=31, + padding_idx=-1, + no_subsample=False, + subsample_by_2=False, + ): + """Construct an Encoder object.""" + super().__init__() + + self._output_size = attention_dim + idim = input_size + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + logging.info("Encoder input layer type: conv2d") + if no_subsample: + self.embed = Conv2dNoSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + subsample_by_2, # NOTE(Sx): added by songxiang + ) + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + attention_dim, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + xs_pad: input tensor (B, L, D) + ilens: input lengths (B) + prev_states: Not to be used now. + Returns: + Position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)): + # print(xs_pad.shape) + xs_pad, masks = self.embed(xs_pad, masks) + # print(xs_pad[0].size()) + else: + xs_pad = self.embed(xs_pad) + xs_pad, masks = self.encoders(xs_pad, masks) + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + olens = masks.squeeze(1).sum(1) + return xs_pad, olens, None + + # def forward(self, xs, masks): + # """Encode input sequence. + + # :param torch.Tensor xs: input tensor + # :param torch.Tensor masks: input mask + # :return: position embedded tensor and mask + # :rtype Tuple[torch.Tensor, torch.Tensor]: + # """ + # if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + # xs, masks = self.embed(xs, masks) + # else: + # xs = self.embed(xs) + + # xs, masks = self.encoders(xs, masks) + # if isinstance(xs, tuple): + # xs = xs[0] + + # if self.normalize_before: + # xs = self.after_norm(xs) + # return xs, masks diff --git a/ppg_extractor/encoder/convolution.py b/ppg_extractor/encoder/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2c399e406aae97a6baf0f7de379a1d90a97949 --- /dev/null +++ b/ppg_extractor/encoder/convolution.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""ConvolutionModule definition.""" + +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + + :param int channels: channels of cnn + :param int kernel_size: kernerl size of cnn + + """ + + def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward(self, x): + """Compute convolution module. + + :param torch.Tensor x: (batch, time, size) + :return torch.Tensor: convoluted `value` (batch, time, d_model) + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) diff --git a/ppg_extractor/encoder/embedding.py b/ppg_extractor/encoder/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3199cf7e3da2ed834d4781b694cf4ccb2a433c --- /dev/null +++ b/ppg_extractor/encoder/embedding.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positonal Encoding Module.""" + +import math + +import torch + + +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + :param reverse: whether to reverse the input position + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class ScaledPositionalEncoding(PositionalEncoding): + """Scaled positional encoding module. + + See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + """ + super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) + self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + + def reset_parameters(self): + """Reset parameters.""" + self.alpha.data = torch.tensor(1.0) + + def forward(self, x): + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(PositionalEncoding): + """Relitive positional encoding module. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + """ + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, x): + """Compute positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + + Returns: + torch.Tensor: x. Its shape is (batch, time, ...) + torch.Tensor: pos_emb. Its shape is (1, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, : x.size(1)] + return self.dropout(x), self.dropout(pos_emb) diff --git a/ppg_extractor/encoder/encoder.py b/ppg_extractor/encoder/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6b92c0116ef1e207f7f7c94e9162cf0f5b86db7b --- /dev/null +++ b/ppg_extractor/encoder/encoder.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" + +import logging +import torch + +from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule +from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.nets_utils import get_activation +from espnet.nets.pytorch_backend.transducer.vgg import VGG2L +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling + + +class Encoder(torch.nn.Module): + """Conformer encoder module. + + :param int idim: input dim + :param int attention_dim: dimention of attention + :param int attention_heads: the number of heads of multi head attention + :param int linear_units: the number of units of position-wise feed forward + :param int num_blocks: the number of decoder blocks + :param float dropout_rate: dropout rate + :param float attention_dropout_rate: dropout rate in attention + :param float positional_dropout_rate: dropout rate after adding positional encoding + :param str or torch.nn.Module input_layer: input layer type + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + :param str positionwise_layer_type: linear of conv1d + :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + :param str encoder_pos_enc_layer_type: encoder positional encoding layer type + :param str encoder_attn_layer_type: encoder attention layer type + :param str activation_type: encoder activation function type + :param bool macaron_style: whether to use macaron style for positionwise layer + :param bool use_cnn_module: whether to use convolution module + :param int cnn_module_kernel: kernerl size of convolution module + :param int padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + cnn_module_kernel=31, + padding_idx=-1, + ): + """Construct an Encoder object.""" + super(Encoder, self).__init__() + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + attention_dim, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, xs, masks): + """Encode input sequence. + + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + xs, masks = self.encoders(xs, masks) + if isinstance(xs, tuple): + xs = xs[0] + + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks diff --git a/ppg_extractor/encoder/encoder_layer.py b/ppg_extractor/encoder/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..750a32e4ef22ed5c2ca74aa364d1e8a3470e4016 --- /dev/null +++ b/ppg_extractor/encoder/encoder_layer.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder self-attention layer definition.""" + +import torch + +from torch import nn + +from .layer_norm import LayerNorm + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + :param int size: input dim + :param espnet.nets.pytorch_backend.transformer.attention. + MultiHeadedAttention self_attn: self attention module + RelPositionMultiHeadedAttention self_attn: self attention module + :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. + PositionwiseFeedForward feed_forward: + feed forward module + :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward + for macaron style + PositionwiseFeedForward feed_forward: + feed forward module + :param espnet.nets.pytorch_backend.conformer.convolution. + ConvolutionModule feed_foreard: + feed forward module + :param float dropout_rate: dropout rate + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + feed_forward_macaron, + conv_module, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x_input, mask, cache=None): + """Compute encoded features. + + :param torch.Tensor x_input: encoded source features, w/o pos_emb + tuple((batch, max_time_in, size), (1, max_time_in, size)) + or (batch, max_time_in, size) + :param torch.Tensor mask: mask for x (batch, max_time_in) + :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size) + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask diff --git a/ppg_extractor/encoder/layer_norm.py b/ppg_extractor/encoder/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..db8be30ff70554edb179109037665e51c04510ec --- /dev/null +++ b/ppg_extractor/encoder/layer_norm.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer normalization module.""" + +import torch + + +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) diff --git a/ppg_extractor/encoder/multi_layer_conv.py b/ppg_extractor/encoder/multi_layer_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb7fe70810eda54c727367efc986ce02ce581cc --- /dev/null +++ b/ppg_extractor/encoder/multi_layer_conv.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" + +import torch + + +class MultiLayeredConv1d(torch.nn.Module): + """Multi-layered conv1d for Transformer block. + + This is a module of multi-leyered conv1d designed + to replace positionwise feed-forward network + in Transforner block, which is introduced in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize MultiLayeredConv1d module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Conv1d( + hidden_chans, + in_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Batch of input tensors (B, ..., in_chans). + + Returns: + Tensor: Batch of output tensors (B, ..., hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + + +class Conv1dLinear(torch.nn.Module): + """Conv1D + Linear for Transformer block. + + A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize Conv1dLinear module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(Conv1dLinear, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Linear(hidden_chans, in_chans) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Batch of input tensors (B, ..., in_chans). + + Returns: + Tensor: Batch of output tensors (B, ..., hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x)) diff --git a/ppg_extractor/encoder/positionwise_feed_forward.py b/ppg_extractor/encoder/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9237a38314e3f758f064ab78d8983b94a9eb0a --- /dev/null +++ b/ppg_extractor/encoder/positionwise_feed_forward.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + :param int idim: input dimenstion + :param int hidden_units: number of hidden units + :param float dropout_rate: dropout rate + + """ + + def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): + """Construct an PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.dropout = torch.nn.Dropout(dropout_rate) + self.activation = activation + + def forward(self, x): + """Forward funciton.""" + return self.w_2(self.dropout(self.activation(self.w_1(x)))) diff --git a/ppg_extractor/encoder/repeat.py b/ppg_extractor/encoder/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8af6ce850e930feb2bf0cd0e9bc7a8d21520e4 --- /dev/null +++ b/ppg_extractor/encoder/repeat.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Repeat the same layer definition.""" + +import torch + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential.""" + + def forward(self, *args): + """Repeat.""" + for m in self: + args = m(*args) + return args + + +def repeat(N, fn): + """Repeat module N times. + + :param int N: repeat time + :param function fn: function to generate module + :return: repeated modules + :rtype: MultiSequential + """ + return MultiSequential(*[fn(n) for n in range(N)]) diff --git a/ppg_extractor/encoder/subsampling.py b/ppg_extractor/encoder/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e754126b2ec1f2d914206ec35ec026c7b6add17f --- /dev/null +++ b/ppg_extractor/encoder/subsampling.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Subsampling layer definition.""" +import logging +import torch + +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding + + +class Conv2dSubsampling(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length or 1/2 length). + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + :param torch.nn.Module pos_enc: custom position encoding layer + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None, + subsample_by_2=False, + ): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling, self).__init__() + self.subsample_by_2 = subsample_by_2 + if subsample_by_2: + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (idim // 2), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + else: + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (idim // 4), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + if self.subsample_by_2: + return x, x_mask[:, :, ::2] + else: + return x, x_mask[:, :, ::2][:, :, ::2] + + def __getitem__(self, key): + """Subsample x. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dNoSubsampling(torch.nn.Module): + """Convolutional 2D without subsampling. + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + :param torch.nn.Module pos_enc: custom position encoding layer + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling object.""" + super().__init__() + logging.info("Encoder does not do down-sample on mel-spectrogram.") + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * idim, odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask + + def __getitem__(self, key): + """Subsample x. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dSubsampling6(torch.nn.Module): + """Convolutional 2D subsampling (to 1/6 length). + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + + """ + + def __init__(self, idim, odim, dropout_rate): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling6, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), + PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-4:3] + + +class Conv2dSubsampling8(torch.nn.Module): + """Convolutional 2D subsampling (to 1/8 length). + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + + """ + + def __init__(self, idim, odim, dropout_rate): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling8, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim), + PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] diff --git a/ppg_extractor/encoder/swish.py b/ppg_extractor/encoder/swish.py new file mode 100644 index 0000000000000000000000000000000000000000..c53a7a98bfc6d983c3a308c4b40f81e315aa7875 --- /dev/null +++ b/ppg_extractor/encoder/swish.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Swish() activation function for Conformer.""" + +import torch + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x): + """Return Swich activation function.""" + return x * torch.sigmoid(x) diff --git a/ppg_extractor/encoder/vgg.py b/ppg_extractor/encoder/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca1c6551eb6ad238838011a2c98d965138fd770 --- /dev/null +++ b/ppg_extractor/encoder/vgg.py @@ -0,0 +1,77 @@ +"""VGG2L definition for transformer-transducer.""" + +import torch + + +class VGG2L(torch.nn.Module): + """VGG2L module for transformer-transducer encoder.""" + + def __init__(self, idim, odim): + """Construct a VGG2L object. + + Args: + idim (int): dimension of inputs + odim (int): dimension of outputs + + """ + super(VGG2L, self).__init__() + + self.vgg2l = torch.nn.Sequential( + torch.nn.Conv2d(1, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(64, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((3, 2)), + torch.nn.Conv2d(64, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(128, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((2, 2)), + ) + + self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) + + def forward(self, x, x_mask): + """VGG2L forward for x. + + Args: + x (torch.Tensor): input torch (B, T, idim) + x_mask (torch.Tensor): (B, 1, T) + + Returns: + x (torch.Tensor): input torch (B, sub(T), attention_dim) + x_mask (torch.Tensor): (B, 1, sub(T)) + + """ + x = x.unsqueeze(1) + x = self.vgg2l(x) + + b, c, t, f = x.size() + + x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if x_mask is None: + return x, None + else: + x_mask = self.create_new_mask(x_mask, x) + + return x, x_mask + + def create_new_mask(self, x_mask, x): + """Create a subsampled version of x_mask. + + Args: + x_mask (torch.Tensor): (B, 1, T) + x (torch.Tensor): (B, sub(T), attention_dim) + + Returns: + x_mask (torch.Tensor): (B, 1, sub(T)) + + """ + x_t1 = x_mask.size(2) - (x_mask.size(2) % 3) + x_mask = x_mask[:, :, :x_t1][:, :, ::3] + + x_t2 = x_mask.size(2) - (x_mask.size(2) % 2) + x_mask = x_mask[:, :, :x_t2][:, :, ::2] + + return x_mask diff --git a/ppg_extractor/encoders.py b/ppg_extractor/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..526140f4fac5b0c4663e435243655ae74a4735fa --- /dev/null +++ b/ppg_extractor/encoders.py @@ -0,0 +1,298 @@ +import logging +import six + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from .e2e_asr_common import get_vgg2l_odim +from .nets_utils import make_pad_mask, to_device + + +class RNNP(torch.nn.Module): + """RNN with projection layer module + + :param int idim: dimension of inputs + :param int elayers: number of encoder layers + :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) + :param int hdim: number of projection units + :param np.ndarray subsample: list of subsampling numbers + :param float dropout: dropout rate + :param str typ: The RNN type + """ + + def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"): + super(RNNP, self).__init__() + bidir = typ[0] == "b" + for i in six.moves.range(elayers): + if i == 0: + inputdim = idim + else: + inputdim = hdim + rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, + batch_first=True) if "lstm" in typ \ + else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True) + setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn) + # bottleneck layer to merge + if bidir: + setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim)) + else: + setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim)) + + self.elayers = elayers + self.cdim = cdim + self.subsample = subsample + self.typ = typ + self.bidir = bidir + + def forward(self, xs_pad, ilens, prev_state=None): + """RNNP forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous RNN states + :return: batch of hidden state sequences (B, Tmax, hdim) + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens)) + elayer_states = [] + for layer in six.moves.range(self.elayers): + xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False) + rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) + rnn.flatten_parameters() + if prev_state is not None and rnn.bidirectional: + prev_state = reset_backward_rnn_state(prev_state) + ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer]) + elayer_states.append(states) + # ys: utt list of frame x cdim x 2 (2: means bidirectional) + ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) + sub = self.subsample[layer + 1] + if sub > 1: + ys_pad = ys_pad[:, ::sub] + ilens = [int(i + 1) // sub for i in ilens] + # (sum _utt frame_utt) x dim + projected = getattr(self, 'bt' + str(layer) + )(ys_pad.contiguous().view(-1, ys_pad.size(2))) + if layer == self.elayers - 1: + xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) + else: + xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1)) + + return xs_pad, ilens, elayer_states # x: utt list of frame x dim + + +class RNN(torch.nn.Module): + """RNN module + + :param int idim: dimension of inputs + :param int elayers: number of encoder layers + :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) + :param int hdim: number of final projection units + :param float dropout: dropout rate + :param str typ: The RNN type + """ + + def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"): + super(RNN, self).__init__() + bidir = typ[0] == "b" + self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True, + dropout=dropout, bidirectional=bidir) if "lstm" in typ \ + else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout, + bidirectional=bidir) + if bidir: + self.l_last = torch.nn.Linear(cdim * 2, hdim) + else: + self.l_last = torch.nn.Linear(cdim, hdim) + self.typ = typ + + def forward(self, xs_pad, ilens, prev_state=None): + """RNN forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous RNN states + :return: batch of hidden state sequences (B, Tmax, eprojs) + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens)) + xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True) + self.nbrnn.flatten_parameters() + if prev_state is not None and self.nbrnn.bidirectional: + # We assume that when previous state is passed, it means that we're streaming the input + # and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction) + prev_state = reset_backward_rnn_state(prev_state) + ys, states = self.nbrnn(xs_pack, hx=prev_state) + # ys: utt list of frame x cdim x 2 (2: means bidirectional) + ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) + # (sum _utt frame_utt) x dim + projected = torch.tanh(self.l_last( + ys_pad.contiguous().view(-1, ys_pad.size(2)))) + xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) + return xs_pad, ilens, states # x: utt list of frame x dim + + +def reset_backward_rnn_state(states): + """Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs""" + if isinstance(states, (list, tuple)): + for state in states: + state[1::2] = 0. + else: + states[1::2] = 0. + return states + + +class VGG2L(torch.nn.Module): + """VGG-like module + + :param int in_channel: number of input channels + """ + + def __init__(self, in_channel=1, downsample=True): + super(VGG2L, self).__init__() + # CNN layer (VGG motivated) + self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1) + self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1) + self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1) + self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1) + + self.in_channel = in_channel + self.downsample = downsample + if downsample: + self.stride = 2 + else: + self.stride = 1 + + def forward(self, xs_pad, ilens, **kwargs): + """VGG2L forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens)) + + # x: utt x frame x dim + # xs_pad = F.pad_sequence(xs_pad) + + # x: utt x 1 (input channel num) x frame x dim + xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel, + xs_pad.size(2) // self.in_channel).transpose(1, 2) + + # NOTE: max_pool1d ? + xs_pad = F.relu(self.conv1_1(xs_pad)) + xs_pad = F.relu(self.conv1_2(xs_pad)) + if self.downsample: + xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True) + + xs_pad = F.relu(self.conv2_1(xs_pad)) + xs_pad = F.relu(self.conv2_2(xs_pad)) + if self.downsample: + xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True) + if torch.is_tensor(ilens): + ilens = ilens.cpu().numpy() + else: + ilens = np.array(ilens, dtype=np.float32) + if self.downsample: + ilens = np.array(np.ceil(ilens / 2), dtype=np.int64) + ilens = np.array( + np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist() + + # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim) + xs_pad = xs_pad.transpose(1, 2) + xs_pad = xs_pad.contiguous().view( + xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)) + return xs_pad, ilens, None # no state in this layer + + +class Encoder(torch.nn.Module): + """Encoder module + + :param str etype: type of encoder network + :param int idim: number of dimensions of encoder network + :param int elayers: number of layers of encoder network + :param int eunits: number of lstm units of encoder network + :param int eprojs: number of projection units of encoder network + :param np.ndarray subsample: list of subsampling numbers + :param float dropout: dropout rate + :param int in_channel: number of input channels + """ + + def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1): + super(Encoder, self).__init__() + typ = etype.lstrip("vgg").rstrip("p") + if typ not in ['lstm', 'gru', 'blstm', 'bgru']: + logging.error("Error: need to specify an appropriate encoder architecture") + + if etype.startswith("vgg"): + if etype[-1] == "p": + self.enc = torch.nn.ModuleList([VGG2L(in_channel), + RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits, + eprojs, + subsample, dropout, typ=typ)]) + logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder') + else: + self.enc = torch.nn.ModuleList([VGG2L(in_channel), + RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits, + eprojs, + dropout, typ=typ)]) + logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder') + else: + if etype[-1] == "p": + self.enc = torch.nn.ModuleList( + [RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)]) + logging.info(typ.upper() + ' with every-layer projection for encoder') + else: + self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)]) + logging.info(typ.upper() + ' without projection for encoder') + + def forward(self, xs_pad, ilens, prev_states=None): + """Encoder forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) + :return: batch of hidden state sequences (B, Tmax, eprojs) + :rtype: torch.Tensor + """ + if prev_states is None: + prev_states = [None] * len(self.enc) + assert len(prev_states) == len(self.enc) + + current_states = [] + for module, prev_state in zip(self.enc, prev_states): + xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) + current_states.append(states) + + # make mask to remove bias value in padded part + mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1)) + + return xs_pad.masked_fill(mask, 0.0), ilens, current_states + + +def encoder_for(args, idim, subsample): + """Instantiates an encoder module given the program arguments + + :param Namespace args: The arguments + :param int or List of integer idim: dimension of input, e.g. 83, or + List of dimensions of inputs, e.g. [83,83] + :param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or + List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]] + :rtype torch.nn.Module + :return: The encoder module + """ + num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility + if num_encs == 1: + # compatible with single encoder asr mode + return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate) + elif num_encs >= 1: + enc_list = torch.nn.ModuleList() + for idx in range(num_encs): + enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx], + args.dropout_rate[idx]) + enc_list.append(enc) + return enc_list + else: + raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs)) diff --git a/ppg_extractor/frontend.py b/ppg_extractor/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..32549ed050655d79be1793a9cf04d9d52644794a --- /dev/null +++ b/ppg_extractor/frontend.py @@ -0,0 +1,115 @@ +import copy +from typing import Tuple +import numpy as np +import torch +from torch_complex.tensor import ComplexTensor + +from .log_mel import LogMel +from .stft import Stft + + +class DefaultFrontend(torch.nn.Module): + """Conventional frontend structure for ASR + + Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN + """ + + def __init__( + self, + fs: 16000, + n_fft: int = 1024, + win_length: int = 800, + hop_length: int = 160, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: int = None, + fmax: int = None, + htk: bool = False, + norm=1, + frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend), + kaldi_padding_mode=False, + downsample_rate: int = 1, + ): + super().__init__() + self.downsample_rate = downsample_rate + + # Deepcopy (In general, dict shouldn't be used as default arg) + frontend_conf = copy.deepcopy(frontend_conf) + + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + center=center, + pad_mode=pad_mode, + normalized=normalized, + onesided=onesided, + kaldi_padding_mode=kaldi_padding_mode + ) + if frontend_conf is not None: + self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) + else: + self.frontend = None + + self.logmel = LogMel( + fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm, + ) + self.n_mels = n_mels + + def output_size(self) -> int: + return self.n_mels + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Domain-conversion: e.g. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # Change torch.Tensor to ComplexTensor + # input_stft: (..., F, 2) -> (..., F) + input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) + + # 2. [Option] Speech enhancement + if self.frontend is not None: + assert isinstance(input_stft, ComplexTensor), type(input_stft) + # input_stft: (Batch, Length, [Channel], Freq) + input_stft, _, mask = self.frontend(input_stft, feats_lens) + + # 3. [Multi channel case]: Select a channel + if input_stft.dim() == 4: + # h: (B, T, C, F) -> h: (B, T, F) + if self.training: + # Select 1ch randomly + ch = np.random.randint(input_stft.size(2)) + input_stft = input_stft[:, :, ch, :] + else: + # Use the first channel + input_stft = input_stft[:, :, 0, :] + + # 4. STFT -> Power spectrum + # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) + input_power = input_stft.real ** 2 + input_stft.imag ** 2 + + # 5. Feature transform e.g. Stft -> Log-Mel-Fbank + # input_power: (Batch, [Channel,] Length, Freq) + # -> input_feats: (Batch, Length, Dim) + input_feats, _ = self.logmel(input_power, feats_lens) + + # NOTE(sx): pad + max_len = input_feats.size(1) + if self.downsample_rate > 1 and max_len % self.downsample_rate != 0: + padding = self.downsample_rate - max_len % self.downsample_rate + # print("Logmel: ", input_feats.size()) + input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding), + "constant", 0) + # print("Logmel(after padding): ",input_feats.size()) + feats_lens[torch.argmax(feats_lens)] = max_len + padding + + return input_feats, feats_lens diff --git a/ppg_extractor/log_mel.py b/ppg_extractor/log_mel.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3b87d7ec73516ad79ee6eb1943cffb70bb52fa --- /dev/null +++ b/ppg_extractor/log_mel.py @@ -0,0 +1,74 @@ +import librosa +import numpy as np +import torch +from typing import Tuple + +from .nets_utils import make_pad_mask + + +class LogMel(torch.nn.Module): + """Convert STFT to fbank feats + + The arguments is same as librosa.filters.mel + + Args: + fs: number > 0 [scalar] sampling rate of the incoming signal + n_fft: int > 0 [scalar] number of FFT components + n_mels: int > 0 [scalar] number of Mel bands to generate + fmin: float >= 0 [scalar] lowest frequency (in Hz) + fmax: float >= 0 [scalar] highest frequency (in Hz). + If `None`, use `fmax = fs / 2.0` + htk: use HTK formula instead of Slaney + norm: {None, 1, np.inf} [scalar] + if 1, divide the triangular mel weights by the width of the mel band + (area normalization). Otherwise, leave all the triangles aiming for + a peak value of 1.0 + + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 512, + n_mels: int = 80, + fmin: float = None, + fmax: float = None, + htk: bool = False, + norm=1, + ): + super().__init__() + + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + _mel_options = dict( + sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm + ) + self.mel_options = _mel_options + + # Note(kamo): The mel matrix of librosa is different from kaldi. + melmat = librosa.filters.mel(**_mel_options) + # melmat: (D2, D1) -> (D1, D2) + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + inv_mel = np.linalg.pinv(melmat) + self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float()) + + def extra_repr(self): + return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) + + def forward( + self, feat: torch.Tensor, ilens: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) + mel_feat = torch.matmul(feat, self.melmat) + + logmel_feat = (mel_feat + 1e-20).log() + # Zero padding + if ilens is not None: + logmel_feat = logmel_feat.masked_fill( + make_pad_mask(ilens, logmel_feat, 1), 0.0 + ) + else: + ilens = feat.new_full( + [feat.size(0)], fill_value=feat.size(1), dtype=torch.long + ) + return logmel_feat, ilens diff --git a/ppg_extractor/nets_utils.py b/ppg_extractor/nets_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6db064b7a829ad7c45dd17e9f5a4fc92c95a72f4 --- /dev/null +++ b/ppg_extractor/nets_utils.py @@ -0,0 +1,465 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" + +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + assert isinstance(m, torch.nn.Module) + device = next(m.parameters()).device + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, :xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError('length_dim cannot be 0: {}'.format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple(slice(None) if i in (0, length_dim) else None + for i in range(xs.dim())) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), + pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == 'c': + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if 'real' not in x or 'imag' not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x['real'], x['imag']) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ("x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x))) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the training args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == 'transformer': + return np.array([1]) + + elif mode == 'mt' and arch == 'rnn': + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + logging.warning('Subsampling is not performed for machine translation.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \ + (mode == 'mt' and arch == 'rnn') or \ + (mode == 'st' and arch == 'rnn'): + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mix': + subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mulenc': + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int) + if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Encoder %d: Subsampling is not performed for vgg*. ' + 'It is performed in max pooling layers at CNN.', idx + 1) + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch)) + + +def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f'Rename: {old_prefix} -> {new_prefix}') + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + from .encoder.swish import Swish + + activation_funcs = { + "hardtanh": torch.nn.Hardtanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": Swish, + } + + return activation_funcs[act]() diff --git a/ppg_extractor/stft.py b/ppg_extractor/stft.py new file mode 100644 index 0000000000000000000000000000000000000000..06b879e3cd810fc85b93fb4e3c118e38dfbce5f0 --- /dev/null +++ b/ppg_extractor/stft.py @@ -0,0 +1,118 @@ +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from .nets_utils import make_pad_mask + + +class Stft(torch.nn.Module): + def __init__( + self, + n_fft: int = 512, + win_length: Union[int, None] = 512, + hop_length: int = 128, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: bool = True, + kaldi_padding_mode=False, + ): + super().__init__() + self.n_fft = n_fft + if win_length is None: + self.win_length = n_fft + else: + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.pad_mode = pad_mode + self.normalized = normalized + self.onesided = onesided + self.kaldi_padding_mode = kaldi_padding_mode + if self.kaldi_padding_mode: + self.win_length = 400 + + def extra_repr(self): + return ( + f"n_fft={self.n_fft}, " + f"win_length={self.win_length}, " + f"hop_length={self.hop_length}, " + f"center={self.center}, " + f"pad_mode={self.pad_mode}, " + f"normalized={self.normalized}, " + f"onesided={self.onesided}" + ) + + def forward( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """STFT forward function. + + Args: + input: (Batch, Nsamples) or (Batch, Nsample, Channels) + ilens: (Batch) + Returns: + output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) + + """ + bs = input.size(0) + if input.dim() == 3: + multi_channel = True + # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) + input = input.transpose(1, 2).reshape(-1, input.size(1)) + else: + multi_channel = False + + # output: (Batch, Freq, Frames, 2=real_imag) + # or (Batch, Channel, Freq, Frames, 2=real_imag) + if not self.kaldi_padding_mode: + output = torch.stft( + input, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=self.center, + pad_mode=self.pad_mode, + normalized=self.normalized, + onesided=self.onesided, + return_complex=False + ) + else: + # NOTE(sx): Use Kaldi-fasion padding, maybe wrong + num_pads = self.n_fft - self.win_length + input = torch.nn.functional.pad(input, (num_pads, 0)) + output = torch.stft( + input, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=False, + pad_mode=self.pad_mode, + normalized=self.normalized, + onesided=self.onesided, + return_complex=False + ) + + # output: (Batch, Freq, Frames, 2=real_imag) + # -> (Batch, Frames, Freq, 2=real_imag) + output = output.transpose(1, 2) + if multi_channel: + # output: (Batch * Channel, Frames, Freq, 2=real_imag) + # -> (Batch, Frame, Channel, Freq, 2=real_imag) + output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( + 1, 2 + ) + + if ilens is not None: + if self.center: + pad = self.win_length // 2 + ilens = ilens + 2 * pad + olens = torch.div(ilens - self.win_length, self.hop_length, rounding_mode='floor') + 1 + # olens = ilens - self.win_length // self.hop_length + 1 + output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) + else: + olens = None + + return output, olens diff --git a/ppg_extractor/utterance_mvn.py b/ppg_extractor/utterance_mvn.py new file mode 100644 index 0000000000000000000000000000000000000000..37fb0c1b918bff60d0c6b5fef883b2f735e7cd79 --- /dev/null +++ b/ppg_extractor/utterance_mvn.py @@ -0,0 +1,82 @@ +from typing import Tuple + +import torch + +from .nets_utils import make_pad_mask + + +class UtteranceMVN(torch.nn.Module): + def __init__( + self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20, + ): + super().__init__() + self.norm_means = norm_means + self.norm_vars = norm_vars + self.eps = eps + + def extra_repr(self): + return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" + + def forward( + self, x: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function + + Args: + x: (B, L, ...) + ilens: (B,) + + """ + return utterance_mvn( + x, + ilens, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + eps=self.eps, + ) + + +def utterance_mvn( + x: torch.Tensor, + ilens: torch.Tensor = None, + norm_means: bool = True, + norm_vars: bool = False, + eps: float = 1.0e-20, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply utterance mean and variance normalization + + Args: + x: (B, T, D), assumed zero padded + ilens: (B,) + norm_means: + norm_vars: + eps: + + """ + if ilens is None: + ilens = x.new_full([x.size(0)], x.size(1)) + ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)]) + # Zero padding + if x.requires_grad: + x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) + else: + x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) + # mean: (B, 1, D) + mean = x.sum(dim=1, keepdim=True) / ilens_ + + if norm_means: + x -= mean + + if norm_vars: + var = x.pow(2).sum(dim=1, keepdim=True) / ilens_ + std = torch.clamp(var.sqrt(), min=eps) + x = x / std.sqrt() + return x, ilens + else: + if norm_vars: + y = x - mean + y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0) + var = y.pow(2).sum(dim=1, keepdim=True) / ilens_ + std = torch.clamp(var.sqrt(), min=eps) + x /= std + return x, ilens diff --git a/pre.py b/pre.py new file mode 100644 index 0000000000000000000000000000000000000000..17fd0f710153bfb71b717678998a853e364c8cd8 --- /dev/null +++ b/pre.py @@ -0,0 +1,76 @@ +from synthesizer.preprocess import create_embeddings +from utils.argutils import print_args +from pathlib import Path +import argparse + +from synthesizer.preprocess import preprocess_dataset +from synthesizer.hparams import hparams +from utils.argutils import print_args +from pathlib import Path +import argparse + +recognized_datasets = [ + "aidatatang_200zh", + "magicdata", + "aishell3", + "data_aishell" +] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Preprocesses audio files from datasets, encodes them as mel spectrograms " + "and writes them to the disk. Audio files are also saved, to be used by the " + "vocoder for training.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("datasets_root", type=Path, help=\ + "Path to the directory containing your datasets.") + parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ + "Path to the output directory that will contain the mel spectrograms, the audios and the " + "embeds. Defaults to /SV2TTS/synthesizer/") + parser.add_argument("-n", "--n_processes", type=int, default=1, help=\ + "Number of processes in parallel.") + parser.add_argument("-s", "--skip_existing", action="store_true", help=\ + "Whether to overwrite existing files with the same name. Useful if the preprocessing was " + "interrupted. ") + parser.add_argument("--hparams", type=str, default="", help=\ + "Hyperparameter overrides as a comma-separated list of name-value pairs") + parser.add_argument("--no_trim", action="store_true", help=\ + "Preprocess audio without trimming silences (not recommended).") + parser.add_argument("--no_alignments", action="store_true", help=\ + "Use this option when dataset does not include alignments\ + (these are used to split long audio files into sub-utterances.)") + parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\ + "Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.") + parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\ + "Path your trained encoder model.") + parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\ + "Number of processes in parallel.An encoder is created for each, so you may need to lower " + "this value on GPUs with low memory. Set it to 1 if CUDA is unhappy") + args = parser.parse_args() + + # Process the arguments + if not hasattr(args, "out_dir"): + args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer") + assert args.dataset in recognized_datasets, 'is not supported, please vote for it in https://github.com/babysor/MockingBird/issues/10' + # Create directories + assert args.datasets_root.exists() + args.out_dir.mkdir(exist_ok=True, parents=True) + + # Verify webrtcvad is available + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + encoder_model_fpath = args.encoder_model_fpath + del args.no_trim, args.encoder_model_fpath + + args.hparams = hparams.parse(args.hparams) + n_processes_embed = args.n_processes_embed + del args.n_processes_embed + preprocess_dataset(**vars(args)) + + create_embeddings(synthesizer_root=args.out_dir, n_processes=n_processes_embed, encoder_model_fpath=encoder_model_fpath) diff --git a/pre4ppg.py b/pre4ppg.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfa0fa0e655966b2959c1b25e4a7ae358cb69a9 --- /dev/null +++ b/pre4ppg.py @@ -0,0 +1,49 @@ +from pathlib import Path +import argparse + +from ppg2mel.preprocess import preprocess_dataset +from pathlib import Path +import argparse + +recognized_datasets = [ + "aidatatang_200zh", + "aidatatang_200zh_s", # sample +] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Preprocesses audio files from datasets, to be used by the " + "ppg2mel model for training.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("datasets_root", type=Path, help=\ + "Path to the directory containing your datasets.") + parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\ + "Name of the dataset to process, allowing values: aidatatang_200zh.") + parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ + "Path to the output directory that will contain the mel spectrograms, the audios and the " + "embeds. Defaults to /PPGVC/ppg2mel/") + parser.add_argument("-n", "--n_processes", type=int, default=8, help=\ + "Number of processes in parallel.") + # parser.add_argument("-s", "--skip_existing", action="store_true", help=\ + # "Whether to overwrite existing files with the same name. Useful if the preprocessing was " + # "interrupted. ") + # parser.add_argument("--hparams", type=str, default="", help=\ + # "Hyperparameter overrides as a comma-separated list of name-value pairs") + # parser.add_argument("--no_trim", action="store_true", help=\ + # "Preprocess audio without trimming silences (not recommended).") + parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\ + "Path your trained ppg encoder model.") + parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\ + "Path your trained speaker encoder model.") + args = parser.parse_args() + + assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one' + + # Create directories + assert args.datasets_root.exists() + if not hasattr(args, "out_dir"): + args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel") + args.out_dir.mkdir(exist_ok=True, parents=True) + + preprocess_dataset(**vars(args)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a73eabd5d7dfa5072692cc1cdb510496aa929a69 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +umap-learn +visdom +librosa==0.8.1 +matplotlib>=3.3.0 +numpy==1.19.3; platform_system == "Windows" +numpy==1.19.4; platform_system != "Windows" +scipy>=1.0.0 +tqdm +sounddevice +SoundFile +Unidecode +inflect +PyQt5 +multiprocess +numba +webrtcvad; platform_system != "Windows" +pypinyin +flask +flask_wtf +flask_cors==3.0.10 +gevent==21.8.0 +flask_restx +tensorboard +streamlit==1.8.0 +PyYAML==5.4.1 +torch_complex +espnet +PyWavelets +webrtcvad-wheels +pyaudio diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..170f9db25b17d5cf4d7fded16d7f912249e2a365 --- /dev/null +++ b/run.py @@ -0,0 +1,142 @@ +import time +import os +import argparse +import torch +import numpy as np +import glob +from pathlib import Path +from tqdm import tqdm +from ppg_extractor import load_model +import librosa +import soundfile as sf +from utils.load_yaml import HpsYaml + +from encoder.audio import preprocess_wav +from encoder import inference as speacker_encoder +from vocoder.hifigan import inference as vocoder +from ppg2mel import MelDecoderMOLv2 +from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv + + +def _build_ppg2mel_model(model_config, model_file, device): + ppg2mel_model = MelDecoderMOLv2( + **model_config["model"] + ).to(device) + ckpt = torch.load(model_file, map_location=device) + ppg2mel_model.load_state_dict(ckpt["model"]) + ppg2mel_model.eval() + return ppg2mel_model + + +@torch.no_grad() +def convert(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1] + + # Build models + print("Load PPG-model, PPG2Mel-model, Vocoder-model...") + ppg_model = load_model( + Path('./ppg_extractor/saved_models/24epoch.pt'), + device, + ) + ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device) + # vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json") + vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt') + # Data related + ref_wav_path = args.ref_wav_path + ref_wav = preprocess_wav(ref_wav_path) + ref_fid = os.path.basename(ref_wav_path)[:-4] + + # TODO: specify encoder + speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt")) + ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav) + ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device) + ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav))) + + source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav")) + print(f"Number of source utterances: {len(source_file_list)}.") + + total_rtf = 0.0 + cnt = 0 + for src_wav_path in tqdm(source_file_list): + # Load the audio to a numpy array: + src_wav, _ = librosa.load(src_wav_path, sr=16000) + src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device) + src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device) + ppg = ppg_model(src_wav_tensor, src_wav_lengths) + + lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True) + min_len = min(ppg.shape[1], len(lf0_uv)) + + ppg = ppg[:, :min_len] + lf0_uv = lf0_uv[:min_len] + + start = time.time() + _, mel_pred, att_ws = ppg2mel_model.inference( + ppg, + logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device), + spembs=ref_spk_dvec, + ) + src_fid = os.path.basename(src_wav_path)[:-4] + wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav" + mel_len = mel_pred.shape[0] + rtf = (time.time() - start) / (0.01 * mel_len) + total_rtf += rtf + cnt += 1 + # continue + mel_pred= mel_pred.transpose(0, 1) + y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu()) + sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16") + + print("RTF:") + print(total_rtf / cnt) + + +def get_parser(): + parser = argparse.ArgumentParser(description="Conversion from wave input") + parser.add_argument( + "--wav_dir", + type=str, + default=None, + required=True, + help="Source wave directory.", + ) + parser.add_argument( + "--ref_wav_path", + type=str, + required=True, + help="Reference wave file path.", + ) + parser.add_argument( + "--ppg2mel_model_train_config", "-c", + type=str, + default=None, + required=True, + help="Training config file (yaml file)", + ) + parser.add_argument( + "--ppg2mel_model_file", "-m", + type=str, + default=None, + required=True, + help="ppg2mel model checkpoint file path" + ) + parser.add_argument( + "--output_dir", "-o", + type=str, + default="vc_gens_vctk_oneshot", + help="Output folder to save the converted wave." + ) + + return parser + +def main(): + parser = get_parser() + args = parser.parse_args() + convert(args) + +if __name__ == "__main__": + main() diff --git a/samples/T0055G0013S0005.wav b/samples/T0055G0013S0005.wav new file mode 100644 index 0000000000000000000000000000000000000000..4fcc65cd1a9bcd3ac3adbfe593696585a8ce5f36 Binary files /dev/null and b/samples/T0055G0013S0005.wav differ diff --git a/synthesizer/LICENSE.txt b/synthesizer/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..3337d453404ea63d5a5919d3922045374bea3da1 --- /dev/null +++ b/synthesizer/LICENSE.txt @@ -0,0 +1,24 @@ +MIT License + +Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah) +Original work Copyright (c) 2019 fatchord (https://github.com/fatchord) +Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ) +Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/synthesizer/__init__.py b/synthesizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4287ca8617970fa8fc025b75cb319c7032706910 --- /dev/null +++ b/synthesizer/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/synthesizer/audio.py b/synthesizer/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..2e03ae5eecdf50bd88b1a76c6bff59f8d4947291 --- /dev/null +++ b/synthesizer/audio.py @@ -0,0 +1,206 @@ +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile +import soundfile as sf + + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + sf.write(path, wav.astype(np.float32), sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py +def start_and_end_indices(quantized, silence_threshold=2): + for start in range(quantized.size): + if abs(quantized[start] - 127) > silence_threshold: + break + for end in range(quantized.size - 1, 1, -1): + if abs(quantized[end] - 127) > silence_threshold: + break + + assert abs(quantized[start] - 127) > silence_threshold + assert abs(quantized[end] - 127) > silence_threshold + + return start, end + +def get_hop_size(hparams): + hop_size = hparams.hop_size + if hop_size is None: + assert hparams.frame_shift_ms is not None + hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) + return hop_size + +def linearspectrogram(wav, hparams): + D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams) + S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db + + if hparams.signal_normalization: + return _normalize(S, hparams) + return S + +def melspectrogram(wav, hparams): + D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams) + S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db + + if hparams.signal_normalization: + return _normalize(S, hparams) + return S + +def inv_linear_spectrogram(linear_spectrogram, hparams): + """Converts linear spectrogram to waveform using librosa""" + if hparams.signal_normalization: + D = _denormalize(linear_spectrogram, hparams) + else: + D = linear_spectrogram + + S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear + + if hparams.use_lws: + processor = _lws_processor(hparams) + D = processor.run_lws(S.astype(np.float64).T ** hparams.power) + y = processor.istft(D).astype(np.float32) + return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize) + else: + return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize) + +def inv_mel_spectrogram(mel_spectrogram, hparams): + """Converts mel spectrogram to waveform using librosa""" + if hparams.signal_normalization: + D = _denormalize(mel_spectrogram, hparams) + else: + D = mel_spectrogram + + S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear + + if hparams.use_lws: + processor = _lws_processor(hparams) + D = processor.run_lws(S.astype(np.float64).T ** hparams.power) + y = processor.istft(D).astype(np.float32) + return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize) + else: + return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize) + +def _lws_processor(hparams): + import lws + return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech") + +def _griffin_lim(S, hparams): + """librosa implementation of Griffin-Lim + Based on https://github.com/librosa/librosa/issues/434 + """ + angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) + S_complex = np.abs(S).astype(np.complex) + y = _istft(S_complex * angles, hparams) + for i in range(hparams.griffin_lim_iters): + angles = np.exp(1j * np.angle(_stft(y, hparams))) + y = _istft(S_complex * angles, hparams) + return y + +def _stft(y, hparams): + if hparams.use_lws: + return _lws_processor(hparams).stft(y).T + else: + return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size) + +def _istft(y, hparams): + return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None +_inv_mel_basis = None + +def _linear_to_mel(spectogram, hparams): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis(hparams) + return np.dot(_mel_basis, spectogram) + +def _mel_to_linear(mel_spectrogram, hparams): + global _inv_mel_basis + if _inv_mel_basis is None: + _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams)) + return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) + +def _build_mel_basis(hparams): + assert hparams.fmax <= hparams.sample_rate // 2 + return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels, + fmin=hparams.fmin, fmax=hparams.fmax) + +def _amp_to_db(x, hparams): + min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S, hparams): + if hparams.allow_clipping_in_normalization: + if hparams.symmetric_mels: + return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value, + -hparams.max_abs_value, hparams.max_abs_value) + else: + return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value) + + assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0 + if hparams.symmetric_mels: + return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value + else: + return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)) + +def _denormalize(D, hparams): + if hparams.allow_clipping_in_normalization: + if hparams.symmetric_mels: + return (((np.clip(D, -hparams.max_abs_value, + hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + + hparams.min_level_db) + else: + return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) + + if hparams.symmetric_mels: + return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db) + else: + return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) diff --git a/synthesizer/gst_hyperparameters.py b/synthesizer/gst_hyperparameters.py new file mode 100644 index 0000000000000000000000000000000000000000..1403144651853135489c4a42d3c0f52bd0f87664 --- /dev/null +++ b/synthesizer/gst_hyperparameters.py @@ -0,0 +1,13 @@ +class GSTHyperparameters(): + E = 512 + + # reference encoder + ref_enc_filters = [32, 32, 64, 64, 128, 128] + + # style token layer + token_num = 10 + # token_emb_size = 256 + num_heads = 8 + + n_mels = 256 # Number of Mel banks to generate + diff --git a/synthesizer/hparams.py b/synthesizer/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcdb635a90a7700d4e133410268a897d3fd4a8c --- /dev/null +++ b/synthesizer/hparams.py @@ -0,0 +1,110 @@ +import ast +import pprint +import json + +class HParams(object): + def __init__(self, **kwargs): self.__dict__.update(kwargs) + def __setitem__(self, key, value): setattr(self, key, value) + def __getitem__(self, key): return getattr(self, key) + def __repr__(self): return pprint.pformat(self.__dict__) + + def parse(self, string): + # Overrides hparams from a comma-separated string of name=value pairs + if len(string) > 0: + overrides = [s.split("=") for s in string.split(",")] + keys, values = zip(*overrides) + keys = list(map(str.strip, keys)) + values = list(map(str.strip, values)) + for k in keys: + self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) + return self + + def loadJson(self, dict): + print("\Loading the json with %s\n", dict) + for k in dict.keys(): + if k not in ["tts_schedule", "tts_finetune_layers"]: + self.__dict__[k] = dict[k] + return self + + def dumpJson(self, fp): + print("\Saving the json with %s\n", fp) + with fp.open("w", encoding="utf-8") as f: + json.dump(self.__dict__, f) + return self + +hparams = HParams( + ### Signal Processing (used in both synthesizer and vocoder) + sample_rate = 16000, + n_fft = 800, + num_mels = 80, + hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125) + win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050) + fmin = 55, + min_level_db = -100, + ref_level_db = 20, + max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small. + preemphasis = 0.97, # Filter coefficient to use if preemphasize is True + preemphasize = True, + + ### Tacotron Text-to-Speech (TTS) + tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs + tts_encoder_dims = 256, + tts_decoder_dims = 128, + tts_postnet_dims = 512, + tts_encoder_K = 5, + tts_lstm_dims = 1024, + tts_postnet_K = 5, + tts_num_highways = 4, + tts_dropout = 0.5, + tts_cleaner_names = ["basic_cleaners"], + tts_stop_threshold = -3.4, # Value below which audio generation ends. + # For example, for a range of [-4, 4], this + # will terminate the sequence at the first + # frame that has all values < -3.4 + + ### Tacotron Training + tts_schedule = [(2, 1e-3, 10_000, 12), # Progressive training schedule + (2, 5e-4, 15_000, 12), # (r, lr, step, batch_size) + (2, 2e-4, 20_000, 12), # (r, lr, step, batch_size) + (2, 1e-4, 30_000, 12), # + (2, 5e-5, 40_000, 12), # + (2, 1e-5, 60_000, 12), # + (2, 5e-6, 160_000, 12), # r = reduction factor (# of mel frames + (2, 3e-6, 320_000, 12), # synthesized for each decoder iteration) + (2, 1e-6, 640_000, 12)], # lr = learning rate + + tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed + tts_eval_interval = 500, # Number of steps between model evaluation (sample generation) + # Set to -1 to generate after completing epoch, or 0 to disable + tts_eval_num_samples = 1, # Makes this number of samples + + ## For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj + tts_finetune_layers = [], + + ### Data Preprocessing + max_mel_frames = 900, + rescale = True, + rescaling_max = 0.9, + synthesis_batch_size = 16, # For vocoder preprocessing and inference. + + ### Mel Visualization and Griffin-Lim + signal_normalization = True, + power = 1.5, + griffin_lim_iters = 60, + + ### Audio processing options + fmax = 7600, # Should not exceed (sample_rate // 2) + allow_clipping_in_normalization = True, # Used when signal_normalization = True + clip_mels_length = True, # If true, discards samples exceeding max_mel_frames + use_lws = False, # "Fast spectrogram phase recovery using local weighted sums" + symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True, + # and [0, max_abs_value] if False + trim_silence = True, # Use with sample_rate of 16000 for best results + + ### SV2TTS + speaker_embedding_size = 256, # Dimension for the speaker embedding + silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split + utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded + use_gst = True, # Whether to use global style token + use_ser_for_gst = True, # Whether to use speaker embedding referenced for global style token + ) diff --git a/synthesizer/inference.py b/synthesizer/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff856bdf4d43fd09dec1061818590937fef6512 --- /dev/null +++ b/synthesizer/inference.py @@ -0,0 +1,187 @@ +import torch +from synthesizer import audio +from synthesizer.hparams import hparams +from synthesizer.models.tacotron import Tacotron +from synthesizer.utils.symbols import symbols +from synthesizer.utils.text import text_to_sequence +from vocoder.display import simple_table +from pathlib import Path +from typing import Union, List +import numpy as np +import librosa +from utils import logmmse +import json +from pypinyin import lazy_pinyin, Style + +class Synthesizer: + sample_rate = hparams.sample_rate + hparams = hparams + + def __init__(self, model_fpath: Path, verbose=True): + """ + The model isn't instantiated and loaded in memory until needed or until load() is called. + + :param model_fpath: path to the trained model file + :param verbose: if False, prints less information when using the model + """ + self.model_fpath = model_fpath + self.verbose = verbose + + # Check for GPU + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + if self.verbose: + print("Synthesizer using device:", self.device) + + # Tacotron model will be instantiated later on first use. + self._model = None + + def is_loaded(self): + """ + Whether the model is loaded in memory. + """ + return self._model is not None + + def load(self): + # Try to scan config file + model_config_fpaths = list(self.model_fpath.parent.rglob("*.json")) + if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): + with model_config_fpaths[0].open("r", encoding="utf-8") as f: + hparams.loadJson(json.load(f)) + """ + Instantiates and loads the model given the weights file that was passed in the constructor. + """ + self._model = Tacotron(embed_dims=hparams.tts_embed_dims, + num_chars=len(symbols), + encoder_dims=hparams.tts_encoder_dims, + decoder_dims=hparams.tts_decoder_dims, + n_mels=hparams.num_mels, + fft_bins=hparams.num_mels, + postnet_dims=hparams.tts_postnet_dims, + encoder_K=hparams.tts_encoder_K, + lstm_dims=hparams.tts_lstm_dims, + postnet_K=hparams.tts_postnet_K, + num_highways=hparams.tts_num_highways, + dropout=hparams.tts_dropout, + stop_threshold=hparams.tts_stop_threshold, + speaker_embedding_size=hparams.speaker_embedding_size).to(self.device) + + self._model.load(self.model_fpath, self.device) + self._model.eval() + + if self.verbose: + print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"])) + + def synthesize_spectrograms(self, texts: List[str], + embeddings: Union[np.ndarray, List[np.ndarray]], + return_alignments=False, style_idx=0, min_stop_token=5, steps=2000): + """ + Synthesizes mel spectrograms from texts and speaker embeddings. + + :param texts: a list of N text prompts to be synthesized + :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256) + :param return_alignments: if True, a matrix representing the alignments between the + characters + and each decoder output step will be returned for each spectrogram + :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the + sequence length of spectrogram i, and possibly the alignments. + """ + # Load the model on the first request. + if not self.is_loaded(): + self.load() + + # Print some info about the model when it is loaded + tts_k = self._model.get_step() // 1000 + + simple_table([("Tacotron", str(tts_k) + "k"), + ("r", self._model.r)]) + + print("Read " + str(texts)) + texts = [" ".join(lazy_pinyin(v, style=Style.TONE3, neutral_tone_with_five=True)) for v in texts] + print("Synthesizing " + str(texts)) + # Preprocess text inputs + inputs = [text_to_sequence(text, hparams.tts_cleaner_names) for text in texts] + if not isinstance(embeddings, list): + embeddings = [embeddings] + + # Batch inputs + batched_inputs = [inputs[i:i+hparams.synthesis_batch_size] + for i in range(0, len(inputs), hparams.synthesis_batch_size)] + batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size] + for i in range(0, len(embeddings), hparams.synthesis_batch_size)] + + specs = [] + for i, batch in enumerate(batched_inputs, 1): + if self.verbose: + print(f"\n| Generating {i}/{len(batched_inputs)}") + + # Pad texts so they are all the same length + text_lens = [len(text) for text in batch] + max_text_len = max(text_lens) + chars = [pad1d(text, max_text_len) for text in batch] + chars = np.stack(chars) + + # Stack speaker embeddings into 2D array for batch processing + speaker_embeds = np.stack(batched_embeds[i-1]) + + # Convert to tensor + chars = torch.tensor(chars).long().to(self.device) + speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device) + + # Inference + _, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token, steps=steps) + mels = mels.detach().cpu().numpy() + for m in mels: + # Trim silence from end of each spectrogram + while np.max(m[:, -1]) < hparams.tts_stop_threshold: + m = m[:, :-1] + specs.append(m) + + if self.verbose: + print("\n\nDone.\n") + return (specs, alignments) if return_alignments else specs + + @staticmethod + def load_preprocess_wav(fpath): + """ + Loads and preprocesses an audio file under the same conditions the audio files were used to + train the synthesizer. + """ + wav = librosa.load(path=str(fpath), sr=hparams.sample_rate)[0] + if hparams.rescale: + wav = wav / np.abs(wav).max() * hparams.rescaling_max + # denoise + if len(wav) > hparams.sample_rate*(0.3+0.1): + noise_wav = np.concatenate([wav[:int(hparams.sample_rate*0.15)], + wav[-int(hparams.sample_rate*0.15):]]) + profile = logmmse.profile_noise(noise_wav, hparams.sample_rate) + wav = logmmse.denoise(wav, profile) + return wav + + @staticmethod + def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]): + """ + Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that + were fed to the synthesizer when training. + """ + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav = Synthesizer.load_preprocess_wav(fpath_or_wav) + else: + wav = fpath_or_wav + + mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32) + return mel_spectrogram + + @staticmethod + def griffin_lim(mel): + """ + Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built + with the same parameters present in hparams.py. + """ + return audio.inv_mel_spectrogram(mel, hparams) + + +def pad1d(x, max_len, pad_value=0): + return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value) diff --git a/synthesizer/models/base.py b/synthesizer/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..13b32a1b1aff4b5b6024e7574810e76452283b8d --- /dev/null +++ b/synthesizer/models/base.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import imp +import numpy as np + +class Base(nn.Module): + def __init__(self, stop_threshold): + super().__init__() + + self.init_model() + self.num_params() + + self.register_buffer("step", torch.zeros(1, dtype=torch.long)) + self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)) + + @property + def r(self): + return self.decoder.r.item() + + @r.setter + def r(self, value): + self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) + + def init_model(self): + for p in self.parameters(): + if p.dim() > 1: nn.init.xavier_uniform_(p) + + def finetune_partial(self, whitelist_layers): + self.zero_grad() + for name, child in self.named_children(): + if name in whitelist_layers: + print("Trainable Layer: %s" % name) + print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()])) + for param in child.parameters(): + param.requires_grad = False + + def get_step(self): + return self.step.data.item() + + def reset_step(self): + # assignment to parameters or buffers is overloaded, updates internal dict entry + self.step = self.step.data.new_tensor(1) + + def log(self, path, msg): + with open(path, "a") as f: + print(msg, file=f) + + def load(self, path, device, optimizer=None): + # Use device of model params as location for loaded state + checkpoint = torch.load(str(path), map_location=device) + self.load_state_dict(checkpoint["model_state"], strict=False) + + if "optimizer_state" in checkpoint and optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer_state"]) + + def save(self, path, optimizer=None): + if optimizer is not None: + torch.save({ + "model_state": self.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, str(path)) + else: + torch.save({ + "model_state": self.state_dict(), + }, str(path)) + + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print("Trainable Parameters: %.3fM" % parameters) + return parameters diff --git a/synthesizer/models/sublayer/__init__.py b/synthesizer/models/sublayer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4287ca8617970fa8fc025b75cb319c7032706910 --- /dev/null +++ b/synthesizer/models/sublayer/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/synthesizer/models/sublayer/cbhg.py b/synthesizer/models/sublayer/cbhg.py new file mode 100644 index 0000000000000000000000000000000000000000..10eb6bb85dd2a1711fe7c92ec77bbaaf786f7a53 --- /dev/null +++ b/synthesizer/models/sublayer/cbhg.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from .common.batch_norm_conv import BatchNormConv +from .common.highway_network import HighwayNetwork + +class CBHG(nn.Module): + def __init__(self, K, in_channels, channels, proj_channels, num_highways): + super().__init__() + + # List of all rnns to call `flatten_parameters()` on + self._to_flatten = [] + + self.bank_kernels = [i for i in range(1, K + 1)] + self.conv1d_bank = nn.ModuleList() + for k in self.bank_kernels: + conv = BatchNormConv(in_channels, channels, k) + self.conv1d_bank.append(conv) + + self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) + + self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) + self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) + + # Fix the highway input if necessary + if proj_channels[-1] != channels: + self.highway_mismatch = True + self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) + else: + self.highway_mismatch = False + + self.highways = nn.ModuleList() + for i in range(num_highways): + hn = HighwayNetwork(channels) + self.highways.append(hn) + + self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) + self._to_flatten.append(self.rnn) + + # Avoid fragmentation of RNN parameters and associated warning + self._flatten_parameters() + + def forward(self, x): + # Although we `_flatten_parameters()` on init, when using DataParallel + # the model gets replicated, making it no longer guaranteed that the + # weights are contiguous in GPU memory. Hence, we must call it again + self.rnn.flatten_parameters() + + # Save these for later + residual = x + seq_len = x.size(-1) + conv_bank = [] + + # Convolution Bank + for conv in self.conv1d_bank: + c = conv(x) # Convolution + conv_bank.append(c[:, :, :seq_len]) + + # Stack along the channel axis + conv_bank = torch.cat(conv_bank, dim=1) + + # dump the last padding to fit residual + x = self.maxpool(conv_bank)[:, :, :seq_len] + + # Conv1d projections + x = self.conv_project1(x) + x = self.conv_project2(x) + + # Residual Connect + x = x + residual + + # Through the highways + x = x.transpose(1, 2) + if self.highway_mismatch is True: + x = self.pre_highway(x) + for h in self.highways: x = h(x) + + # And then the RNN + x, _ = self.rnn(x) + return x + + def _flatten_parameters(self): + """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used + to improve efficiency and avoid PyTorch yelling at us.""" + [m.flatten_parameters() for m in self._to_flatten] + diff --git a/synthesizer/models/sublayer/common/batch_norm_conv.py b/synthesizer/models/sublayer/common/batch_norm_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..0d07a4a9495657bcd434111ec0b6f16ca35211c2 --- /dev/null +++ b/synthesizer/models/sublayer/common/batch_norm_conv.py @@ -0,0 +1,14 @@ +import torch.nn as nn +import torch.nn.functional as F + +class BatchNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel, relu=True): + super().__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) + self.bnorm = nn.BatchNorm1d(out_channels) + self.relu = relu + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) if self.relu is True else x + return self.bnorm(x) \ No newline at end of file diff --git a/synthesizer/models/sublayer/common/highway_network.py b/synthesizer/models/sublayer/common/highway_network.py new file mode 100644 index 0000000000000000000000000000000000000000..d311c6924db6dfc247f69cc266d6c1975b6e03cd --- /dev/null +++ b/synthesizer/models/sublayer/common/highway_network.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class HighwayNetwork(nn.Module): + def __init__(self, size): + super().__init__() + self.W1 = nn.Linear(size, size) + self.W2 = nn.Linear(size, size) + self.W1.bias.data.fill_(0.) + + def forward(self, x): + x1 = self.W1(x) + x2 = self.W2(x) + g = torch.sigmoid(x2) + y = g * F.relu(x1) + (1. - g) * x + return y diff --git a/synthesizer/models/sublayer/global_style_token.py b/synthesizer/models/sublayer/global_style_token.py new file mode 100644 index 0000000000000000000000000000000000000000..21ce07e7056ee575ee37e3855e1489d6cea7ccae --- /dev/null +++ b/synthesizer/models/sublayer/global_style_token.py @@ -0,0 +1,145 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as tFunctional +from synthesizer.gst_hyperparameters import GSTHyperparameters as hp +from synthesizer.hparams import hparams + + +class GlobalStyleToken(nn.Module): + """ + inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel] + speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + def __init__(self, speaker_embedding_dim=None): + + super().__init__() + self.encoder = ReferenceEncoder() + self.stl = STL(speaker_embedding_dim) + + def forward(self, inputs, speaker_embedding=None): + enc_out = self.encoder(inputs) + # concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py + if hparams.use_ser_for_gst and speaker_embedding is not None: + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) + style_embed = self.stl(enc_out) + + return style_embed + + +class ReferenceEncoder(nn.Module): + ''' + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + ''' + + def __init__(self): + + super().__init__() + K = len(hp.ref_enc_filters) + filters = [1] + hp.ref_enc_filters + convs = [nn.Conv2d(in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)) for i in range(K)] + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)]) + + out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K) + self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels, + hidden_size=hp.E // 2, + batch_first=True) + + def forward(self, inputs): + N = inputs.size(0) + out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels] + for conv, bn in zip(self.convs, self.bns): + out = conv(out) + out = bn(out) + out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, E//2] + + return out.squeeze(0) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class STL(nn.Module): + ''' + inputs --- [N, E//2] + ''' + + def __init__(self, speaker_embedding_dim=None): + + super().__init__() + self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads)) + d_q = hp.E // 2 + d_k = hp.E // hp.num_heads + # self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v) + if hparams.use_ser_for_gst and speaker_embedding_dim is not None: + d_q += speaker_embedding_dim + self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads) + + init.normal_(self.embed, mean=0, std=0.5) + + def forward(self, inputs): + N = inputs.size(0) + query = inputs.unsqueeze(1) # [N, 1, E//2] + keys = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads] + style_embed = self.attention(query, keys) + + return style_embed + + +class MultiHeadAttention(nn.Module): + ''' + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + ''' + + def __init__(self, query_dim, key_dim, num_units, num_heads): + + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query, key): + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + + split_size = self.num_units // self.num_heads + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k ** 0.5)) + scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim ** 0.5) + scores = tFunctional.softmax(scores, dim=3) + + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out diff --git a/synthesizer/models/sublayer/lsa.py b/synthesizer/models/sublayer/lsa.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2dfa52d629793b11a2460be10d17a726ab5303 --- /dev/null +++ b/synthesizer/models/sublayer/lsa.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LSA(nn.Module): + def __init__(self, attn_dim, kernel_size=31, filters=32): + super().__init__() + self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True) + self.L = nn.Linear(filters, attn_dim, bias=False) + self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term + self.v = nn.Linear(attn_dim, 1, bias=False) + self.cumulative = None + self.attention = None + + def init_attention(self, encoder_seq_proj): + device = encoder_seq_proj.device # use same device as parameters + b, t, c = encoder_seq_proj.size() + self.cumulative = torch.zeros(b, t, device=device) + self.attention = torch.zeros(b, t, device=device) + + def forward(self, encoder_seq_proj, query, times, chars): + + if times == 0: self.init_attention(encoder_seq_proj) + + processed_query = self.W(query).unsqueeze(1) + + location = self.cumulative.unsqueeze(1) + processed_loc = self.L(self.conv(location).transpose(1, 2)) + + u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) + u = u.squeeze(-1) + + # Mask zero padding chars + u = u * (chars != 0).float() + + # Smooth Attention + # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) + scores = F.softmax(u, dim=1) + self.attention = scores + self.cumulative = self.cumulative + self.attention + + return scores.unsqueeze(-1).transpose(1, 2) diff --git a/synthesizer/models/sublayer/pre_net.py b/synthesizer/models/sublayer/pre_net.py new file mode 100644 index 0000000000000000000000000000000000000000..886646a154c68298deeec09dbad736d617f73155 --- /dev/null +++ b/synthesizer/models/sublayer/pre_net.py @@ -0,0 +1,27 @@ +import torch.nn as nn +import torch.nn.functional as F + +class PreNet(nn.Module): + def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): + super().__init__() + self.fc1 = nn.Linear(in_dims, fc1_dims) + self.fc2 = nn.Linear(fc1_dims, fc2_dims) + self.p = dropout + + def forward(self, x): + """forward + + Args: + x (3D tensor with size `[batch_size, num_chars, tts_embed_dims]`): input texts list + + Returns: + 3D tensor with size `[batch_size, num_chars, encoder_dims]` + + """ + x = self.fc1(x) + x = F.relu(x) + x = F.dropout(x, self.p, training=True) + x = self.fc2(x) + x = F.relu(x) + x = F.dropout(x, self.p, training=True) + return x diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b01bbae0e6dc95d68bbb983c70706d76e1d990 --- /dev/null +++ b/synthesizer/models/tacotron.py @@ -0,0 +1,298 @@ +import torch +import torch.nn as nn +from .sublayer.global_style_token import GlobalStyleToken +from .sublayer.pre_net import PreNet +from .sublayer.cbhg import CBHG +from .sublayer.lsa import LSA +from .base import Base +from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp +from synthesizer.hparams import hparams + +class Encoder(nn.Module): + def __init__(self, num_chars, embed_dims=512, encoder_dims=256, K=5, num_highways=4, dropout=0.5): + """ Encoder for SV2TTS + + Args: + num_chars (int): length of symbols + embed_dims (int, optional): embedding dim for input texts. Defaults to 512. + encoder_dims (int, optional): output dim for encoder. Defaults to 256. + K (int, optional): _description_. Defaults to 5. + num_highways (int, optional): _description_. Defaults to 4. + dropout (float, optional): _description_. Defaults to 0.5. + """ + super().__init__() + self.embedding = nn.Embedding(num_chars, embed_dims) + self.pre_net = PreNet(embed_dims, fc1_dims=encoder_dims, fc2_dims=encoder_dims, + dropout=dropout) + self.cbhg = CBHG(K=K, in_channels=encoder_dims, channels=encoder_dims, + proj_channels=[encoder_dims, encoder_dims], + num_highways=num_highways) + + def forward(self, x): + """forward pass for encoder + + Args: + x (2D tensor with size `[batch_size, text_num_chars]`): input texts list + + Returns: + 3D tensor with size `[batch_size, text_num_chars, encoder_dims]` + + """ + x = self.embedding(x) # return: [batch_size, text_num_chars, tts_embed_dims] + x = self.pre_net(x) # return: [batch_size, text_num_chars, encoder_dims] + x.transpose_(1, 2) # return: [batch_size, encoder_dims, text_num_chars] + return self.cbhg(x) # return: [batch_size, text_num_chars, encoder_dims] + +class Decoder(nn.Module): + # Class variable because its value doesn't change between classes + # yet ought to be scoped by class because its a property of a Decoder + max_r = 20 + def __init__(self, n_mels, input_dims, decoder_dims, lstm_dims, + dropout, speaker_embedding_size): + super().__init__() + self.register_buffer("r", torch.tensor(1, dtype=torch.int)) + self.n_mels = n_mels + self.prenet = PreNet(n_mels, fc1_dims=decoder_dims * 2, fc2_dims=decoder_dims * 2, + dropout=dropout) + self.attn_net = LSA(decoder_dims) + if hparams.use_gst: + speaker_embedding_size += gst_hp.E + self.attn_rnn = nn.GRUCell(input_dims + decoder_dims * 2, decoder_dims) + self.rnn_input = nn.Linear(input_dims + decoder_dims, lstm_dims) + self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) + self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) + self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) + self.stop_proj = nn.Linear(input_dims + lstm_dims, 1) + + def zoneout(self, prev, current, device, p=0.1): + mask = torch.zeros(prev.size(),device=device).bernoulli_(p) + return prev * mask + current * (1 - mask) + + def forward(self, encoder_seq, encoder_seq_proj, prenet_in, + hidden_states, cell_states, context_vec, times, chars): + """_summary_ + + Args: + encoder_seq (3D tensor `[batch_size, text_num_chars, project_dim(default to 512)]`): _description_ + encoder_seq_proj (3D tensor `[batch_size, text_num_chars, decoder_dims(default to 128)]`): _description_ + prenet_in (2D tensor `[batch_size, n_mels]`): _description_ + hidden_states (_type_): _description_ + cell_states (_type_): _description_ + context_vec (2D tensor `[batch_size, project_dim(default to 512)]`): _description_ + times (int): the number of times runned + chars (2D tensor with size `[batch_size, text_num_chars]`): original texts list input + + """ + # Need this for reshaping mels + batch_size = encoder_seq.size(0) + device = encoder_seq.device + # Unpack the hidden and cell states + attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states + rnn1_cell, rnn2_cell = cell_states + + # PreNet for the Attention RNN + prenet_out = self.prenet(prenet_in) # return: `[batch_size, decoder_dims * 2(256)]` + + # Compute the Attention RNN hidden state + attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) # `[batch_size, project_dim + decoder_dims * 2 (768)]` + attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) # `[batch_size, decoder_dims (128)]` + + # Compute the attention scores + scores = self.attn_net(encoder_seq_proj, attn_hidden, times, chars) + + # Dot product to create the context vector + context_vec = scores @ encoder_seq + context_vec = context_vec.squeeze(1) + + # Concat Attention RNN output w. Context Vector & project + x = torch.cat([context_vec, attn_hidden], dim=1) # `[batch_size, project_dim + decoder_dims (630)]` + x = self.rnn_input(x) # `[batch_size, lstm_dims(1024)]` + + # Compute first Residual RNN, training with fixed zoneout rate 0.1 + rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) # `[batch_size, lstm_dims(1024)]` + if self.training: + rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device) + else: + rnn1_hidden = rnn1_hidden_next + x = x + rnn1_hidden + + # Compute second Residual RNN + rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) # `[batch_size, lstm_dims(1024)]` + if self.training: + rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device) + else: + rnn2_hidden = rnn2_hidden_next + x = x + rnn2_hidden + + # Project Mels + mels = self.mel_proj(x) # `[batch_size, 1600]` + mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] # `[batch_size, n_mels, r]` + hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) + cell_states = (rnn1_cell, rnn2_cell) + + # Stop token prediction + s = torch.cat((x, context_vec), dim=1) + s = self.stop_proj(s) + stop_tokens = torch.sigmoid(s) + + return mels, scores, hidden_states, cell_states, context_vec, stop_tokens + +class Tacotron(Base): + def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, + fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways, + dropout, stop_threshold, speaker_embedding_size): + super().__init__(stop_threshold) + self.n_mels = n_mels + self.lstm_dims = lstm_dims + self.encoder_dims = encoder_dims + self.decoder_dims = decoder_dims + self.speaker_embedding_size = speaker_embedding_size + self.encoder = Encoder(num_chars, embed_dims, encoder_dims, + encoder_K, num_highways, dropout) + self.project_dims = encoder_dims + speaker_embedding_size + if hparams.use_gst: + self.project_dims += gst_hp.E + self.encoder_proj = nn.Linear(self.project_dims, decoder_dims, bias=False) + if hparams.use_gst: + self.gst = GlobalStyleToken(speaker_embedding_size) + self.decoder = Decoder(n_mels, self.project_dims, decoder_dims, lstm_dims, + dropout, speaker_embedding_size) + self.postnet = CBHG(postnet_K, n_mels, postnet_dims, + [postnet_dims, fft_bins], num_highways) + self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False) + + @staticmethod + def _concat_speaker_embedding(outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand( + outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) + return outputs + + @staticmethod + def _add_speaker_embedding(x, speaker_embedding): + """Add speaker embedding + This concats the speaker embedding for each char in the encoder output + Args: + x (3D tensor with size `[batch_size, text_num_chars, encoder_dims]`): the encoder output + speaker_embedding (2D tensor `[batch_size, speaker_embedding_size]`): the speaker embedding + + Returns: + 3D tensor with size `[batch_size, text_num_chars, encoder_dims+speaker_embedding_size]` + """ + # Save the dimensions as human-readable names + batch_size = x.size()[0] + text_num_chars = x.size()[1] + + # Start by making a copy of each speaker embedding to match the input text length + # The output of this has size (batch_size, text_num_chars * speaker_embedding_size) + speaker_embedding_size = speaker_embedding.size()[1] + e = speaker_embedding.repeat_interleave(text_num_chars, dim=1) + + # Reshape it and transpose + e = e.reshape(batch_size, speaker_embedding_size, text_num_chars) + e = e.transpose(1, 2) + + # Concatenate the tiled speaker embedding with the encoder output + x = torch.cat((x, e), 2) + return x + + def forward(self, texts, mels, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5): + """Forward pass for Tacotron + + Args: + texts (`[batch_size, text_num_chars]`): input texts list + mels (`[batch_size, varied_mel_lengths, steps]`): mels for comparison (training only) + speaker_embedding (`[batch_size, speaker_embedding_size(default to 256)]`): referring embedding. + steps (int, optional): . Defaults to 2000. + style_idx (int, optional): GST style selected. Defaults to 0. + min_stop_token (int, optional): decoder min_stop_token. Defaults to 5. + """ + device = texts.device # use same device as parameters + + if self.training: + self.step += 1 + batch_size, _, steps = mels.size() + else: + batch_size, _ = texts.size() + + # Initialise all hidden states and pack into tuple + attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) + rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) + + # Initialise all lstm cell states and pack into tuple + rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + cell_states = (rnn1_cell, rnn2_cell) + + # Frame for start of decoder loop + go_frame = torch.zeros(batch_size, self.n_mels, device=device) + + # SV2TTS: Run the encoder with the speaker embedding + # The projection avoids unnecessary matmuls in the decoder loop + encoder_seq = self.encoder(texts) + + encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding) + + if hparams.use_gst and self.gst is not None: + if self.training: + style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced + # style_embed = style_embed.expand_as(encoder_seq) + # encoder_seq = torch.cat((encoder_seq, style_embed), 2) + elif style_idx >= 0 and style_idx < 10: + query = torch.zeros(1, 1, self.gst.stl.attention.num_units) + if device.type == 'cuda': + query = query.cuda() + gst_embed = torch.tanh(self.gst.stl.embed) + key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) + style_embed = self.gst.stl.attention(query, key) + else: + speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) + style_embed = self.gst(speaker_embedding_style, speaker_embedding) + encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # return: [batch_size, text_num_chars, project_dims] + + encoder_seq_proj = self.encoder_proj(encoder_seq) # return: [batch_size, text_num_chars, decoder_dims] + + # Need a couple of lists for outputs + mel_outputs, attn_scores, stop_outputs = [], [], [] + + # Need an initial context vector + context_vec = torch.zeros(batch_size, self.project_dims, device=device) + + # Run the decoder loop + for t in range(0, steps, self.r): + if self.training: + prenet_in = mels[:, :, t -1] if t > 0 else go_frame + else: + prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame + mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ + self.decoder(encoder_seq, encoder_seq_proj, prenet_in, + hidden_states, cell_states, context_vec, t, texts) + mel_outputs.append(mel_frames) + attn_scores.append(scores) + stop_outputs.extend([stop_tokens] * self.r) + if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break + + # Concat the mel outputs into sequence + mel_outputs = torch.cat(mel_outputs, dim=2) + + # Post-Process for Linear Spectrograms + postnet_out = self.postnet(mel_outputs) + linear = self.post_proj(postnet_out) + linear = linear.transpose(1, 2) + + # For easy visualisation + attn_scores = torch.cat(attn_scores, 1) + # attn_scores = attn_scores.cpu().data.numpy() + stop_outputs = torch.cat(stop_outputs, 1) + + if self.training: + self.train() + + return mel_outputs, linear, attn_scores, stop_outputs + + def generate(self, x, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5): + self.eval() + mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token) + return mel_outputs, linear, attn_scores diff --git a/synthesizer/preprocess.py b/synthesizer/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..dc305e40e217ca6693a8023f49f9bc8ce5d45d57 --- /dev/null +++ b/synthesizer/preprocess.py @@ -0,0 +1,120 @@ +from multiprocessing.pool import Pool + +from functools import partial +from itertools import chain +from pathlib import Path +from tqdm import tqdm +import numpy as np +from encoder import inference as encoder +from synthesizer.preprocess_speaker import preprocess_speaker_general +from synthesizer.preprocess_transcript import preprocess_transcript_aishell3, preprocess_transcript_magicdata + +data_info = { + "aidatatang_200zh": { + "subfolders": ["corpus/train"], + "trans_filepath": "transcript/aidatatang_200_zh_transcript.txt", + "speak_func": preprocess_speaker_general + }, + "magicdata": { + "subfolders": ["train"], + "trans_filepath": "train/TRANS.txt", + "speak_func": preprocess_speaker_general, + "transcript_func": preprocess_transcript_magicdata, + }, + "aishell3":{ + "subfolders": ["train/wav"], + "trans_filepath": "train/content.txt", + "speak_func": preprocess_speaker_general, + "transcript_func": preprocess_transcript_aishell3, + }, + "data_aishell":{ + "subfolders": ["wav/train"], + "trans_filepath": "transcript/aishell_transcript_v0.8.txt", + "speak_func": preprocess_speaker_general + } +} + +def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, + skip_existing: bool, hparams, no_alignments: bool, + dataset: str): + dataset_info = data_info[dataset] + # Gather the input directories + dataset_root = datasets_root.joinpath(dataset) + input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in dataset_info["subfolders"]] + print("\n ".join(map(str, ["Using data from:"] + input_dirs))) + assert all(input_dir.exists() for input_dir in input_dirs) + + # Create the output directories for each output file type + out_dir.joinpath("mels").mkdir(exist_ok=True) + out_dir.joinpath("audio").mkdir(exist_ok=True) + + # Create a metadata file + metadata_fpath = out_dir.joinpath("train.txt") + metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8") + + # Preprocess the dataset + dict_info = {} + transcript_dirs = dataset_root.joinpath(dataset_info["trans_filepath"]) + assert transcript_dirs.exists(), str(transcript_dirs)+" not exist." + with open(transcript_dirs, "r", encoding="utf-8") as dict_transcript: + # process with specific function for your dataset + if "transcript_func" in dataset_info: + dataset_info["transcript_func"](dict_info, dict_transcript) + else: + for v in dict_transcript: + if not v: + continue + v = v.strip().replace("\n","").replace("\t"," ").split(" ") + dict_info[v[0]] = " ".join(v[1:]) + + speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs)) + func = partial(dataset_info["speak_func"], out_dir=out_dir, skip_existing=skip_existing, + hparams=hparams, dict_info=dict_info, no_alignments=no_alignments) + job = Pool(n_processes).imap(func, speaker_dirs) + for speaker_metadata in tqdm(job, dataset, len(speaker_dirs), unit="speakers"): + for metadatum in speaker_metadata: + metadata_file.write("|".join(str(x) for x in metadatum) + "\n") + metadata_file.close() + + # Verify the contents of the metadata file + with metadata_fpath.open("r", encoding="utf-8") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + mel_frames = sum([int(m[4]) for m in metadata]) + timesteps = sum([int(m[3]) for m in metadata]) + sample_rate = hparams.sample_rate + hours = (timesteps / sample_rate) / 3600 + print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." % + (len(metadata), mel_frames, timesteps, hours)) + print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata)) + print("Max mel frames length: %d" % max(int(m[4]) for m in metadata)) + print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata)) + +def embed_utterance(fpaths, encoder_model_fpath): + if not encoder.is_loaded(): + encoder.load_model(encoder_model_fpath) + + # Compute the speaker embedding of the utterance + wav_fpath, embed_fpath = fpaths + wav = np.load(wav_fpath) + wav = encoder.preprocess_wav(wav) + embed = encoder.embed_utterance(wav) + np.save(embed_fpath, embed, allow_pickle=False) + + +def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int): + wav_dir = synthesizer_root.joinpath("audio") + metadata_fpath = synthesizer_root.joinpath("train.txt") + assert wav_dir.exists() and metadata_fpath.exists() + embed_dir = synthesizer_root.joinpath("embeds") + embed_dir.mkdir(exist_ok=True) + + # Gather the input wave filepath and the target output embed filepath + with metadata_fpath.open("r", encoding="utf-8") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata] + + # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here. + # Embed the utterances in separate threads + func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath) + job = Pool(n_processes).imap(func, fpaths) + list(tqdm(job, "Embedding", len(fpaths), unit="utterances")) diff --git a/synthesizer/preprocess_speaker.py b/synthesizer/preprocess_speaker.py new file mode 100644 index 0000000000000000000000000000000000000000..28ddad4f113a6543b94e7d03e332af23cf2436bb --- /dev/null +++ b/synthesizer/preprocess_speaker.py @@ -0,0 +1,99 @@ +import librosa +import numpy as np + +from encoder import inference as encoder +from utils import logmmse +from synthesizer import audio +from pathlib import Path +from pypinyin import Style +from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin +from pypinyin.converter import DefaultConverter +from pypinyin.core import Pinyin + +class PinyinConverter(NeutralToneWith5Mixin, DefaultConverter): + pass + +pinyin = Pinyin(PinyinConverter()).pinyin + + +def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str, + skip_existing: bool, hparams): + ## FOR REFERENCE: + # For you not to lose your head if you ever wish to change things here or implement your own + # synthesizer. + # - Both the audios and the mel spectrograms are saved as numpy arrays + # - There is no processing done to the audios that will be saved to disk beyond volume + # normalization (in split_on_silences) + # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This + # is why we re-apply it on the audio on the side of the vocoder. + # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved + # without extra padding. This means that you won't have an exact relation between the length + # of the wav and of the mel spectrogram. See the vocoder data loader. + + + # Skip existing utterances if needed + mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename) + wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename) + if skip_existing and mel_fpath.exists() and wav_fpath.exists(): + return None + + # Trim silence + if hparams.trim_silence: + wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True) + + # Skip utterances that are too short + if len(wav) < hparams.utterance_min_duration * hparams.sample_rate: + return None + + # Compute the mel spectrogram + mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32) + mel_frames = mel_spectrogram.shape[1] + + # Skip utterances that are too long + if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length: + return None + + # Write the spectrogram, embed and audio to disk + np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False) + np.save(wav_fpath, wav, allow_pickle=False) + + # Return a tuple describing this training example + return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text + + +def _split_on_silences(wav_fpath, words, hparams): + # Load the audio waveform + wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate) + wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0] + if hparams.rescale: + wav = wav / np.abs(wav).max() * hparams.rescaling_max + # denoise, we may not need it here. + if len(wav) > hparams.sample_rate*(0.3+0.1): + noise_wav = np.concatenate([wav[:int(hparams.sample_rate*0.15)], + wav[-int(hparams.sample_rate*0.15):]]) + profile = logmmse.profile_noise(noise_wav, hparams.sample_rate) + wav = logmmse.denoise(wav, profile, eta=0) + + resp = pinyin(words, style=Style.TONE3) + res = [v[0] for v in resp if v[0].strip()] + res = " ".join(res) + + return wav, res + +def preprocess_speaker_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams, dict_info, no_alignments: bool): + metadata = [] + extensions = ["*.wav", "*.flac", "*.mp3"] + for extension in extensions: + wav_fpath_list = speaker_dir.glob(extension) + # Iterate over each wav + for wav_fpath in wav_fpath_list: + words = dict_info.get(wav_fpath.name.split(".")[0]) + words = dict_info.get(wav_fpath.name) if not words else words # try with wav + if not words: + print("no wordS") + continue + sub_basename = "%s_%02d" % (wav_fpath.name, 0) + wav, text = _split_on_silences(wav_fpath, words, hparams) + metadata.append(_process_utterance(wav, text, out_dir, sub_basename, + skip_existing, hparams)) + return [m for m in metadata if m is not None] diff --git a/synthesizer/preprocess_transcript.py b/synthesizer/preprocess_transcript.py new file mode 100644 index 0000000000000000000000000000000000000000..7a26672bb7d39b88cbbe11888f5c0b444235bb7e --- /dev/null +++ b/synthesizer/preprocess_transcript.py @@ -0,0 +1,18 @@ +def preprocess_transcript_aishell3(dict_info, dict_transcript): + for v in dict_transcript: + if not v: + continue + v = v.strip().replace("\n","").replace("\t"," ").split(" ") + transList = [] + for i in range(2, len(v), 2): + transList.append(v[i]) + dict_info[v[0]] = " ".join(transList) + + +def preprocess_transcript_magicdata(dict_info, dict_transcript): + for v in dict_transcript: + if not v: + continue + v = v.strip().replace("\n","").replace("\t"," ").split(" ") + dict_info[v[0]] = " ".join(v[2:]) + \ No newline at end of file diff --git "a/synthesizer/saved_models/\346\231\256\351\200\232\350\257\235.pt" "b/synthesizer/saved_models/\346\231\256\351\200\232\350\257\235.pt" new file mode 100644 index 0000000000000000000000000000000000000000..3c1c3c66b251eb5ff38f579dd2b058603ccc9848 --- /dev/null +++ "b/synthesizer/saved_models/\346\231\256\351\200\232\350\257\235.pt" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71d047e92cebd45112eb68f79f39cd0368f7d01954de474131364a526a19b4b6 +size 526153469 diff --git a/synthesizer/synthesize.py b/synthesizer/synthesize.py new file mode 100644 index 0000000000000000000000000000000000000000..49a06b01983ae54c57840a62fa18f7a8508948ee --- /dev/null +++ b/synthesizer/synthesize.py @@ -0,0 +1,97 @@ +import torch +from torch.utils.data import DataLoader +from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer +from synthesizer.models.tacotron import Tacotron +from synthesizer.utils.text import text_to_sequence +from synthesizer.utils.symbols import symbols +import numpy as np +from pathlib import Path +from tqdm import tqdm +import sys + + +def run_synthesis(in_dir, out_dir, model_dir, hparams): + # This generates ground truth-aligned mels for vocoder training + synth_dir = Path(out_dir).joinpath("mels_gta") + synth_dir.mkdir(parents=True, exist_ok=True) + print(str(hparams)) + + # Check for GPU + if torch.cuda.is_available(): + device = torch.device("cuda") + if hparams.synthesis_batch_size % torch.cuda.device_count() != 0: + raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!") + else: + device = torch.device("cpu") + print("Synthesizer using device:", device) + + # Instantiate Tacotron model + model = Tacotron(embed_dims=hparams.tts_embed_dims, + num_chars=len(symbols), + encoder_dims=hparams.tts_encoder_dims, + decoder_dims=hparams.tts_decoder_dims, + n_mels=hparams.num_mels, + fft_bins=hparams.num_mels, + postnet_dims=hparams.tts_postnet_dims, + encoder_K=hparams.tts_encoder_K, + lstm_dims=hparams.tts_lstm_dims, + postnet_K=hparams.tts_postnet_K, + num_highways=hparams.tts_num_highways, + dropout=0., # Use zero dropout for gta mels + stop_threshold=hparams.tts_stop_threshold, + speaker_embedding_size=hparams.speaker_embedding_size).to(device) + + # Load the weights + model_dir = Path(model_dir) + model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt") + print("\nLoading weights at %s" % model_fpath) + model.load(model_fpath, device) + print("Tacotron weights loaded from step %d" % model.step) + + # Synthesize using same reduction factor as the model is currently trained + r = np.int32(model.r) + + # Set model to eval mode (disable gradient and zoneout) + model.eval() + + # Initialize the dataset + in_dir = Path(in_dir) + metadata_fpath = in_dir.joinpath("train.txt") + mel_dir = in_dir.joinpath("mels") + embed_dir = in_dir.joinpath("embeds") + num_workers = 0 if sys.platform.startswith("win") else 2; + dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams) + data_loader = DataLoader(dataset, + collate_fn=lambda batch: collate_synthesizer(batch), + batch_size=hparams.synthesis_batch_size, + num_workers=num_workers, + shuffle=False, + pin_memory=True) + + # Generate GTA mels + meta_out_fpath = Path(out_dir).joinpath("synthesized.txt") + with open(meta_out_fpath, "w") as file: + for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)): + texts = texts.to(device) + mels = mels.to(device) + embeds = embeds.to(device) + + # Parallelize model onto GPUS using workaround due to python bug + if device.type == "cuda" and torch.cuda.device_count() > 1: + _, mels_out, _ , _ = data_parallel_workaround(model, texts, mels, embeds) + else: + _, mels_out, _, _ = model(texts, mels, embeds) + + for j, k in enumerate(idx): + # Note: outputs mel-spectrogram files and target ones have same names, just different folders + mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1]) + mel_out = mels_out[j].detach().cpu().numpy().T + + # Use the length of the ground truth mel to remove padding from the generated mels + mel_out = mel_out[:int(dataset.metadata[k][4])] + + # Write the spectrogram to disk + np.save(mel_filename, mel_out, allow_pickle=False) + + # Write metadata into the synthesized file + file.write("|".join(dataset.metadata[k])) diff --git a/synthesizer/synthesizer_dataset.py b/synthesizer/synthesizer_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5ed581e11d7809c6321cd395c8ebacf0303e9d --- /dev/null +++ b/synthesizer/synthesizer_dataset.py @@ -0,0 +1,93 @@ +import torch +from torch.utils.data import Dataset +import numpy as np +from pathlib import Path +from synthesizer.utils.text import text_to_sequence + + +class SynthesizerDataset(Dataset): + def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams): + print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir)) + + with metadata_fpath.open("r", encoding="utf-8") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + + mel_fnames = [x[1] for x in metadata if int(x[4])] + mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames] + embed_fnames = [x[2] for x in metadata if int(x[4])] + embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames] + self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths)) + self.samples_texts = [x[5].strip() for x in metadata if int(x[4])] + self.metadata = metadata + self.hparams = hparams + + print("Found %d samples" % len(self.samples_fpaths)) + + def __getitem__(self, index): + # Sometimes index may be a list of 2 (not sure why this happens) + # If that is the case, return a single item corresponding to first element in index + if index is list: + index = index[0] + + mel_path, embed_path = self.samples_fpaths[index] + mel = np.load(mel_path).T.astype(np.float32) + + # Load the embed + embed = np.load(embed_path) + + # Get the text and clean it + text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names) + + # Convert the list returned by text_to_sequence to a numpy array + text = np.asarray(text).astype(np.int32) + + return text, mel.astype(np.float32), embed.astype(np.float32), index + + def __len__(self): + return len(self.samples_fpaths) + + +def collate_synthesizer(batch): + # Text + x_lens = [len(x[0]) for x in batch] + max_x_len = max(x_lens) + + chars = [pad1d(x[0], max_x_len) for x in batch] + chars = np.stack(chars) + + # Mel spectrogram + spec_lens = [x[1].shape[-1] for x in batch] + max_spec_len = max(spec_lens) + 1 + if max_spec_len % 2 != 0: # FIXIT: Hardcoded due to incompatibility with Windows (no lambda) + max_spec_len += 2 - max_spec_len % 2 + + # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence + # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence. + # if hparams.symmetric_mels: + # mel_pad_value = -1 * hparams.max_abs_value + # else: + # mel_pad_value = 0 + mel_pad_value = -4 # FIXIT: Hardcoded due to incompatibility with Windows (no lambda) + mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch] + mel = np.stack(mel) + + # Speaker embedding (SV2TTS) + embeds = [x[2] for x in batch] + embeds = np.stack(embeds) + + # Index (for vocoder preprocessing) + indices = [x[3] for x in batch] + + + # Convert all to tensor + chars = torch.tensor(chars).long() + mel = torch.tensor(mel) + embeds = torch.tensor(embeds) + + return chars, mel, embeds, indices + +def pad1d(x, max_len, pad_value=0): + return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value) + +def pad2d(x, max_len, pad_value=0): + return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value) diff --git a/synthesizer/train.py b/synthesizer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1f8a0cf7aab7cfa7c00205d8368cad7570005f --- /dev/null +++ b/synthesizer/train.py @@ -0,0 +1,317 @@ +import torch +import torch.nn.functional as F +from torch import optim +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from synthesizer import audio +from synthesizer.models.tacotron import Tacotron +from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer +from synthesizer.utils import ValueWindow, data_parallel_workaround +from synthesizer.utils.plot import plot_spectrogram, plot_spectrogram_and_trace +from synthesizer.utils.symbols import symbols +from synthesizer.utils.text import sequence_to_text +from vocoder.display import * +from datetime import datetime +import json +import numpy as np +from pathlib import Path +import time +import os + +def np_now(x: torch.Tensor): return x.detach().cpu().numpy() + +def time_string(): + return datetime.now().strftime("%Y-%m-%d %H:%M") + +def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, + backup_every: int, log_every:int, force_restart:bool, hparams): + + syn_dir = Path(syn_dir) + models_dir = Path(models_dir) + models_dir.mkdir(exist_ok=True) + + model_dir = models_dir.joinpath(run_id) + plot_dir = model_dir.joinpath("plots") + wav_dir = model_dir.joinpath("wavs") + mel_output_dir = model_dir.joinpath("mel-spectrograms") + meta_folder = model_dir.joinpath("metas") + model_dir.mkdir(exist_ok=True) + plot_dir.mkdir(exist_ok=True) + wav_dir.mkdir(exist_ok=True) + mel_output_dir.mkdir(exist_ok=True) + meta_folder.mkdir(exist_ok=True) + + weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt") + metadata_fpath = syn_dir.joinpath("train.txt") + + print("Checkpoint path: {}".format(weights_fpath)) + print("Loading training data from: {}".format(metadata_fpath)) + print("Using model: Tacotron") + + # Book keeping + step = 0 + time_window = ValueWindow(100) + loss_window = ValueWindow(100) + + + # From WaveRNN/train_tacotron.py + if torch.cuda.is_available(): + device = torch.device("cuda") + + for session in hparams.tts_schedule: + _, _, _, batch_size = session + if batch_size % torch.cuda.device_count() != 0: + raise ValueError("`batch_size` must be evenly divisible by n_gpus!") + else: + device = torch.device("cpu") + print("Using device:", device) + + # Instantiate Tacotron Model + print("\nInitialising Tacotron Model...\n") + num_chars = len(symbols) + if weights_fpath.exists(): + # for compatibility purpose, change symbols accordingly: + loaded_shape = torch.load(str(weights_fpath), map_location=device)["model_state"]["encoder.embedding.weight"].shape + if num_chars != loaded_shape[0]: + print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`") + num_chars != loaded_shape[0] + # Try to scan config file + model_config_fpaths = list(weights_fpath.parent.rglob("*.json")) + if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): + with model_config_fpaths[0].open("r", encoding="utf-8") as f: + hparams.loadJson(json.load(f)) + else: # save a config + hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json")) + + + model = Tacotron(embed_dims=hparams.tts_embed_dims, + num_chars=num_chars, + encoder_dims=hparams.tts_encoder_dims, + decoder_dims=hparams.tts_decoder_dims, + n_mels=hparams.num_mels, + fft_bins=hparams.num_mels, + postnet_dims=hparams.tts_postnet_dims, + encoder_K=hparams.tts_encoder_K, + lstm_dims=hparams.tts_lstm_dims, + postnet_K=hparams.tts_postnet_K, + num_highways=hparams.tts_num_highways, + dropout=hparams.tts_dropout, + stop_threshold=hparams.tts_stop_threshold, + speaker_embedding_size=hparams.speaker_embedding_size).to(device) + + # Initialize the optimizer + optimizer = optim.Adam(model.parameters(), amsgrad=True) + + # Load the weights + if force_restart or not weights_fpath.exists(): + print("\nStarting the training of Tacotron from scratch\n") + model.save(weights_fpath) + + # Embeddings metadata + char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv") + with open(char_embedding_fpath, "w", encoding="utf-8") as f: + for symbol in symbols: + if symbol == " ": + symbol = "\\s" # For visual purposes, swap space with \s + + f.write("{}\n".format(symbol)) + + else: + print("\nLoading weights at %s" % weights_fpath) + model.load(weights_fpath, device, optimizer) + print("Tacotron weights loaded from step %d" % model.step) + + # Initialize the dataset + metadata_fpath = syn_dir.joinpath("train.txt") + mel_dir = syn_dir.joinpath("mels") + embed_dir = syn_dir.joinpath("embeds") + dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams) + test_loader = DataLoader(dataset, + batch_size=1, + shuffle=True, + pin_memory=True) + + # tracing training step + sw = SummaryWriter(log_dir=model_dir.joinpath("logs")) + + for i, session in enumerate(hparams.tts_schedule): + current_step = model.get_step() + + r, lr, max_step, batch_size = session + + training_steps = max_step - current_step + + # Do we need to change to the next session? + if current_step >= max_step: + # Are there no further sessions than the current one? + if i == len(hparams.tts_schedule) - 1: + # We have completed training. Save the model and exit + model.save(weights_fpath, optimizer) + break + else: + # There is a following session, go to it + continue + + model.r = r + # Begin the training + simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"), + ("Batch Size", batch_size), + ("Learning Rate", lr), + ("Outputs/Step (r)", model.r)]) + + for p in optimizer.param_groups: + p["lr"] = lr + if hparams.tts_finetune_layers is not None and len(hparams.tts_finetune_layers) > 0: + model.finetune_partial(hparams.tts_finetune_layers) + + data_loader = DataLoader(dataset, + collate_fn=collate_synthesizer, + batch_size=batch_size, #change if you got graphic card OOM + num_workers=2, + shuffle=True, + pin_memory=True) + + total_iters = len(dataset) + steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32) + epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32) + + for epoch in range(1, epochs+1): + for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1): + start_time = time.time() + + # Generate stop tokens for training + stop = torch.ones(mels.shape[0], mels.shape[2]) + for j, k in enumerate(idx): + stop[j, :int(dataset.metadata[k][4])-1] = 0 + + texts = texts.to(device) + mels = mels.to(device) + embeds = embeds.to(device) + stop = stop.to(device) + + # Forward pass + # Parallelize model onto GPUS using workaround due to python bug + if device.type == "cuda" and torch.cuda.device_count() > 1: + m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, + mels, embeds) + else: + m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds) + + # Backward pass + m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels) + m2_loss = F.mse_loss(m2_hat, mels) + stop_loss = F.binary_cross_entropy(stop_pred, stop) + + loss = m1_loss + m2_loss + stop_loss + + optimizer.zero_grad() + loss.backward() + + if hparams.tts_clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm) + if np.isnan(grad_norm.cpu()): + print("grad_norm was NaN!") + + optimizer.step() + + time_window.append(time.time() - start_time) + loss_window.append(loss.item()) + + step = model.get_step() + k = step // 1000 + + + msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | " + stream(msg) + + if log_every != 0 and step % log_every == 0 : + sw.add_scalar("training/loss", loss_window.average, step) + + # Backup or save model as appropriate + if backup_every != 0 and step % backup_every == 0 : + backup_fpath = Path("{}/{}_{}.pt".format(str(weights_fpath.parent), run_id, step)) + model.save(backup_fpath, optimizer) + + if save_every != 0 and step % save_every == 0 : + # Must save latest optimizer state to ensure that resuming training + # doesn't produce artifacts + model.save(weights_fpath, optimizer) + + + # Evaluate model to generate samples + epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done + step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps + if epoch_eval or step_eval: + for sample_idx in range(hparams.tts_eval_num_samples): + # At most, generate samples equal to number in the batch + if sample_idx + 1 <= len(texts): + # Remove padding from mels using frame length in metadata + mel_length = int(dataset.metadata[idx[sample_idx]][4]) + mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length] + target_spectrogram = np_now(mels[sample_idx]).T[:mel_length] + attention_len = mel_length // model.r + # eval_loss = F.mse_loss(mel_prediction, target_spectrogram) + # sw.add_scalar("validing/loss", eval_loss.item(), step) + eval_model(attention=np_now(attention[sample_idx][:, :attention_len]), + mel_prediction=mel_prediction, + target_spectrogram=target_spectrogram, + input_seq=np_now(texts[sample_idx]), + step=step, + plot_dir=plot_dir, + mel_output_dir=mel_output_dir, + wav_dir=wav_dir, + sample_num=sample_idx + 1, + loss=loss, + hparams=hparams, + sw=sw) + MAX_SAVED_COUNT = 20 + if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT == 0: + # clean up and save last MAX_SAVED_COUNT; + plots = next(os.walk(plot_dir), (None, None, []))[2] + for plot in plots[-MAX_SAVED_COUNT:]: + os.remove(plot_dir.joinpath(plot)) + mel_files = next(os.walk(mel_output_dir), (None, None, []))[2] + for mel_file in mel_files[-MAX_SAVED_COUNT:]: + os.remove(mel_output_dir.joinpath(mel_file)) + wavs = next(os.walk(wav_dir), (None, None, []))[2] + for w in wavs[-MAX_SAVED_COUNT:]: + os.remove(wav_dir.joinpath(w)) + + # Break out of loop to update training schedule + if step >= max_step: + break + + # Add line break after every epoch + print("") + +def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step, + plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams, sw): + # Save some results for evaluation + attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num))) + # save_attention(attention, attention_path) + save_and_trace_attention(attention, attention_path, sw, step) + + # save predicted mel spectrogram to disk (debug) + mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num)) + np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False) + + # save griffin lim inverted wav for debug (mel -> wav) + wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams) + wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num)) + audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate) + + # save real and predicted mel-spectrogram plot to disk (control purposes) + spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num)) + title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss) + # plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str, + # target_spectrogram=target_spectrogram, + # max_len=target_spectrogram.size // hparams.num_mels) + plot_spectrogram_and_trace( + mel_prediction, + str(spec_fpath), + title=title_str, + target_spectrogram=target_spectrogram, + max_len=target_spectrogram.size // hparams.num_mels, + sw=sw, + step=step) + print("Input at step {}: {}".format(step, sequence_to_text(input_seq))) diff --git a/synthesizer/utils/__init__.py b/synthesizer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae3e48110e61231acf1e666e5fa76af5e4ebdcd --- /dev/null +++ b/synthesizer/utils/__init__.py @@ -0,0 +1,45 @@ +import torch + + +_output_ref = None +_replicas_ref = None + +def data_parallel_workaround(model, *input): + global _output_ref + global _replicas_ref + device_ids = list(range(torch.cuda.device_count())) + output_device = device_ids[0] + replicas = torch.nn.parallel.replicate(model, device_ids) + # input.shape = (num_args, batch, ...) + inputs = torch.nn.parallel.scatter(input, device_ids) + # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...) + replicas = replicas[:len(inputs)] + outputs = torch.nn.parallel.parallel_apply(replicas, inputs) + y_hat = torch.nn.parallel.gather(outputs, output_device) + _output_ref = outputs + _replicas_ref = replicas + return y_hat + + +class ValueWindow(): + def __init__(self, window_size=100): + self._window_size = window_size + self._values = [] + + def append(self, x): + self._values = self._values[-(self._window_size - 1):] + [x] + + @property + def sum(self): + return sum(self._values) + + @property + def count(self): + return len(self._values) + + @property + def average(self): + return self.sum / max(1, self.count) + + def reset(self): + self._values = [] diff --git a/synthesizer/utils/_cmudict.py b/synthesizer/utils/_cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..2cef1f896d4fb78478884fe8e810956998d5e3b3 --- /dev/null +++ b/synthesizer/utils/_cmudict.py @@ -0,0 +1,62 @@ +import re + +valid_symbols = [ + "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2", + "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2", + "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY", + "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1", + "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0", + "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW", + "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH" +] + +_valid_symbol_set = set(valid_symbols) + + +class CMUDict: + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding="latin-1") as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + + def __len__(self): + return len(self._entries) + + + def lookup(self, word): + """Returns list of ARPAbet pronunciations of the given word.""" + return self._entries.get(word.upper()) + + + +_alt_re = re.compile(r"\([0-9]+\)") + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(" ") + for part in parts: + if part not in _valid_symbol_set: + return None + return " ".join(parts) diff --git a/synthesizer/utils/cleaners.py b/synthesizer/utils/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..eab63f05c9cc7cc0b583992eac94058097f3c191 --- /dev/null +++ b/synthesizer/utils/cleaners.py @@ -0,0 +1,88 @@ +""" +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You"ll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import re +from unidecode import unidecode +from .numbers import normalize_numbers + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + """lowercase input tokens.""" + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + """Pipeline for English text, including number and abbreviation expansion.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/synthesizer/utils/numbers.py b/synthesizer/utils/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..75020a0bd732830f603d7c7d250c9e087033cc24 --- /dev/null +++ b/synthesizer/utils/numbers.py @@ -0,0 +1,68 @@ +import re +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/synthesizer/utils/plot.py b/synthesizer/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..efdb5670b4f26f2110988e818ff8ad9ff7238cef --- /dev/null +++ b/synthesizer/utils/plot.py @@ -0,0 +1,115 @@ +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + + +def split_title_line(title_text, max_words=5): + """ + A function that splits any string based on specific character + (returning it with the string), with maximum number of words on it + """ + seq = title_text.split() + return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)]) + +def plot_alignment(alignment, path, title=None, split_title=False, max_len=None): + if max_len is not None: + alignment = alignment[:, :max_len] + + fig = plt.figure(figsize=(8, 6)) + ax = fig.add_subplot(111) + + im = ax.imshow( + alignment, + aspect="auto", + origin="lower", + interpolation="none") + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + + if split_title: + title = split_title_line(title) + + plt.xlabel(xlabel) + plt.title(title) + plt.ylabel("Encoder timestep") + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() + + +def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False): + if max_len is not None: + target_spectrogram = target_spectrogram[:max_len] + pred_spectrogram = pred_spectrogram[:max_len] + + if split_title: + title = split_title_line(title) + + fig = plt.figure(figsize=(10, 8)) + # Set common labels + fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16) + + #target spectrogram subplot + if target_spectrogram is not None: + ax1 = fig.add_subplot(311) + ax2 = fig.add_subplot(312) + + if auto_aspect: + im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none") + else: + im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none") + ax1.set_title("Target Mel-Spectrogram") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1) + ax2.set_title("Predicted Mel-Spectrogram") + else: + ax2 = fig.add_subplot(211) + + if auto_aspect: + im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none") + else: + im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2) + + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() + + +def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, sw=None, step=0): + if max_len is not None: + target_spectrogram = target_spectrogram[:max_len] + pred_spectrogram = pred_spectrogram[:max_len] + + if split_title: + title = split_title_line(title) + + fig = plt.figure(figsize=(10, 8)) + # Set common labels + fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16) + + #target spectrogram subplot + if target_spectrogram is not None: + ax1 = fig.add_subplot(311) + ax2 = fig.add_subplot(312) + + if auto_aspect: + im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none") + else: + im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none") + ax1.set_title("Target Mel-Spectrogram") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1) + ax2.set_title("Predicted Mel-Spectrogram") + else: + ax2 = fig.add_subplot(211) + + if auto_aspect: + im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none") + else: + im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2) + + plt.tight_layout() + plt.savefig(path, format="png") + sw.add_figure("spectrogram", fig, step) + plt.close() \ No newline at end of file diff --git a/synthesizer/utils/symbols.py b/synthesizer/utils/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..2036dded914cc5490d556a2022b40e57e584b742 --- /dev/null +++ b/synthesizer/utils/symbols.py @@ -0,0 +1,18 @@ +""" +Defines the set of symbols used in text input to the model. + +The default is a set of ASCII characters that works well for English or text that has been run +through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. +""" +# from . import cmudict + +_pad = "_" +_eos = "~" +_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!\'(),-.:;? ' + +#_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz12340!\'(),-.:;? ' # use this old one if you want to train old model +# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): +#_arpabet = ["@' + s for s in cmudict.valid_symbols] + +# Export all symbols: +symbols = [_pad, _eos] + list(_characters) #+ _arpabet diff --git a/synthesizer/utils/text.py b/synthesizer/utils/text.py new file mode 100644 index 0000000000000000000000000000000000000000..29372174aec95cd2eac1ea40096fcc148f532b07 --- /dev/null +++ b/synthesizer/utils/text.py @@ -0,0 +1,74 @@ +from .symbols import symbols +from . import cleaners +import re + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +# Regular expression matching text enclosed in curly braces: +_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) + break + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) + + # Append EOS token + sequence.append(_symbol_to_id["~"]) + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] + result += s + return result.replace("}{", " ") + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text + + +def _symbols_to_sequence(symbols): + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + return _symbols_to_sequence(["@" + s for s in text.split()]) + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s not in ("_", "~") diff --git a/synthesizer_preprocess_audio.py b/synthesizer_preprocess_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..51d92f91a485ea853957127bec9166420daed934 --- /dev/null +++ b/synthesizer_preprocess_audio.py @@ -0,0 +1,65 @@ +from synthesizer.preprocess import preprocess_dataset +from synthesizer.hparams import hparams +from utils.argutils import print_args +from pathlib import Path +import argparse + + +recognized_datasets = [ + "aidatatang_200zh", + "magicdata", + "aishell3" +] + +if __name__ == "__main__": + print("This method is deprecaded and will not be longer supported, please use 'pre.py'") + parser = argparse.ArgumentParser( + description="Preprocesses audio files from datasets, encodes them as mel spectrograms " + "and writes them to the disk. Audio files are also saved, to be used by the " + "vocoder for training.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("datasets_root", type=Path, help=\ + "Path to the directory containing your LibriSpeech/TTS datasets.") + parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ + "Path to the output directory that will contain the mel spectrograms, the audios and the " + "embeds. Defaults to /SV2TTS/synthesizer/") + parser.add_argument("-n", "--n_processes", type=int, default=None, help=\ + "Number of processes in parallel.") + parser.add_argument("-s", "--skip_existing", action="store_true", help=\ + "Whether to overwrite existing files with the same name. Useful if the preprocessing was " + "interrupted.") + parser.add_argument("--hparams", type=str, default="", help=\ + "Hyperparameter overrides as a comma-separated list of name-value pairs") + parser.add_argument("--no_trim", action="store_true", help=\ + "Preprocess audio without trimming silences (not recommended).") + parser.add_argument("--no_alignments", action="store_true", help=\ + "Use this option when dataset does not include alignments\ + (these are used to split long audio files into sub-utterances.)") + parser.add_argument("--dataset", type=str, default="aidatatang_200zh", help=\ + "Name of the dataset to process, allowing values: magicdata, aidatatang_200zh.") + args = parser.parse_args() + + # Process the arguments + if not hasattr(args, "out_dir"): + args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer") + assert args.dataset in recognized_datasets, 'is not supported, please vote for it in https://github.com/babysor/MockingBird/issues/10' + # Create directories + assert args.datasets_root.exists() + args.out_dir.mkdir(exist_ok=True, parents=True) + + # Verify webrtcvad is available + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + del args.no_trim + + # Preprocess the dataset + print_args(args, parser) + args.hparams = hparams.parse(args.hparams) + + preprocess_dataset(**vars(args)) \ No newline at end of file diff --git a/synthesizer_preprocess_embeds.py b/synthesizer_preprocess_embeds.py new file mode 100644 index 0000000000000000000000000000000000000000..7276626f5c870020ee5fda5168897dded0174dd8 --- /dev/null +++ b/synthesizer_preprocess_embeds.py @@ -0,0 +1,26 @@ +from synthesizer.preprocess import create_embeddings +from utils.argutils import print_args +from pathlib import Path +import argparse + + +if __name__ == "__main__": + print("This method is deprecaded and will not be longer supported, please use 'pre.py'") + parser = argparse.ArgumentParser( + description="Creates embeddings for the synthesizer from the LibriSpeech utterances.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("synthesizer_root", type=Path, help=\ + "Path to the synthesizer training data that contains the audios and the train.txt file. " + "If you let everything as default, it should be /SV2TTS/synthesizer/.") + parser.add_argument("-e", "--encoder_model_fpath", type=Path, + default="encoder/saved_models/pretrained.pt", help=\ + "Path your trained encoder model.") + parser.add_argument("-n", "--n_processes", type=int, default=4, help= \ + "Number of parallel processes. An encoder is created for each, so you may need to lower " + "this value on GPUs with low memory. Set it to 1 if CUDA is unhappy.") + args = parser.parse_args() + + # Preprocess the dataset + print_args(args, parser) + create_embeddings(**vars(args)) diff --git a/synthesizer_train.py b/synthesizer_train.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0b5985dcc02393bda576f7b30d6ade4427fc29 --- /dev/null +++ b/synthesizer_train.py @@ -0,0 +1,37 @@ +from synthesizer.hparams import hparams +from synthesizer.train import train +from utils.argutils import print_args +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("run_id", type=str, help= \ + "Name for this model instance. If a model state from the same run ID was previously " + "saved, the training will restart from there. Pass -f to overwrite saved states and " + "restart from scratch.") + parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the synthesizer directory that contains the ground truth mel spectrograms, " + "the wavs and the embeds.") + parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\ + "Path to the output directory that will contain the saved model weights and the logs.") + parser.add_argument("-s", "--save_every", type=int, default=1000, help= \ + "Number of steps between updates of the model on the disk. Set to 0 to never save the " + "model.") + parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \ + "Number of steps between backups of the model. Set to 0 to never make backups of the " + "model.") + parser.add_argument("-l", "--log_every", type=int, default=200, help= \ + "Number of steps between summary the training info in tensorboard") + parser.add_argument("-f", "--force_restart", action="store_true", help= \ + "Do not load any saved model and restart from scratch.") + parser.add_argument("--hparams", default="", + help="Hyperparameter overrides as a comma-separated list of name=value " + "pairs") + args = parser.parse_args() + print_args(args, parser) + + args.hparams = hparams.parse(args.hparams) + + # Run the training + train(**vars(args)) diff --git a/toolbox/__init__.py b/toolbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b51164f3537a6b19cb2a00fb44b38855c4ba1c49 --- /dev/null +++ b/toolbox/__init__.py @@ -0,0 +1,476 @@ +from toolbox.ui import UI +from encoder import inference as encoder +from synthesizer.inference import Synthesizer +from vocoder.wavernn import inference as rnn_vocoder +from vocoder.hifigan import inference as gan_vocoder +from vocoder.fregan import inference as fgan_vocoder +from pathlib import Path +from time import perf_counter as timer +from toolbox.utterance import Utterance +import numpy as np +import traceback +import sys +import torch +import re + +# 默认使用wavernn +vocoder = rnn_vocoder + +# Use this directory structure for your datasets, or modify it to fit your needs +recognized_datasets = [ + "LibriSpeech/dev-clean", + "LibriSpeech/dev-other", + "LibriSpeech/test-clean", + "LibriSpeech/test-other", + "LibriSpeech/train-clean-100", + "LibriSpeech/train-clean-360", + "LibriSpeech/train-other-500", + "LibriTTS/dev-clean", + "LibriTTS/dev-other", + "LibriTTS/test-clean", + "LibriTTS/test-other", + "LibriTTS/train-clean-100", + "LibriTTS/train-clean-360", + "LibriTTS/train-other-500", + "LJSpeech-1.1", + "VoxCeleb1/wav", + "VoxCeleb1/test_wav", + "VoxCeleb2/dev/aac", + "VoxCeleb2/test/aac", + "VCTK-Corpus/wav48", + "aidatatang_200zh/corpus/dev", + "aidatatang_200zh/corpus/test", + "aishell3/test/wav", + "magicdata/train", +] + +#Maximum of generated wavs to keep on memory +MAX_WAVES = 15 + +class Toolbox: + def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode): + self.no_mp3_support = no_mp3_support + self.vc_mode = vc_mode + sys.excepthook = self.excepthook + self.datasets_root = datasets_root + self.utterances = set() + self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav + + self.synthesizer = None # type: Synthesizer + + # for ppg-based voice conversion + self.extractor = None + self.convertor = None # ppg2mel + + self.current_wav = None + self.waves_list = [] + self.waves_count = 0 + self.waves_namelist = [] + + # Check for webrtcvad (enables removal of silences in vocoder output) + try: + import webrtcvad + self.trim_silences = True + except: + self.trim_silences = False + + # Initialize the events and the interface + self.ui = UI(vc_mode) + self.style_idx = 0 + self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed) + self.setup_events() + self.ui.start() + + def excepthook(self, exc_type, exc_value, exc_tb): + traceback.print_exception(exc_type, exc_value, exc_tb) + self.ui.log("Exception: %s" % exc_value) + + def setup_events(self): + # Dataset, speaker and utterance selection + self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser()) + random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root, + recognized_datasets, + level) + self.ui.random_dataset_button.clicked.connect(random_func(0)) + self.ui.random_speaker_button.clicked.connect(random_func(1)) + self.ui.random_utterance_button.clicked.connect(random_func(2)) + self.ui.dataset_box.currentIndexChanged.connect(random_func(1)) + self.ui.speaker_box.currentIndexChanged.connect(random_func(2)) + + # Model selection + self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder) + def func(): + self.synthesizer = None + if self.vc_mode: + self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor) + else: + self.ui.synthesizer_box.currentIndexChanged.connect(func) + + self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder) + + # Utterance selection + func = lambda: self.load_from_browser(self.ui.browse_file()) + self.ui.browser_browse_button.clicked.connect(func) + func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current") + self.ui.utterance_history.currentIndexChanged.connect(func) + func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate) + self.ui.play_button.clicked.connect(func) + self.ui.stop_button.clicked.connect(self.ui.stop) + self.ui.record_button.clicked.connect(self.record) + + # Source Utterance selection + if self.vc_mode: + func = lambda: self.load_soruce_button(self.ui.selected_utterance) + self.ui.load_soruce_button.clicked.connect(func) + + #Audio + self.ui.setup_audio_devices(Synthesizer.sample_rate) + + #Wav playback & save + func = lambda: self.replay_last_wav() + self.ui.replay_wav_button.clicked.connect(func) + func = lambda: self.export_current_wave() + self.ui.export_wav_button.clicked.connect(func) + self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav) + + # Generation + self.ui.vocode_button.clicked.connect(self.vocode) + self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox) + + if self.vc_mode: + func = lambda: self.convert() or self.vocode() + self.ui.convert_button.clicked.connect(func) + else: + func = lambda: self.synthesize() or self.vocode() + self.ui.generate_button.clicked.connect(func) + self.ui.synthesize_button.clicked.connect(self.synthesize) + + # UMAP legend + self.ui.clear_button.clicked.connect(self.clear_utterances) + + def set_current_wav(self, index): + self.current_wav = self.waves_list[index] + + def export_current_wave(self): + self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate) + + def replay_last_wav(self): + self.ui.play(self.current_wav, Synthesizer.sample_rate) + + def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed): + self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True) + self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode) + self.ui.populate_gen_options(seed, self.trim_silences) + + def load_from_browser(self, fpath=None): + if fpath is None: + fpath = Path(self.datasets_root, + self.ui.current_dataset_name, + self.ui.current_speaker_name, + self.ui.current_utterance_name) + name = str(fpath.relative_to(self.datasets_root)) + speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name + + # Select the next utterance + if self.ui.auto_next_checkbox.isChecked(): + self.ui.browser_select_next() + elif fpath == "": + return + else: + name = fpath.name + speaker_name = fpath.parent.name + + if fpath.suffix.lower() == ".mp3" and self.no_mp3_support: + self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used") + return + + # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for + # playback, so as to have a fair comparison with the generated audio + wav = Synthesizer.load_preprocess_wav(fpath) + self.ui.log("Loaded %s" % name) + + self.add_real_utterance(wav, name, speaker_name) + + def load_soruce_button(self, utterance: Utterance): + self.selected_source_utterance = utterance + + def record(self): + wav = self.ui.record_one(encoder.sampling_rate, 5) + if wav is None: + return + self.ui.play(wav, encoder.sampling_rate) + + speaker_name = "user01" + name = speaker_name + "_rec_%05d" % np.random.randint(100000) + self.add_real_utterance(wav, name, speaker_name) + + def add_real_utterance(self, wav, name, speaker_name): + # Compute the mel spectrogram + spec = Synthesizer.make_spectrogram(wav) + self.ui.draw_spec(spec, "current") + + # Compute the embedding + if not encoder.is_loaded(): + self.init_encoder() + encoder_wav = encoder.preprocess_wav(wav) + embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + # Add the utterance + utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False) + self.utterances.add(utterance) + self.ui.register_utterance(utterance, self.vc_mode) + + # Plot it + self.ui.draw_embed(embed, name, "current") + self.ui.draw_umap_projections(self.utterances) + + def clear_utterances(self): + self.utterances.clear() + self.ui.draw_umap_projections(self.utterances) + + def synthesize(self): + self.ui.log("Generating the mel spectrogram...") + self.ui.set_loading(1) + + # Update the synthesizer random seed + if self.ui.random_seed_checkbox.isChecked(): + seed = int(self.ui.seed_textbox.text()) + self.ui.populate_gen_options(seed, self.trim_silences) + else: + seed = None + + if seed is not None: + torch.manual_seed(seed) + + # Synthesize the spectrogram + if self.synthesizer is None or seed is not None: + self.init_synthesizer() + + texts = self.ui.text_prompt.toPlainText().split("\n") + punctuation = '!,。、,' # punctuate and split/clean text + processed_texts = [] + for text in texts: + for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'): + if processed_text: + processed_texts.append(processed_text.strip()) + texts = processed_texts + embed = self.ui.selected_utterance.embed + embeds = [embed] * len(texts) + min_token = int(self.ui.token_slider.value()) + specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200) + breaks = [spec.shape[1] for spec in specs] + spec = np.concatenate(specs, axis=1) + + self.ui.draw_spec(spec, "generated") + self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None) + self.ui.set_loading(0) + + def vocode(self): + speaker_name, spec, breaks, _ = self.current_generated + assert spec is not None + + # Initialize the vocoder model and make it determinstic, if user provides a seed + if self.ui.random_seed_checkbox.isChecked(): + seed = int(self.ui.seed_textbox.text()) + self.ui.populate_gen_options(seed, self.trim_silences) + else: + seed = None + + if seed is not None: + torch.manual_seed(seed) + + # Synthesize the waveform + if not vocoder.is_loaded() or seed is not None: + self.init_vocoder() + + def vocoder_progress(i, seq_len, b_size, gen_rate): + real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000 + line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \ + % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor) + self.ui.log(line, "overwrite") + self.ui.set_loading(i, seq_len) + if self.ui.current_vocoder_fpath is not None: + self.ui.log("") + wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress) + else: + self.ui.log("Waveform generation with Griffin-Lim... ") + wav = Synthesizer.griffin_lim(spec) + self.ui.set_loading(0) + self.ui.log(" Done!", "append") + + # Add breaks + b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size) + b_starts = np.concatenate(([0], b_ends[:-1])) + wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)] + breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks) + wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)]) + + # Trim excessive silences + if self.ui.trim_silences_checkbox.isChecked(): + wav = encoder.preprocess_wav(wav) + + # Play it + wav = wav / np.abs(wav).max() * 0.97 + self.ui.play(wav, sample_rate) + + # Name it (history displayed in combobox) + # TODO better naming for the combobox items? + wav_name = str(self.waves_count + 1) + + #Update waves combobox + self.waves_count += 1 + if self.waves_count > MAX_WAVES: + self.waves_list.pop() + self.waves_namelist.pop() + self.waves_list.insert(0, wav) + self.waves_namelist.insert(0, wav_name) + + self.ui.waves_cb.disconnect() + self.ui.waves_cb_model.setStringList(self.waves_namelist) + self.ui.waves_cb.setCurrentIndex(0) + self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav) + + # Update current wav + self.set_current_wav(0) + + #Enable replay and save buttons: + self.ui.replay_wav_button.setDisabled(False) + self.ui.export_wav_button.setDisabled(False) + + # Compute the embedding + # TODO: this is problematic with different sampling rates, gotta fix it + if not encoder.is_loaded(): + self.init_encoder() + encoder_wav = encoder.preprocess_wav(wav) + embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + # Add the utterance + name = speaker_name + "_gen_%05d" % np.random.randint(100000) + utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True) + self.utterances.add(utterance) + + # Plot it + self.ui.draw_embed(embed, name, "generated") + self.ui.draw_umap_projections(self.utterances) + + def convert(self): + self.ui.log("Extract PPG and Converting...") + self.ui.set_loading(1) + + # Init + if self.convertor is None: + self.init_convertor() + if self.extractor is None: + self.init_extractor() + + src_wav = self.selected_source_utterance.wav + + # Compute the ppg + if not self.extractor is None: + ppg = self.extractor.extract_from_wav(src_wav) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ref_wav = self.ui.selected_utterance.wav + # Import necessary dependency of Voice Conversion + from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv + ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav))) + lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True) + min_len = min(ppg.shape[1], len(lf0_uv)) + ppg = ppg[:, :min_len] + lf0_uv = lf0_uv[:min_len] + _, mel_pred, att_ws = self.convertor.inference( + ppg, + logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device), + spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device), + ) + mel_pred= mel_pred.transpose(0, 1) + breaks = [mel_pred.shape[1]] + mel_pred= mel_pred.detach().cpu().numpy() + self.ui.draw_spec(mel_pred, "generated") + self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None) + self.ui.set_loading(0) + + def init_extractor(self): + if self.ui.current_extractor_fpath is None: + return + model_fpath = self.ui.current_extractor_fpath + self.ui.log("Loading the extractor %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + import ppg_extractor as extractor + self.extractor = extractor.load_model(model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def init_convertor(self): + if self.ui.current_convertor_fpath is None: + return + model_fpath = self.ui.current_convertor_fpath + self.ui.log("Loading the convertor %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + import ppg2mel as convertor + self.convertor = convertor.load_model( model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def init_encoder(self): + model_fpath = self.ui.current_encoder_fpath + + self.ui.log("Loading the encoder %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + encoder.load_model(model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def init_synthesizer(self): + model_fpath = self.ui.current_synthesizer_fpath + + self.ui.log("Loading the synthesizer %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + self.synthesizer = Synthesizer(model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def init_vocoder(self): + + global vocoder + model_fpath = self.ui.current_vocoder_fpath + # Case of Griffin-lim + if model_fpath is None: + return + # Sekect vocoder based on model name + model_config_fpath = None + if model_fpath.name is not None and model_fpath.name.find("hifigan") > -1: + vocoder = gan_vocoder + self.ui.log("set hifigan as vocoder") + # search a config file + model_config_fpaths = list(model_fpath.parent.rglob("*.json")) + if self.vc_mode and self.ui.current_extractor_fpath is None: + return + if len(model_config_fpaths) > 0: + model_config_fpath = model_config_fpaths[0] + elif model_fpath.name is not None and model_fpath.name.find("fregan") > -1: + vocoder = fgan_vocoder + self.ui.log("set fregan as vocoder") + # search a config file + model_config_fpaths = list(model_fpath.parent.rglob("*.json")) + if self.vc_mode and self.ui.current_extractor_fpath is None: + return + if len(model_config_fpaths) > 0: + model_config_fpath = model_config_fpaths[0] + else: + vocoder = rnn_vocoder + self.ui.log("set wavernn as vocoder") + + self.ui.log("Loading the vocoder %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + vocoder.load_model(model_fpath, model_config_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def update_seed_textbox(self): + self.ui.update_seed_textbox() diff --git a/toolbox/assets/mb.png b/toolbox/assets/mb.png new file mode 100644 index 0000000000000000000000000000000000000000..abd804cab48147cdfafc4a385cf501322bca6e1c Binary files /dev/null and b/toolbox/assets/mb.png differ diff --git a/toolbox/ui.py b/toolbox/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..fe51e73bc1ea7d46c85ae8471604ffdc4ad05e80 --- /dev/null +++ b/toolbox/ui.py @@ -0,0 +1,699 @@ +from PyQt5.QtCore import Qt, QStringListModel +from PyQt5 import QtGui +from PyQt5.QtWidgets import * +import matplotlib.pyplot as plt +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +from encoder.inference import plot_embedding_as_heatmap +from toolbox.utterance import Utterance +from pathlib import Path +from typing import List, Set +import sounddevice as sd +import soundfile as sf +import numpy as np +# from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP +from time import sleep +import umap +import sys +from warnings import filterwarnings, warn +filterwarnings("ignore") + + +colormap = np.array([ + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [97, 142, 151], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + [76, 255, 0], +], dtype=np.float) / 255 + +default_text = \ + "欢迎使用工具箱, 现已支持中文输入!" + + + +class UI(QDialog): + min_umap_points = 4 + max_log_lines = 5 + max_saved_utterances = 20 + + def draw_utterance(self, utterance: Utterance, which): + self.draw_spec(utterance.spec, which) + self.draw_embed(utterance.embed, utterance.name, which) + + def draw_embed(self, embed, name, which): + embed_ax, _ = self.current_ax if which == "current" else self.gen_ax + embed_ax.figure.suptitle("" if embed is None else name) + + ## Embedding + # Clear the plot + if len(embed_ax.images) > 0: + embed_ax.images[0].colorbar.remove() + embed_ax.clear() + + # Draw the embed + if embed is not None: + plot_embedding_as_heatmap(embed, embed_ax) + embed_ax.set_title("embedding") + embed_ax.set_aspect("equal", "datalim") + embed_ax.set_xticks([]) + embed_ax.set_yticks([]) + embed_ax.figure.canvas.draw() + + def draw_spec(self, spec, which): + _, spec_ax = self.current_ax if which == "current" else self.gen_ax + + ## Spectrogram + # Draw the spectrogram + spec_ax.clear() + if spec is not None: + im = spec_ax.imshow(spec, aspect="auto", interpolation="none") + # spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal", + # spec_ax=spec_ax) + spec_ax.set_title("mel spectrogram") + + spec_ax.set_xticks([]) + spec_ax.set_yticks([]) + spec_ax.figure.canvas.draw() + if which != "current": + self.vocode_button.setDisabled(spec is None) + + def draw_umap_projections(self, utterances: Set[Utterance]): + self.umap_ax.clear() + + speakers = np.unique([u.speaker_name for u in utterances]) + colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)} + embeds = [u.embed for u in utterances] + + # Display a message if there aren't enough points + if len(utterances) < self.min_umap_points: + self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" % + (self.min_umap_points - len(utterances)), + horizontalalignment='center', fontsize=15) + self.umap_ax.set_title("") + + # Compute the projections + else: + if not self.umap_hot: + self.log( + "Drawing UMAP projections for the first time, this will take a few seconds.") + self.umap_hot = True + + reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine") + # reducer = TSNE() + projections = reducer.fit_transform(embeds) + + speakers_done = set() + for projection, utterance in zip(projections, utterances): + color = colors[utterance.speaker_name] + mark = "x" if "_gen_" in utterance.name else "o" + label = None if utterance.speaker_name in speakers_done else utterance.speaker_name + speakers_done.add(utterance.speaker_name) + self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark, + label=label) + # self.umap_ax.set_title("UMAP projections") + self.umap_ax.legend(prop={'size': 10}) + + # Draw the plot + self.umap_ax.set_aspect("equal", "datalim") + self.umap_ax.set_xticks([]) + self.umap_ax.set_yticks([]) + self.umap_ax.figure.canvas.draw() + + def save_audio_file(self, wav, sample_rate): + dialog = QFileDialog() + dialog.setDefaultSuffix(".wav") + fpath, _ = dialog.getSaveFileName( + parent=self, + caption="Select a path to save the audio file", + filter="Audio Files (*.flac *.wav)" + ) + if fpath: + #Default format is wav + if Path(fpath).suffix == "": + fpath += ".wav" + sf.write(fpath, wav, sample_rate) + + def setup_audio_devices(self, sample_rate): + input_devices = [] + output_devices = [] + for device in sd.query_devices(): + # Check if valid input + try: + sd.check_input_settings(device=device["name"], samplerate=sample_rate) + input_devices.append(device["name"]) + except: + pass + + # Check if valid output + try: + sd.check_output_settings(device=device["name"], samplerate=sample_rate) + output_devices.append(device["name"]) + except Exception as e: + # Log a warning only if the device is not an input + if not device["name"] in input_devices: + warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e))) + + if len(input_devices) == 0: + self.log("No audio input device detected. Recording may not work.") + self.audio_in_device = None + else: + self.audio_in_device = input_devices[0] + + if len(output_devices) == 0: + self.log("No supported output audio devices were found! Audio output may not work.") + self.audio_out_devices_cb.addItems(["None"]) + self.audio_out_devices_cb.setDisabled(True) + else: + self.audio_out_devices_cb.clear() + self.audio_out_devices_cb.addItems(output_devices) + self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device) + + self.set_audio_device() + + def set_audio_device(self): + + output_device = self.audio_out_devices_cb.currentText() + if output_device == "None": + output_device = None + + # If None, sounddevice queries portaudio + sd.default.device = (self.audio_in_device, output_device) + + def play(self, wav, sample_rate): + try: + sd.stop() + sd.play(wav, sample_rate) + except Exception as e: + print(e) + self.log("Error in audio playback. Try selecting a different audio output device.") + self.log("Your device must be connected before you start the toolbox.") + + def stop(self): + sd.stop() + + def record_one(self, sample_rate, duration): + self.record_button.setText("Recording...") + self.record_button.setDisabled(True) + + self.log("Recording %d seconds of audio" % duration) + sd.stop() + try: + wav = sd.rec(duration * sample_rate, sample_rate, 1) + except Exception as e: + print(e) + self.log("Could not record anything. Is your recording device enabled?") + self.log("Your device must be connected before you start the toolbox.") + return None + + for i in np.arange(0, duration, 0.1): + self.set_loading(i, duration) + sleep(0.1) + self.set_loading(duration, duration) + sd.wait() + + self.log("Done recording.") + self.record_button.setText("Record") + self.record_button.setDisabled(False) + + return wav.squeeze() + + @property + def current_dataset_name(self): + return self.dataset_box.currentText() + + @property + def current_speaker_name(self): + return self.speaker_box.currentText() + + @property + def current_utterance_name(self): + return self.utterance_box.currentText() + + def browse_file(self): + fpath = QFileDialog().getOpenFileName( + parent=self, + caption="Select an audio file", + filter="Audio Files (*.mp3 *.flac *.wav *.m4a)" + ) + return Path(fpath[0]) if fpath[0] != "" else "" + + @staticmethod + def repopulate_box(box, items, random=False): + """ + Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join + data to the items + """ + box.blockSignals(True) + box.clear() + for item in items: + item = list(item) if isinstance(item, tuple) else [item] + box.addItem(str(item[0]), *item[1:]) + if len(items) > 0: + box.setCurrentIndex(np.random.randint(len(items)) if random else 0) + box.setDisabled(len(items) == 0) + box.blockSignals(False) + + def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int, + random=True): + # Select a random dataset + if level <= 0: + if datasets_root is not None: + datasets = [datasets_root.joinpath(d) for d in recognized_datasets] + datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()] + self.browser_load_button.setDisabled(len(datasets) == 0) + if datasets_root is None or len(datasets) == 0: + msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \ + if datasets_root is None else "o not have any of the recognized datasets" \ + " in %s" % datasets_root) + self.log(msg) + msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \ + "can still use the toolbox by recording samples yourself." % \ + ("\n\t".join(recognized_datasets)) + print(msg, file=sys.stderr) + + self.random_utterance_button.setDisabled(True) + self.random_speaker_button.setDisabled(True) + self.random_dataset_button.setDisabled(True) + self.utterance_box.setDisabled(True) + self.speaker_box.setDisabled(True) + self.dataset_box.setDisabled(True) + self.browser_load_button.setDisabled(True) + self.auto_next_checkbox.setDisabled(True) + return + self.repopulate_box(self.dataset_box, datasets, random) + + # Select a random speaker + if level <= 1: + speakers_root = datasets_root.joinpath(self.current_dataset_name) + speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()] + self.repopulate_box(self.speaker_box, speaker_names, random) + + # Select a random utterance + if level <= 2: + utterances_root = datasets_root.joinpath( + self.current_dataset_name, + self.current_speaker_name + ) + utterances = [] + for extension in ['mp3', 'flac', 'wav', 'm4a']: + utterances.extend(Path(utterances_root).glob("**/*.%s" % extension)) + utterances = [fpath.relative_to(utterances_root) for fpath in utterances] + self.repopulate_box(self.utterance_box, utterances, random) + + def browser_select_next(self): + index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box) + self.utterance_box.setCurrentIndex(index) + + @property + def current_encoder_fpath(self): + return self.encoder_box.itemData(self.encoder_box.currentIndex()) + + @property + def current_synthesizer_fpath(self): + return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex()) + + @property + def current_vocoder_fpath(self): + return self.vocoder_box.itemData(self.vocoder_box.currentIndex()) + + @property + def current_extractor_fpath(self): + return self.extractor_box.itemData(self.extractor_box.currentIndex()) + + @property + def current_convertor_fpath(self): + return self.convertor_box.itemData(self.convertor_box.currentIndex()) + + def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path, + vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool): + # Encoder + encoder_fpaths = list(encoder_models_dir.glob("*.pt")) + if len(encoder_fpaths) == 0: + raise Exception("No encoder models found in %s" % encoder_models_dir) + self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths]) + + if vc_mode: + # Extractor + extractor_fpaths = list(extractor_models_dir.glob("*.pt")) + if len(extractor_fpaths) == 0: + self.log("No extractor models found in %s" % extractor_fpaths) + self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths]) + + # Convertor + convertor_fpaths = list(convertor_models_dir.glob("*.pth")) + if len(convertor_fpaths) == 0: + self.log("No convertor models found in %s" % convertor_fpaths) + self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths]) + else: + # Synthesizer + synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt")) + if len(synthesizer_fpaths) == 0: + raise Exception("No synthesizer models found in %s" % synthesizer_models_dir) + self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths]) + + # Vocoder + vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt")) + vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)] + self.repopulate_box(self.vocoder_box, vocoder_items) + + @property + def selected_utterance(self): + return self.utterance_history.itemData(self.utterance_history.currentIndex()) + + def register_utterance(self, utterance: Utterance, vc_mode): + self.utterance_history.blockSignals(True) + self.utterance_history.insertItem(0, utterance.name, utterance) + self.utterance_history.setCurrentIndex(0) + self.utterance_history.blockSignals(False) + + if len(self.utterance_history) > self.max_saved_utterances: + self.utterance_history.removeItem(self.max_saved_utterances) + + self.play_button.setDisabled(False) + if vc_mode: + self.convert_button.setDisabled(False) + else: + self.generate_button.setDisabled(False) + self.synthesize_button.setDisabled(False) + + def log(self, line, mode="newline"): + if mode == "newline": + self.logs.append(line) + if len(self.logs) > self.max_log_lines: + del self.logs[0] + elif mode == "append": + self.logs[-1] += line + elif mode == "overwrite": + self.logs[-1] = line + log_text = '\n'.join(self.logs) + + self.log_window.setText(log_text) + self.app.processEvents() + + def set_loading(self, value, maximum=1): + self.loading_bar.setValue(value * 100) + self.loading_bar.setMaximum(maximum * 100) + self.loading_bar.setTextVisible(value != 0) + self.app.processEvents() + + def populate_gen_options(self, seed, trim_silences): + if seed is not None: + self.random_seed_checkbox.setChecked(True) + self.seed_textbox.setText(str(seed)) + self.seed_textbox.setEnabled(True) + else: + self.random_seed_checkbox.setChecked(False) + self.seed_textbox.setText(str(0)) + self.seed_textbox.setEnabled(False) + + if not trim_silences: + self.trim_silences_checkbox.setChecked(False) + self.trim_silences_checkbox.setDisabled(True) + + def update_seed_textbox(self): + if self.random_seed_checkbox.isChecked(): + self.seed_textbox.setEnabled(True) + else: + self.seed_textbox.setEnabled(False) + + def reset_interface(self, vc_mode): + self.draw_embed(None, None, "current") + self.draw_embed(None, None, "generated") + self.draw_spec(None, "current") + self.draw_spec(None, "generated") + self.draw_umap_projections(set()) + self.set_loading(0) + self.play_button.setDisabled(True) + if vc_mode: + self.convert_button.setDisabled(True) + else: + self.generate_button.setDisabled(True) + self.synthesize_button.setDisabled(True) + self.vocode_button.setDisabled(True) + self.replay_wav_button.setDisabled(True) + self.export_wav_button.setDisabled(True) + [self.log("") for _ in range(self.max_log_lines)] + + def __init__(self, vc_mode): + ## Initialize the application + self.app = QApplication(sys.argv) + super().__init__(None) + self.setWindowTitle("MockingBird GUI") + self.setWindowIcon(QtGui.QIcon('toolbox\\assets\\mb.png')) + self.setWindowFlag(Qt.WindowMinimizeButtonHint, True) + self.setWindowFlag(Qt.WindowMaximizeButtonHint, True) + + + ## Main layouts + # Root + root_layout = QGridLayout() + self.setLayout(root_layout) + + # Browser + browser_layout = QGridLayout() + root_layout.addLayout(browser_layout, 0, 0, 1, 8) + + # Generation + gen_layout = QVBoxLayout() + root_layout.addLayout(gen_layout, 0, 8) + + # Visualizations + vis_layout = QVBoxLayout() + root_layout.addLayout(vis_layout, 1, 0, 2, 8) + + # Output + output_layout = QGridLayout() + vis_layout.addLayout(output_layout, 0) + + # Projections + self.projections_layout = QVBoxLayout() + root_layout.addLayout(self.projections_layout, 1, 8, 2, 2) + + ## Projections + # UMap + fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0") + fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98) + self.projections_layout.addWidget(FigureCanvas(fig)) + self.umap_hot = False + self.clear_button = QPushButton("Clear") + self.projections_layout.addWidget(self.clear_button) + + + ## Browser + # Dataset, speaker and utterance selection + i = 0 + + source_groupbox = QGroupBox('Source(源音频)') + source_layout = QGridLayout() + source_groupbox.setLayout(source_layout) + browser_layout.addWidget(source_groupbox, i, 0, 1, 5) + + self.dataset_box = QComboBox() + source_layout.addWidget(QLabel("Dataset(数据集):"), i, 0) + source_layout.addWidget(self.dataset_box, i, 1) + self.random_dataset_button = QPushButton("Random") + source_layout.addWidget(self.random_dataset_button, i, 2) + i += 1 + self.speaker_box = QComboBox() + source_layout.addWidget(QLabel("Speaker(说话者)"), i, 0) + source_layout.addWidget(self.speaker_box, i, 1) + self.random_speaker_button = QPushButton("Random") + source_layout.addWidget(self.random_speaker_button, i, 2) + i += 1 + self.utterance_box = QComboBox() + source_layout.addWidget(QLabel("Utterance(音频):"), i, 0) + source_layout.addWidget(self.utterance_box, i, 1) + self.random_utterance_button = QPushButton("Random") + source_layout.addWidget(self.random_utterance_button, i, 2) + + i += 1 + source_layout.addWidget(QLabel("Use(使用):"), i, 0) + self.browser_load_button = QPushButton("Load Above(加载上面)") + source_layout.addWidget(self.browser_load_button, i, 1, 1, 2) + self.auto_next_checkbox = QCheckBox("Auto select next") + self.auto_next_checkbox.setChecked(True) + source_layout.addWidget(self.auto_next_checkbox, i+1, 1) + self.browser_browse_button = QPushButton("Browse(打开本地)") + source_layout.addWidget(self.browser_browse_button, i, 3) + self.record_button = QPushButton("Record(录音)") + source_layout.addWidget(self.record_button, i+1, 3) + + i += 2 + # Utterance box + browser_layout.addWidget(QLabel("Current(当前):"), i, 0) + self.utterance_history = QComboBox() + browser_layout.addWidget(self.utterance_history, i, 1) + self.play_button = QPushButton("Play(播放)") + browser_layout.addWidget(self.play_button, i, 2) + self.stop_button = QPushButton("Stop(暂停)") + browser_layout.addWidget(self.stop_button, i, 3) + if vc_mode: + self.load_soruce_button = QPushButton("Select(选择为被转换的语音输入)") + browser_layout.addWidget(self.load_soruce_button, i, 4) + + i += 1 + model_groupbox = QGroupBox('Models(模型选择)') + model_layout = QHBoxLayout() + model_groupbox.setLayout(model_layout) + browser_layout.addWidget(model_groupbox, i, 0, 2, 5) + + # Model and audio output selection + self.encoder_box = QComboBox() + model_layout.addWidget(QLabel("Encoder:")) + model_layout.addWidget(self.encoder_box) + self.synthesizer_box = QComboBox() + if vc_mode: + self.extractor_box = QComboBox() + model_layout.addWidget(QLabel("Extractor:")) + model_layout.addWidget(self.extractor_box) + self.convertor_box = QComboBox() + model_layout.addWidget(QLabel("Convertor:")) + model_layout.addWidget(self.convertor_box) + else: + model_layout.addWidget(QLabel("Synthesizer:")) + model_layout.addWidget(self.synthesizer_box) + self.vocoder_box = QComboBox() + model_layout.addWidget(QLabel("Vocoder:")) + model_layout.addWidget(self.vocoder_box) + + #Replay & Save Audio + i = 0 + output_layout.addWidget(QLabel("Toolbox Output:"), i, 0) + self.waves_cb = QComboBox() + self.waves_cb_model = QStringListModel() + self.waves_cb.setModel(self.waves_cb_model) + self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting") + output_layout.addWidget(self.waves_cb, i, 1) + self.replay_wav_button = QPushButton("Replay") + self.replay_wav_button.setToolTip("Replay last generated vocoder") + output_layout.addWidget(self.replay_wav_button, i, 2) + self.export_wav_button = QPushButton("Export") + self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file") + output_layout.addWidget(self.export_wav_button, i, 3) + self.audio_out_devices_cb=QComboBox() + i += 1 + output_layout.addWidget(QLabel("Audio Output"), i, 0) + output_layout.addWidget(self.audio_out_devices_cb, i, 1) + + ## Embed & spectrograms + vis_layout.addStretch() + # TODO: add spectrograms for source + gridspec_kw = {"width_ratios": [1, 4]} + fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0", + gridspec_kw=gridspec_kw) + fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8) + vis_layout.addWidget(FigureCanvas(fig)) + + fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0", + gridspec_kw=gridspec_kw) + fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8) + vis_layout.addWidget(FigureCanvas(fig)) + + for ax in self.current_ax.tolist() + self.gen_ax.tolist(): + ax.set_facecolor("#F0F0F0") + for side in ["top", "right", "bottom", "left"]: + ax.spines[side].set_visible(False) + + ## Generation + self.text_prompt = QPlainTextEdit(default_text) + gen_layout.addWidget(self.text_prompt, stretch=1) + + if vc_mode: + layout = QHBoxLayout() + self.convert_button = QPushButton("Extract and Convert") + layout.addWidget(self.convert_button) + gen_layout.addLayout(layout) + else: + self.generate_button = QPushButton("Synthesize and vocode") + gen_layout.addWidget(self.generate_button) + layout = QHBoxLayout() + self.synthesize_button = QPushButton("Synthesize only") + layout.addWidget(self.synthesize_button) + + self.vocode_button = QPushButton("Vocode only") + layout.addWidget(self.vocode_button) + gen_layout.addLayout(layout) + + + layout_seed = QGridLayout() + self.random_seed_checkbox = QCheckBox("Random seed:") + self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.") + layout_seed.addWidget(self.random_seed_checkbox, 0, 0) + self.seed_textbox = QLineEdit() + self.seed_textbox.setMaximumWidth(80) + layout_seed.addWidget(self.seed_textbox, 0, 1) + self.trim_silences_checkbox = QCheckBox("Enhance vocoder output") + self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output." + " This feature requires `webrtcvad` to be installed.") + layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2) + self.style_slider = QSlider(Qt.Horizontal) + self.style_slider.setTickInterval(1) + self.style_slider.setFocusPolicy(Qt.NoFocus) + self.style_slider.setSingleStep(1) + self.style_slider.setRange(-1, 9) + self.style_value_label = QLabel("-1") + self.style_slider.setValue(-1) + layout_seed.addWidget(QLabel("Style:"), 1, 0) + + self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s)) + layout_seed.addWidget(self.style_value_label, 1, 1) + layout_seed.addWidget(self.style_slider, 1, 3) + + self.token_slider = QSlider(Qt.Horizontal) + self.token_slider.setTickInterval(1) + self.token_slider.setFocusPolicy(Qt.NoFocus) + self.token_slider.setSingleStep(1) + self.token_slider.setRange(3, 9) + self.token_value_label = QLabel("5") + self.token_slider.setValue(4) + layout_seed.addWidget(QLabel("Accuracy(精度):"), 2, 0) + + self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s)) + layout_seed.addWidget(self.token_value_label, 2, 1) + layout_seed.addWidget(self.token_slider, 2, 3) + + self.length_slider = QSlider(Qt.Horizontal) + self.length_slider.setTickInterval(1) + self.length_slider.setFocusPolicy(Qt.NoFocus) + self.length_slider.setSingleStep(1) + self.length_slider.setRange(1, 10) + self.length_value_label = QLabel("2") + self.length_slider.setValue(2) + layout_seed.addWidget(QLabel("MaxLength(最大句长):"), 3, 0) + + self.length_slider.valueChanged.connect(lambda s: self.length_value_label.setNum(s)) + layout_seed.addWidget(self.length_value_label, 3, 1) + layout_seed.addWidget(self.length_slider, 3, 3) + + gen_layout.addLayout(layout_seed) + + self.loading_bar = QProgressBar() + gen_layout.addWidget(self.loading_bar) + + self.log_window = QLabel() + self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft) + gen_layout.addWidget(self.log_window) + self.logs = [] + gen_layout.addStretch() + + + ## Set the size of the window and of the elements + max_size = QDesktopWidget().availableGeometry(self).size() * 0.5 + self.resize(max_size) + + ## Finalize the display + self.reset_interface(vc_mode) + self.show() + + def start(self): + self.app.exec_() diff --git a/toolbox/utterance.py b/toolbox/utterance.py new file mode 100644 index 0000000000000000000000000000000000000000..844c8a2adb0c8eba2992eaf5ea357d7add3c1896 --- /dev/null +++ b/toolbox/utterance.py @@ -0,0 +1,5 @@ +from collections import namedtuple + +Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth") +Utterance.__eq__ = lambda x, y: x.name == y.name +Utterance.__hash__ = lambda x: hash(x.name) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6a06c805109159ff40cad69668f1fc38cf1e9b --- /dev/null +++ b/train.py @@ -0,0 +1,67 @@ +import sys +import torch +import argparse +import numpy as np +from utils.load_yaml import HpsYaml +from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + +# For reproducibility, comment these may speed up training +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +def main(): + # Arguments + parser = argparse.ArgumentParser(description= + 'Training PPG2Mel VC model.') + parser.add_argument('--config', type=str, + help='Path to experiment config, e.g., config/vc.yaml') + parser.add_argument('--name', default=None, type=str, help='Name for logging.') + parser.add_argument('--logdir', default='log/', type=str, + help='Logging path.', required=False) + parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str, + help='Checkpoint path.', required=False) + parser.add_argument('--outdir', default='result/', type=str, + help='Decode output path.', required=False) + parser.add_argument('--load', default=None, type=str, + help='Load pre-trained model (for training only)', required=False) + parser.add_argument('--warm_start', action='store_true', + help='Load model weights only, ignore specified layers.') + parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducable results.', required=False) + parser.add_argument('--njobs', default=8, type=int, + help='Number of threads for dataloader/decoding.', required=False) + parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') + parser.add_argument('--no-pin', action='store_true', + help='Disable pin-memory for dataloader') + parser.add_argument('--test', action='store_true', help='Test the model.') + parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') + parser.add_argument('--finetune', action='store_true', help='Finetune model') + parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') + parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') + parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') + + ### + + paras = parser.parse_args() + setattr(paras, 'gpu', not paras.cpu) + setattr(paras, 'pin_memory', not paras.no_pin) + setattr(paras, 'verbose', not paras.no_msg) + # Make the config dict dot visitable + config = HpsYaml(paras.config) + + np.random.seed(paras.seed) + torch.manual_seed(paras.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(paras.seed) + + print(">>> OneShot VC training ...") + mode = "train" + solver = Solver(config, paras, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/argutils.py b/utils/argutils.py new file mode 100644 index 0000000000000000000000000000000000000000..db41683027173517c910e3b259f8da48207dcb38 --- /dev/null +++ b/utils/argutils.py @@ -0,0 +1,40 @@ +from pathlib import Path +import numpy as np +import argparse + +_type_priorities = [ # In decreasing order + Path, + str, + int, + float, + bool, +] + +def _priority(o): + p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None) + if p is not None: + return p + p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None) + if p is not None: + return p + return len(_type_priorities) + +def print_args(args: argparse.Namespace, parser=None): + args = vars(args) + if parser is None: + priorities = list(map(_priority, args.values())) + else: + all_params = [a.dest for g in parser._action_groups for a in g._group_actions ] + priority = lambda p: all_params.index(p) if p in all_params else len(all_params) + priorities = list(map(priority, args.keys())) + + pad = max(map(len, args.keys())) + 3 + indices = np.lexsort((list(args.keys()), priorities)) + items = list(args.items()) + + print("Arguments:") + for i in indices: + param, value = items[i] + print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value)) + print("") + \ No newline at end of file diff --git a/utils/audio_utils.py b/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbeddbc65d2048fd90b348db6ff15a420a70f2b --- /dev/null +++ b/utils/audio_utils.py @@ -0,0 +1,60 @@ + +import torch +import torch.utils.data +from scipy.io.wavfile import read +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + +def _dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _spectral_normalize_torch(magnitudes): + output = _dynamic_range_compression_torch(magnitudes) + return output + +mel_basis = {} +hann_window = {} + +def mel_spectrogram( + y, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False, + output_energy=False, +): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + mel_spec = _spectral_normalize_torch(mel_spec) + if output_energy: + energy = torch.norm(spec, dim=1) + return mel_spec, energy + else: + return mel_spec diff --git a/utils/data_load.py b/utils/data_load.py new file mode 100644 index 0000000000000000000000000000000000000000..37723cff7de26a4e0b85368170531970498917fa --- /dev/null +++ b/utils/data_load.py @@ -0,0 +1,214 @@ +import random +import numpy as np +import torch +from utils.f0_utils import get_cont_lf0 +import resampy +from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram +from librosa.util import normalize +import os + + +SAMPLE_RATE=16000 + +def read_fids(fid_list_f): + with open(fid_list_f, 'r') as f: + fids = [l.strip().split()[0] for l in f if l.strip()] + return fids + +class OneshotVcDataset(torch.utils.data.Dataset): + def __init__( + self, + meta_file: str, + vctk_ppg_dir: str, + libri_ppg_dir: str, + vctk_f0_dir: str, + libri_f0_dir: str, + vctk_wav_dir: str, + libri_wav_dir: str, + vctk_spk_dvec_dir: str, + libri_spk_dvec_dir: str, + min_max_norm_mel: bool = False, + mel_min: float = None, + mel_max: float = None, + ppg_file_ext: str = "ling_feat.npy", + f0_file_ext: str = "f0.npy", + wav_file_ext: str = "wav", + ): + self.fid_list = read_fids(meta_file) + self.vctk_ppg_dir = vctk_ppg_dir + self.libri_ppg_dir = libri_ppg_dir + self.vctk_f0_dir = vctk_f0_dir + self.libri_f0_dir = libri_f0_dir + self.vctk_wav_dir = vctk_wav_dir + self.libri_wav_dir = libri_wav_dir + self.vctk_spk_dvec_dir = vctk_spk_dvec_dir + self.libri_spk_dvec_dir = libri_spk_dvec_dir + + self.ppg_file_ext = ppg_file_ext + self.f0_file_ext = f0_file_ext + self.wav_file_ext = wav_file_ext + + self.min_max_norm_mel = min_max_norm_mel + if min_max_norm_mel: + print("[INFO] Min-Max normalize Melspec.") + assert mel_min is not None + assert mel_max is not None + self.mel_max = mel_max + self.mel_min = mel_min + + random.seed(1234) + random.shuffle(self.fid_list) + print(f'[INFO] Got {len(self.fid_list)} samples.') + + def __len__(self): + return len(self.fid_list) + + def get_spk_dvec(self, fid): + spk_name = fid + if spk_name.startswith("p"): + spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy" + else: + spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy" + return torch.from_numpy(np.load(spk_dvec_path)) + + def compute_mel(self, wav_path): + audio, sr = load_wav(wav_path) + if sr != SAMPLE_RATE: + audio = resampy.resample(audio, sr, SAMPLE_RATE) + audio = audio / MAX_WAV_VALUE + audio = normalize(audio) * 0.95 + audio = torch.FloatTensor(audio).unsqueeze(0) + melspec = mel_spectrogram( + audio, + n_fft=1024, + num_mels=80, + sampling_rate=SAMPLE_RATE, + hop_size=160, + win_size=1024, + fmin=80, + fmax=8000, + ) + return melspec.squeeze(0).numpy().T + + def bin_level_min_max_norm(self, melspec): + # frequency bin level min-max normalization to [-4, 4] + mel = (melspec - self.mel_min) / (self.mel_max - self.mel_min) * 8.0 - 4.0 + return np.clip(mel, -4., 4.) + + def __getitem__(self, index): + fid = self.fid_list[index] + + # 1. Load features + if fid.startswith("p"): + # vctk + sub = fid.split("_")[0] + ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}") + f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}") + mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}") + else: + # aidatatang + sub = fid[5:10] + ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}") + f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}") + mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}") + if self.min_max_norm_mel: + mel = self.bin_level_min_max_norm(mel) + + f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid) + spk_dvec = self.get_spk_dvec(fid) + + # 2. Convert f0 to continuous log-f0 and u/v flags + uv, cont_lf0 = get_cont_lf0(f0, 10.0, False) + # cont_lf0 = (cont_lf0 - np.amin(cont_lf0)) / (np.amax(cont_lf0) - np.amin(cont_lf0)) + # cont_lf0 = self.utt_mvn(cont_lf0) + lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1) + + # uv, cont_f0 = convert_continuous_f0(f0) + # cont_f0 = (cont_f0 - np.amin(cont_f0)) / (np.amax(cont_f0) - np.amin(cont_f0)) + # lf0_uv = np.concatenate([cont_f0[:, np.newaxis], uv[:, np.newaxis]], axis=1) + + # 3. Convert numpy array to torch.tensor + ppg = torch.from_numpy(ppg) + lf0_uv = torch.from_numpy(lf0_uv) + mel = torch.from_numpy(mel) + + return (ppg, lf0_uv, mel, spk_dvec, fid) + + def check_lengths(self, f0, ppg, mel, fid): + LEN_THRESH = 10 + assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \ + f"{abs(len(ppg) - len(f0))}: for file {fid}" + assert abs(len(mel) - len(f0)) <= LEN_THRESH, \ + f"{abs(len(mel) - len(f0))}: for file {fid}" + + def _adjust_lengths(self, f0, ppg, mel, fid): + self.check_lengths(f0, ppg, mel, fid) + min_len = min( + len(f0), + len(ppg), + len(mel), + ) + f0 = f0[:min_len] + ppg = ppg[:min_len] + mel = mel[:min_len] + return f0, ppg, mel + +class MultiSpkVcCollate(): + """Zero-pads model inputs and targets based on number of frames per step + """ + def __init__(self, n_frames_per_step=1, give_uttids=False, + f02ppg_length_ratio=1, use_spk_dvec=False): + self.n_frames_per_step = n_frames_per_step + self.give_uttids = give_uttids + self.f02ppg_length_ratio = f02ppg_length_ratio + self.use_spk_dvec = use_spk_dvec + + def __call__(self, batch): + batch_size = len(batch) + # Prepare different features + ppgs = [x[0] for x in batch] + lf0_uvs = [x[1] for x in batch] + mels = [x[2] for x in batch] + fids = [x[-1] for x in batch] + if len(batch[0]) == 5: + spk_ids = [x[3] for x in batch] + if self.use_spk_dvec: + # use d-vector + spk_ids = torch.stack(spk_ids).float() + else: + # use one-hot ids + spk_ids = torch.LongTensor(spk_ids) + # Pad features into chunk + ppg_lengths = [x.shape[0] for x in ppgs] + mel_lengths = [x.shape[0] for x in mels] + max_ppg_len = max(ppg_lengths) + max_mel_len = max(mel_lengths) + if max_mel_len % self.n_frames_per_step != 0: + max_mel_len += (self.n_frames_per_step - max_mel_len % self.n_frames_per_step) + ppg_dim = ppgs[0].shape[1] + mel_dim = mels[0].shape[1] + ppgs_padded = torch.FloatTensor(batch_size, max_ppg_len, ppg_dim).zero_() + mels_padded = torch.FloatTensor(batch_size, max_mel_len, mel_dim).zero_() + lf0_uvs_padded = torch.FloatTensor(batch_size, self.f02ppg_length_ratio * max_ppg_len, 2).zero_() + stop_tokens = torch.FloatTensor(batch_size, max_mel_len).zero_() + for i in range(batch_size): + cur_ppg_len = ppgs[i].shape[0] + cur_mel_len = mels[i].shape[0] + ppgs_padded[i, :cur_ppg_len, :] = ppgs[i] + lf0_uvs_padded[i, :self.f02ppg_length_ratio*cur_ppg_len, :] = lf0_uvs[i] + mels_padded[i, :cur_mel_len, :] = mels[i] + stop_tokens[i, cur_ppg_len-self.n_frames_per_step:] = 1 + if len(batch[0]) == 5: + ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \ + torch.LongTensor(mel_lengths), spk_ids, stop_tokens) + if self.give_uttids: + return ret_tup + (fids, ) + else: + return ret_tup + else: + ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \ + torch.LongTensor(mel_lengths), stop_tokens) + if self.give_uttids: + return ret_tup + (fids, ) + else: + return ret_tup diff --git a/utils/f0_utils.py b/utils/f0_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc25a882e866a05cfb9afc86397f6c82561a498 --- /dev/null +++ b/utils/f0_utils.py @@ -0,0 +1,124 @@ +import logging +import numpy as np +import pyworld +from scipy.interpolate import interp1d +from scipy.signal import firwin, get_window, lfilter + +def compute_mean_std(lf0): + nonzero_indices = np.nonzero(lf0) + mean = np.mean(lf0[nonzero_indices]) + std = np.std(lf0[nonzero_indices]) + return mean, std + + +def compute_f0(wav, sr=16000, frame_period=10.0): + """Compute f0 from wav using pyworld harvest algorithm.""" + wav = wav.astype(np.float64) + f0, _ = pyworld.harvest( + wav, sr, frame_period=frame_period, f0_floor=80.0, f0_ceil=600.0) + return f0.astype(np.float32) + +def f02lf0(f0): + lf0 = f0.copy() + nonzero_indices = np.nonzero(f0) + lf0[nonzero_indices] = np.log(f0[nonzero_indices]) + return lf0 + +def get_converted_lf0uv( + wav, + lf0_mean_trg, + lf0_std_trg, + convert=True, +): + f0_src = compute_f0(wav) + if not convert: + uv, cont_lf0 = get_cont_lf0(f0_src) + lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1) + return lf0_uv + + lf0_src = f02lf0(f0_src) + lf0_mean_src, lf0_std_src = compute_mean_std(lf0_src) + + lf0_vc = lf0_src.copy() + lf0_vc[lf0_src > 0.0] = (lf0_src[lf0_src > 0.0] - lf0_mean_src) / lf0_std_src * lf0_std_trg + lf0_mean_trg + f0_vc = lf0_vc.copy() + f0_vc[lf0_src > 0.0] = np.exp(lf0_vc[lf0_src > 0.0]) + + uv, cont_lf0_vc = get_cont_lf0(f0_vc) + lf0_uv = np.concatenate([cont_lf0_vc[:, np.newaxis], uv[:, np.newaxis]], axis=1) + return lf0_uv + +def low_pass_filter(x, fs, cutoff=70, padding=True): + """FUNCTION TO APPLY LOW PASS FILTER + + Args: + x (ndarray): Waveform sequence + fs (int): Sampling frequency + cutoff (float): Cutoff frequency of low pass filter + + Return: + (ndarray): Low pass filtered waveform sequence + """ + + nyquist = fs // 2 + norm_cutoff = cutoff / nyquist + + # low cut filter + numtaps = 255 + fil = firwin(numtaps, norm_cutoff) + x_pad = np.pad(x, (numtaps, numtaps), 'edge') + lpf_x = lfilter(fil, 1, x_pad) + lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2] + + return lpf_x + + +def convert_continuos_f0(f0): + """CONVERT F0 TO CONTINUOUS F0 + + Args: + f0 (ndarray): original f0 sequence with the shape (T) + + Return: + (ndarray): continuous f0 with the shape (T) + """ + # get uv information as binary + uv = np.float32(f0 != 0) + + # get start and end of f0 + if (f0 == 0).all(): + logging.warn("all of the f0 values are 0.") + return uv, f0 + start_f0 = f0[f0 != 0][0] + end_f0 = f0[f0 != 0][-1] + + # padding start and end of f0 sequence + start_idx = np.where(f0 == start_f0)[0][0] + end_idx = np.where(f0 == end_f0)[0][-1] + f0[:start_idx] = start_f0 + f0[end_idx:] = end_f0 + + # get non-zero frame index + nz_frames = np.where(f0 != 0)[0] + + # perform linear interpolation + f = interp1d(nz_frames, f0[nz_frames]) + cont_f0 = f(np.arange(0, f0.shape[0])) + + return uv, cont_f0 + + +def get_cont_lf0(f0, frame_period=10.0, lpf=False): + uv, cont_f0 = convert_continuos_f0(f0) + if lpf: + cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20) + cont_lf0_lpf = cont_f0_lpf.copy() + nonzero_indices = np.nonzero(cont_lf0_lpf) + cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices]) + # cont_lf0_lpf = np.log(cont_f0_lpf) + return uv, cont_lf0_lpf + else: + nonzero_indices = np.nonzero(cont_f0) + cont_lf0 = cont_f0.copy() + cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0]) + return uv, cont_lf0 diff --git a/utils/load_yaml.py b/utils/load_yaml.py new file mode 100644 index 0000000000000000000000000000000000000000..5792ff471dc63bacc8c27a7bcc2d4bd6f1e35da8 --- /dev/null +++ b/utils/load_yaml.py @@ -0,0 +1,58 @@ +import yaml + + +def load_hparams(filename): + stream = open(filename, 'r') + docs = yaml.safe_load_all(stream) + hparams_dict = dict() + for doc in docs: + for k, v in doc.items(): + hparams_dict[k] = v + return hparams_dict + +def merge_dict(user, default): + if isinstance(user, dict) and isinstance(default, dict): + for k, v in default.items(): + if k not in user: + user[k] = v + else: + user[k] = merge_dict(user[k], v) + return user + +class Dotdict(dict): + """ + a dictionary that supports dot notation + as well as dictionary access notation + usage: d = DotDict() or d = DotDict({'val1':'first'}) + set attributes: d.val2 = 'second' or d['val2'] = 'second' + get attributes: d.val2 or d['val2'] + """ + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __init__(self, dct=None): + dct = dict() if not dct else dct + for key, value in dct.items(): + if hasattr(value, 'keys'): + value = Dotdict(value) + self[key] = value + +class HpsYaml(Dotdict): + def __init__(self, yaml_file): + super(Dotdict, self).__init__() + hps = load_hparams(yaml_file) + hp_dict = Dotdict(hps) + for k, v in hp_dict.items(): + setattr(self, k, v) + + __getattr__ = Dotdict.__getitem__ + __setattr__ = Dotdict.__setitem__ + __delattr__ = Dotdict.__delitem__ + + + + + + + diff --git a/utils/logmmse.py b/utils/logmmse.py new file mode 100644 index 0000000000000000000000000000000000000000..58cc4502fa5ba0670678c3edaf5ba1587b8b58ea --- /dev/null +++ b/utils/logmmse.py @@ -0,0 +1,247 @@ +# The MIT License (MIT) +# +# Copyright (c) 2015 braindead +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# +# This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I +# simply modified the interface to meet my needs. + + +import numpy as np +import math +from scipy.special import expn +from collections import namedtuple + +NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2") + + +def profile_noise(noise, sampling_rate, window_size=0): + """ + Creates a profile of the noise in a given waveform. + + :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints. + :param sampling_rate: the sampling rate of the audio + :param window_size: the size of the window the logmmse algorithm operates on. A default value + will be picked if left as 0. + :return: a NoiseProfile object + """ + noise, dtype = to_float(noise) + noise += np.finfo(np.float64).eps + + if window_size == 0: + window_size = int(math.floor(0.02 * sampling_rate)) + + if window_size % 2 == 1: + window_size = window_size + 1 + + perc = 50 + len1 = int(math.floor(window_size * perc / 100)) + len2 = int(window_size - len1) + + win = np.hanning(window_size) + win = win * len2 / np.sum(win) + n_fft = 2 * window_size + + noise_mean = np.zeros(n_fft) + n_frames = len(noise) // window_size + for j in range(0, window_size * n_frames, window_size): + noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0)) + noise_mu2 = (noise_mean / n_frames) ** 2 + + return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2) + + +def denoise(wav, noise_profile: NoiseProfile, eta=0.15): + """ + Cleans the noise from a speech waveform given a noise profile. The waveform must have the + same sampling rate as the one used to create the noise profile. + + :param wav: a speech waveform as a numpy array of floats or ints. + :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of + the same) waveform. + :param eta: voice threshold for noise update. While the voice activation detection value is + below this threshold, the noise profile will be continuously updated throughout the audio. + Set to 0 to disable updating the noise profile. + :return: the clean wav as a numpy array of floats or ints of the same length. + """ + wav, dtype = to_float(wav) + wav += np.finfo(np.float64).eps + p = noise_profile + + nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2)) + x_final = np.zeros(nframes * p.len2) + + aa = 0.98 + mu = 0.98 + ksi_min = 10 ** (-25 / 10) + + x_old = np.zeros(p.len1) + xk_prev = np.zeros(p.len1) + noise_mu2 = p.noise_mu2 + for k in range(0, nframes * p.len2, p.len2): + insign = p.win * wav[k:k + p.window_size] + + spec = np.fft.fft(insign, p.n_fft, axis=0) + sig = np.absolute(spec) + sig2 = sig ** 2 + + gammak = np.minimum(sig2 / noise_mu2, 40) + + if xk_prev.all() == 0: + ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) + else: + ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) + ksi = np.maximum(ksi_min, ksi) + + log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi) + vad_decision = np.sum(log_sigma_k) / p.window_size + if vad_decision < eta: + noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 + + a = ksi / (1 + ksi) + vk = a * gammak + ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) + hw = a * np.exp(ei_vk) + sig = sig * hw + xk_prev = sig ** 2 + xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0) + xi_w = np.real(xi_w) + + x_final[k:k + p.len2] = x_old + xi_w[0:p.len1] + x_old = xi_w[p.len1:p.window_size] + + output = from_float(x_final, dtype) + output = np.pad(output, (0, len(wav) - len(output)), mode="constant") + return output + + +## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that +## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of +## webrctvad +# def vad(wav, sampling_rate, eta=0.15, window_size=0): +# """ +# TODO: fix doc +# Creates a profile of the noise in a given waveform. +# +# :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints. +# :param sampling_rate: the sampling rate of the audio +# :param window_size: the size of the window the logmmse algorithm operates on. A default value +# will be picked if left as 0. +# :param eta: voice threshold for noise update. While the voice activation detection value is +# below this threshold, the noise profile will be continuously updated throughout the audio. +# Set to 0 to disable updating the noise profile. +# """ +# wav, dtype = to_float(wav) +# wav += np.finfo(np.float64).eps +# +# if window_size == 0: +# window_size = int(math.floor(0.02 * sampling_rate)) +# +# if window_size % 2 == 1: +# window_size = window_size + 1 +# +# perc = 50 +# len1 = int(math.floor(window_size * perc / 100)) +# len2 = int(window_size - len1) +# +# win = np.hanning(window_size) +# win = win * len2 / np.sum(win) +# n_fft = 2 * window_size +# +# wav_mean = np.zeros(n_fft) +# n_frames = len(wav) // window_size +# for j in range(0, window_size * n_frames, window_size): +# wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0)) +# noise_mu2 = (wav_mean / n_frames) ** 2 +# +# wav, dtype = to_float(wav) +# wav += np.finfo(np.float64).eps +# +# nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2)) +# vad = np.zeros(nframes * len2, dtype=np.bool) +# +# aa = 0.98 +# mu = 0.98 +# ksi_min = 10 ** (-25 / 10) +# +# xk_prev = np.zeros(len1) +# noise_mu2 = noise_mu2 +# for k in range(0, nframes * len2, len2): +# insign = win * wav[k:k + window_size] +# +# spec = np.fft.fft(insign, n_fft, axis=0) +# sig = np.absolute(spec) +# sig2 = sig ** 2 +# +# gammak = np.minimum(sig2 / noise_mu2, 40) +# +# if xk_prev.all() == 0: +# ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) +# else: +# ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) +# ksi = np.maximum(ksi_min, ksi) +# +# log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi) +# vad_decision = np.sum(log_sigma_k) / window_size +# if vad_decision < eta: +# noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 +# print(vad_decision) +# +# a = ksi / (1 + ksi) +# vk = a * gammak +# ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) +# hw = a * np.exp(ei_vk) +# sig = sig * hw +# xk_prev = sig ** 2 +# +# vad[k:k + len2] = vad_decision >= eta +# +# vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant") +# return vad + + +def to_float(_input): + if _input.dtype == np.float64: + return _input, _input.dtype + elif _input.dtype == np.float32: + return _input.astype(np.float64), _input.dtype + elif _input.dtype == np.uint8: + return (_input - 128) / 128., _input.dtype + elif _input.dtype == np.int16: + return _input / 32768., _input.dtype + elif _input.dtype == np.int32: + return _input / 2147483648., _input.dtype + raise ValueError('Unsupported wave file format') + + +def from_float(_input, dtype): + if dtype == np.float64: + return _input, np.float64 + elif dtype == np.float32: + return _input.astype(np.float32) + elif dtype == np.uint8: + return ((_input * 128) + 128).astype(np.uint8) + elif dtype == np.int16: + return (_input * 32768).astype(np.int16) + elif dtype == np.int32: + print(_input) + return (_input * 2147483648).astype(np.int32) + raise ValueError('Unsupported wave file format') diff --git a/utils/modelutils.py b/utils/modelutils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2efc7a82950e5dc035060b1c5c1839056a7699 --- /dev/null +++ b/utils/modelutils.py @@ -0,0 +1,16 @@ +from pathlib import Path + +def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path: Path): + # This function tests the model paths and makes sure at least one is valid. + if encoder_path.is_file() or encoder_path.is_dir(): + return + if synthesizer_path.is_file() or synthesizer_path.is_dir(): + return + if vocoder_path.is_file() or vocoder_path.is_dir(): + return + + # If none of the paths exist, remind the user to download models if needed + print("********************************************************************************") + print("Error: Model files not found. Please download the models") + print("********************************************************************************\n") + quit(-1) diff --git a/utils/profiler.py b/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..17175b9e1b0eb17fdc015199e5194a5c1afb8a28 --- /dev/null +++ b/utils/profiler.py @@ -0,0 +1,45 @@ +from time import perf_counter as timer +from collections import OrderedDict +import numpy as np + + +class Profiler: + def __init__(self, summarize_every=5, disabled=False): + self.last_tick = timer() + self.logs = OrderedDict() + self.summarize_every = summarize_every + self.disabled = disabled + + def tick(self, name): + if self.disabled: + return + + # Log the time needed to execute that function + if not name in self.logs: + self.logs[name] = [] + if len(self.logs[name]) >= self.summarize_every: + self.summarize() + self.purge_logs() + self.logs[name].append(timer() - self.last_tick) + + self.reset_timer() + + def purge_logs(self): + for name in self.logs: + self.logs[name].clear() + + def reset_timer(self): + self.last_tick = timer() + + def summarize(self): + n = max(map(len, self.logs.values())) + assert n == self.summarize_every + print("\nAverage execution time over %d steps:" % n) + + name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()] + pad = max(map(len, name_msgs)) + for name_msg, deltas in zip(name_msgs, self.logs.values()): + print(" %s mean: %4.0fms std: %4.0fms" % + (name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000)) + print("", flush=True) + \ No newline at end of file diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..34bcffd6c0975377a54ae1ce89002be1dae8151d --- /dev/null +++ b/utils/util.py @@ -0,0 +1,50 @@ +import matplotlib +matplotlib.use('Agg') +import time + +class Timer(): + ''' Timer for recording training time distribution. ''' + def __init__(self): + self.prev_t = time.time() + self.clear() + + def set(self): + self.prev_t = time.time() + + def cnt(self, mode): + self.time_table[mode] += time.time()-self.prev_t + self.set() + if mode == 'bw': + self.click += 1 + + def show(self): + total_time = sum(self.time_table.values()) + self.time_table['avg'] = total_time/self.click + self.time_table['rd'] = 100*self.time_table['rd']/total_time + self.time_table['fw'] = 100*self.time_table['fw']/total_time + self.time_table['bw'] = 100*self.time_table['bw']/total_time + msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format( + **self.time_table) + self.clear() + return msg + + def clear(self): + self.time_table = {'rd': 0, 'fw': 0, 'bw': 0} + self.click = 0 + +# Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168 + +def human_format(num): + magnitude = 0 + while num >= 1000: + magnitude += 1 + num /= 1000.0 + # add more suffixes if you need them + return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude]) + + +# provide easy access of attribute from dict, such abc.key +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/vocoder/LICENSE.txt b/vocoder/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..8d6716174d0d0058f3fc3ae6a8e595119605acbf --- /dev/null +++ b/vocoder/LICENSE.txt @@ -0,0 +1,22 @@ +MIT License + +Original work Copyright (c) 2019 fatchord (https://github.com/fatchord) +Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vocoder/display.py b/vocoder/display.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7dd30bc5e4009a8b62a4805596f937f01befb5 --- /dev/null +++ b/vocoder/display.py @@ -0,0 +1,128 @@ +import matplotlib.pyplot as plt +import time +import numpy as np +import sys + + +def progbar(i, n, size=16): + done = (i * size) // n + bar = '' + for i in range(size): + bar += '█' if i <= done else '░' + return bar + + +def stream(message) : + try: + sys.stdout.write("\r{%s}" % message) + except: + #Remove non-ASCII characters from message + message = ''.join(i for i in message if ord(i)<128) + sys.stdout.write("\r{%s}" % message) + + +def simple_table(item_tuples) : + + border_pattern = '+---------------------------------------' + whitespace = ' ' + + headings, cells, = [], [] + + for item in item_tuples : + + heading, cell = str(item[0]), str(item[1]) + + pad_head = True if len(heading) < len(cell) else False + + pad = abs(len(heading) - len(cell)) + pad = whitespace[:pad] + + pad_left = pad[:len(pad)//2] + pad_right = pad[len(pad)//2:] + + if pad_head : + heading = pad_left + heading + pad_right + else : + cell = pad_left + cell + pad_right + + headings += [heading] + cells += [cell] + + border, head, body = '', '', '' + + for i in range(len(item_tuples)) : + + temp_head = f'| {headings[i]} ' + temp_body = f'| {cells[i]} ' + + border += border_pattern[:len(temp_head)] + head += temp_head + body += temp_body + + if i == len(item_tuples) - 1 : + head += '|' + body += '|' + border += '+' + + print(border) + print(head) + print(border) + print(body) + print(border) + print(' ') + + +def time_since(started) : + elapsed = time.time() - started + m = int(elapsed // 60) + s = int(elapsed % 60) + if m >= 60 : + h = int(m // 60) + m = m % 60 + return f'{h}h {m}m {s}s' + else : + return f'{m}m {s}s' + + +def save_attention(attn, path) : + fig = plt.figure(figsize=(12, 6)) + plt.imshow(attn.T, interpolation='nearest', aspect='auto') + fig.savefig(f'{path}.png', bbox_inches='tight') + plt.close(fig) + + +def save_and_trace_attention(attn, path, sw, step): + fig = plt.figure(figsize=(12, 6)) + plt.imshow(attn.T, interpolation='nearest', aspect='auto') + fig.savefig(f'{path}.png', bbox_inches='tight') + sw.add_figure('attention', fig, step) + plt.close(fig) + + +def save_spectrogram(M, path, length=None) : + M = np.flip(M, axis=0) + if length : M = M[:, :length] + fig = plt.figure(figsize=(12, 6)) + plt.imshow(M, interpolation='nearest', aspect='auto') + fig.savefig(f'{path}.png', bbox_inches='tight') + plt.close(fig) + + +def plot(array) : + fig = plt.figure(figsize=(30, 5)) + ax = fig.add_subplot(111) + ax.xaxis.label.set_color('grey') + ax.yaxis.label.set_color('grey') + ax.xaxis.label.set_fontsize(23) + ax.yaxis.label.set_fontsize(23) + ax.tick_params(axis='x', colors='grey', labelsize=23) + ax.tick_params(axis='y', colors='grey', labelsize=23) + plt.plot(array) + + +def plot_spec(M) : + M = np.flip(M, axis=0) + plt.figure(figsize=(18,4)) + plt.imshow(M, interpolation='nearest', aspect='auto') + plt.show() + diff --git a/vocoder/distribution.py b/vocoder/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..d3119a5ba1e77bc25a92d2664f83d366f12399c0 --- /dev/null +++ b/vocoder/distribution.py @@ -0,0 +1,132 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +def log_sum_exp(x): + """ numerically stable log_sum_exp implementation that prevents overflow """ + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, + log_scale_min=None, reduce=True): + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + y_hat = y_hat.permute(0,2,1) + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) + + # tf equivalent + """ + log_probs = tf.where(x < -0.999, log_cdf_plus, + tf.where(x > 0.999, log_one_minus_cdf_min, + tf.where(cdf_delta > 1e-5, + tf.log(tf.maximum(cdf_delta, 1e-12)), + log_pdf_mid - np.log(127.5)))) + """ + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * \ + torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ + (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.mean(log_sum_exp(log_probs)) + else: + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(- torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum( + y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) + + x = torch.clamp(torch.clamp(x, min=-1.), max=1.) + + return x + + +def to_one_hot(tensor, n, fill_with=1.): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() + if tensor.is_cuda: + one_hot = one_hot.cuda() + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot diff --git a/vocoder/fregan/.gitignore b/vocoder/fregan/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b6e47617de110dea7ca47e087ff1347cc2646eda --- /dev/null +++ b/vocoder/fregan/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/vocoder/fregan/LICENSE b/vocoder/fregan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..eb916f2926421087ebfe54c0eaa97da03428852f --- /dev/null +++ b/vocoder/fregan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Rishikesh (ऋषिकेश) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vocoder/fregan/config.json b/vocoder/fregan/config.json new file mode 100644 index 0000000000000000000000000000000000000000..22187b360a7a8e60d77a0a852b17b1d099e79baa --- /dev/null +++ b/vocoder/fregan/config.json @@ -0,0 +1,42 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "disc_start_step":0, + + + "upsample_rates": [5,5,2,2,2], + "upsample_kernel_sizes": [10,10,4,4,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1, 3, 5, 7], [1,3,5,7], [1,3,5,7]], + + "segment_size": 6400, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 200, + "win_size": 800, + + "sampling_rate": 16000, + + "fmin": 0, + "fmax": 7600, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } + + + +} \ No newline at end of file diff --git a/vocoder/fregan/discriminator.py b/vocoder/fregan/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..5f94092634db21102b977c0347e756993edbc2bc --- /dev/null +++ b/vocoder/fregan/discriminator.py @@ -0,0 +1,303 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, spectral_norm +from vocoder.fregan.utils import get_padding +from vocoder.fregan.stft_loss import stft +from vocoder.fregan.dwt import DWT_1D +LRELU_SLOPE = 0.1 + + + +class SpecDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False): + super(SpecDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.discriminators = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))), + ]) + + self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1)) + + def forward(self, y): + + fmap = [] + with torch.no_grad(): + y = y.squeeze(1) + y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device())) + y = y.unsqueeze(1) + for i, d in enumerate(self.discriminators): + y = d(y) + y = F.leaky_relu(y, LRELU_SLOPE) + fmap.append(y) + + y = self.out(y) + fmap.append(y) + + return torch.flatten(y, 1, -1), fmap + +class MultiResSpecDiscriminator(torch.nn.Module): + + def __init__(self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window"): + + super(MultiResSpecDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window), + SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window), + SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.dwt1d = DWT_1D() + self.dwt_conv1 = norm_f(Conv1d(2, 1, 1)) + self.dwt_proj1 = norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))) + self.dwt_conv2 = norm_f(Conv1d(4, 1, 1)) + self.dwt_proj2 = norm_f(Conv2d(1, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))) + self.dwt_conv3 = norm_f(Conv1d(8, 1, 1)) + self.dwt_proj3 = norm_f(Conv2d(1, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))) + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # DWT 1 + x_d1_high1, x_d1_low1 = self.dwt1d(x) + x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1)) + # 1d to 2d + b, c, t = x_d1.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x_d1 = F.pad(x_d1, (0, n_pad), "reflect") + t = t + n_pad + x_d1 = x_d1.view(b, c, t // self.period, self.period) + + x_d1 = self.dwt_proj1(x_d1) + + # DWT 2 + x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1) + x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1) + x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1)) + # 1d to 2d + b, c, t = x_d2.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x_d2 = F.pad(x_d2, (0, n_pad), "reflect") + t = t + n_pad + x_d2 = x_d2.view(b, c, t // self.period, self.period) + + x_d2 = self.dwt_proj2(x_d2) + + # DWT 3 + + x_d3_high1, x_d3_low1 = self.dwt1d(x_d2_high1) + x_d3_high2, x_d3_low2 = self.dwt1d(x_d2_low1) + x_d3_high3, x_d3_low3 = self.dwt1d(x_d2_high2) + x_d3_high4, x_d3_low4 = self.dwt1d(x_d2_low2) + x_d3 = self.dwt_conv3( + torch.cat([x_d3_high1, x_d3_low1, x_d3_high2, x_d3_low2, x_d3_high3, x_d3_low3, x_d3_high4, x_d3_low4], + dim=1)) + # 1d to 2d + b, c, t = x_d3.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x_d3 = F.pad(x_d3, (0, n_pad), "reflect") + t = t + n_pad + x_d3 = x_d3.view(b, c, t // self.period, self.period) + + x_d3 = self.dwt_proj3(x_d3) + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + i = 0 + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + + fmap.append(x) + if i == 0: + x = torch.cat([x, x_d1], dim=2) + elif i == 1: + x = torch.cat([x, x_d2], dim=2) + elif i == 2: + x = torch.cat([x, x_d3], dim=2) + else: + x = x + i = i + 1 + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class ResWiseMultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(ResWiseMultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.dwt1d = DWT_1D() + self.dwt_conv1 = norm_f(Conv1d(2, 128, 15, 1, padding=7)) + self.dwt_conv2 = norm_f(Conv1d(4, 128, 41, 2, padding=20)) + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + # DWT 1 + x_d1_high1, x_d1_low1 = self.dwt1d(x) + x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1)) + + # DWT 2 + x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1) + x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1) + x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1)) + + i = 0 + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + if i == 0: + x = torch.cat([x, x_d1], dim=2) + if i == 1: + x = torch.cat([x, x_d2], dim=2) + i = i + 1 + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class ResWiseMultiScaleDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(ResWiseMultiScaleDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.dwt1d = DWT_1D() + self.dwt_conv1 = norm_f(Conv1d(2, 1, 1)) + self.dwt_conv2 = norm_f(Conv1d(4, 1, 1)) + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + # DWT 1 + y_hi, y_lo = self.dwt1d(y) + y_1 = self.dwt_conv1(torch.cat([y_hi, y_lo], dim=1)) + x_d1_high1, x_d1_low1 = self.dwt1d(y_hat) + y_hat_1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1)) + + # DWT 2 + x_d2_high1, x_d2_low1 = self.dwt1d(y_hi) + x_d2_high2, x_d2_low2 = self.dwt1d(y_lo) + y_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1)) + + x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1) + x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1) + y_hat_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1)) + + for i, d in enumerate(self.discriminators): + + if i == 1: + y = y_1 + y_hat = y_hat_1 + if i == 2: + y = y_2 + y_hat = y_hat_2 + + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs \ No newline at end of file diff --git a/vocoder/fregan/dwt.py b/vocoder/fregan/dwt.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5d995e1a6a8757b21f46dd1a6e74befaee9816 --- /dev/null +++ b/vocoder/fregan/dwt.py @@ -0,0 +1,76 @@ +# Copyright (c) 2019, Adobe Inc. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike +# 4.0 International Public License. To view a copy of this license, visit +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. + +# DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py + + +import pywt +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['DWT_1D'] +Pad_Mode = ['constant', 'reflect', 'replicate', 'circular'] + + +class DWT_1D(nn.Module): + def __init__(self, pad_type='reflect', wavename='haar', + stride=2, in_channels=1, out_channels=None, groups=None, + kernel_size=None, trainable=False): + + super(DWT_1D, self).__init__() + self.trainable = trainable + self.kernel_size = kernel_size + if not self.trainable: + assert self.kernel_size == None + self.in_channels = in_channels + self.out_channels = self.in_channels if out_channels == None else out_channels + self.groups = self.in_channels if groups == None else groups + assert isinstance(self.groups, int) and self.in_channels % self.groups == 0 + self.stride = stride + assert self.stride == 2 + self.wavename = wavename + self.pad_type = pad_type + assert self.pad_type in Pad_Mode + self.get_filters() + self.initialization() + + def get_filters(self): + wavelet = pywt.Wavelet(self.wavename) + band_low = torch.tensor(wavelet.rec_lo) + band_high = torch.tensor(wavelet.rec_hi) + length_band = band_low.size()[0] + self.kernel_size = length_band if self.kernel_size == None else self.kernel_size + assert self.kernel_size >= length_band + a = (self.kernel_size - length_band) // 2 + b = - (self.kernel_size - length_band - a) + b = None if b == 0 else b + self.filt_low = torch.zeros(self.kernel_size) + self.filt_high = torch.zeros(self.kernel_size) + self.filt_low[a:b] = band_low + self.filt_high[a:b] = band_high + + def initialization(self): + self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1)) + self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1)) + if torch.cuda.is_available(): + self.filter_low = self.filter_low.cuda() + self.filter_high = self.filter_high.cuda() + if self.trainable: + self.filter_low = nn.Parameter(self.filter_low) + self.filter_high = nn.Parameter(self.filter_high) + if self.kernel_size % 2 == 0: + self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1] + else: + self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2] + + def forward(self, input): + assert isinstance(input, torch.Tensor) + assert len(input.size()) == 3 + assert input.size()[1] == self.in_channels + input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type) + return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \ + F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups) diff --git a/vocoder/fregan/generator.py b/vocoder/fregan/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0dd3a867c058c1201cd4ab65e6e2f2147aeb05d --- /dev/null +++ b/vocoder/fregan/generator.py @@ -0,0 +1,210 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from vocoder.fregan.utils import init_weights, get_padding + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5, 7)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[3], + padding=get_padding(kernel_size, dilation[3]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class FreGAN(torch.nn.Module): + def __init__(self, h, top_k=4): + super(FreGAN, self).__init__() + self.h = h + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.upsample_rates = h.upsample_rates + self.up_kernels = h.upsample_kernel_sizes + self.cond_level = self.num_upsamples - top_k + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + self.cond_up = nn.ModuleList() + self.res_output = nn.ModuleList() + upsample_ = 1 + kr = 80 + + for i, (u, k) in enumerate(zip(self.upsample_rates, self.up_kernels)): +# self.ups.append(weight_norm( + # ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), + # k, u, padding=(k - u) // 2))) + self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i), + h.upsample_initial_channel//(2**(i+1)), + k, u, padding=(u//2 + u%2), output_padding=u%2))) + + if i > (self.num_upsamples - top_k): + self.res_output.append( + nn.Sequential( + nn.Upsample(scale_factor=u, mode='nearest'), + weight_norm(nn.Conv1d(h.upsample_initial_channel // (2 ** i), + h.upsample_initial_channel // (2 ** (i + 1)), 1)) + ) + ) + if i >= (self.num_upsamples - top_k): + self.cond_up.append( + weight_norm( + ConvTranspose1d(kr, h.upsample_initial_channel // (2 ** i), + self.up_kernels[i - 1], self.upsample_rates[i - 1], + padding=(self.upsample_rates[i-1]//2+self.upsample_rates[i-1]%2), output_padding=self.upsample_rates[i-1]%2)) + ) + kr = h.upsample_initial_channel // (2 ** i) + + upsample_ *= u + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.cond_up.apply(init_weights) + self.res_output.apply(init_weights) + + def forward(self, x): + mel = x + x = self.conv_pre(x) + output = None + for i in range(self.num_upsamples): + if i >= self.cond_level: + mel = self.cond_up[i - self.cond_level](mel) + x += mel + if i > self.cond_level: + if output is None: + output = self.res_output[i - self.cond_level - 1](x) + else: + output = self.res_output[i - self.cond_level - 1](output) + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + if output is not None: + output = output + x + + x = F.leaky_relu(output) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + for l in self.cond_up: + remove_weight_norm(l) + for l in self.res_output: + remove_weight_norm(l[1]) + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +''' + to run this, fix + from . import ResStack + into + from res_stack import ResStack +''' +if __name__ == '__main__': + ''' + torch.Size([3, 80, 10]) + torch.Size([3, 1, 2000]) + 4527362 + ''' + with open('config.json') as f: + data = f.read() + from utils import AttrDict + import json + json_config = json.loads(data) + h = AttrDict(json_config) + model = FreGAN(h) + + c = torch.randn(3, 80, 10) # (B, channels, T). + print(c.shape) + + y = model(c) # (B, 1, T ** prod(upsample_scales) + print(y.shape) + assert y.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) \ No newline at end of file diff --git a/vocoder/fregan/inference.py b/vocoder/fregan/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..780a613376a7c411e75bd6d7a468a3eb1e893a57 --- /dev/null +++ b/vocoder/fregan/inference.py @@ -0,0 +1,74 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import json +import torch +from utils.util import AttrDict +from vocoder.fregan.generator import FreGAN + +generator = None # type: FreGAN +output_sample_rate = None +_device = None + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def load_model(weights_fpath, config_fpath=None, verbose=True): + global generator, _device, output_sample_rate + + if verbose: + print("Building fregan") + + if config_fpath == None: + model_config_fpaths = list(weights_fpath.parent.rglob("*.json")) + if len(model_config_fpaths) > 0: + config_fpath = model_config_fpaths[0] + else: + config_fpath = "./vocoder/fregan/config.json" + with open(config_fpath) as f: + data = f.read() + json_config = json.loads(data) + h = AttrDict(json_config) + output_sample_rate = h.sampling_rate + torch.manual_seed(h.seed) + + if torch.cuda.is_available(): + # _model = _model.cuda() + _device = torch.device('cuda') + else: + _device = torch.device('cpu') + + generator = FreGAN(h).to(_device) + state_dict_g = load_checkpoint( + weights_fpath, _device + ) + generator.load_state_dict(state_dict_g['generator']) + generator.eval() + generator.remove_weight_norm() + + +def is_loaded(): + return generator is not None + + +def infer_waveform(mel, progress_callback=None): + + if generator is None: + raise Exception("Please load fre-gan in memory before using it") + + mel = torch.FloatTensor(mel).to(_device) + mel = mel.unsqueeze(0) + + with torch.no_grad(): + y_g_hat = generator(mel) + audio = y_g_hat.squeeze() + audio = audio.cpu().numpy() + + return audio, output_sample_rate + diff --git a/vocoder/fregan/loss.py b/vocoder/fregan/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e37dc64e29446ecdd9dce03290f4e0eba58fb3d7 --- /dev/null +++ b/vocoder/fregan/loss.py @@ -0,0 +1,35 @@ +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss*2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses \ No newline at end of file diff --git a/vocoder/fregan/meldataset.py b/vocoder/fregan/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..53b2c94e21d9ad3e2a33a6f4b1207a57e0016651 --- /dev/null +++ b/vocoder/fregan/meldataset.py @@ -0,0 +1,176 @@ +import math +import os +import random +import torch +import torch.utils.data +import numpy as np +from librosa.util import normalize +from scipy.io.wavfile import read +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + + spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + #with open(a.input_training_file, 'r', encoding='utf-8') as fi: + # training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + # for x in fi.read().split('\n') if len(x) > 0] + + #with open(a.input_validation_file, 'r', encoding='utf-8') as fi: + # validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + # for x in fi.read().split('\n') if len(x) > 0] + files = os.listdir(a.input_wavs_dir) + random.shuffle(files) + files = [os.path.join(a.input_wavs_dir, f) for f in files] + training_files = files[: -int(len(files) * 0.05)] + validation_files = files[-int(len(files) * 0.05):] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__(self, training_files, segment_size, n_fft, num_mels, + hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, + device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + #audio, sampling_rate = load_wav(filename) + #audio = audio / MAX_WAV_VALUE + audio = np.load(filename) + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + #if sampling_rate != self.sampling_rate: + # raise ValueError("{} SR doesn't match target {} SR".format( + # sampling_rate, self.sampling_rate)) + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start:audio_start+self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') + + mel = mel_spectrogram(audio, self.n_fft, self.num_mels, + self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, + center=False) + else: + mel_path = os.path.join(self.base_mels_path, "mel" + "-" + filename.split("/")[-1].split("-")[-1]) + mel = np.load(mel_path).T + #mel = np.load( + # os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy')) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start:mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant') + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') + + mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, + self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, + center=False) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) \ No newline at end of file diff --git a/vocoder/fregan/modules.py b/vocoder/fregan/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8160e4a9cd14627f0bc7a3f28c36ea6483e3a5 --- /dev/null +++ b/vocoder/fregan/modules.py @@ -0,0 +1,201 @@ +import torch +import torch.nn.functional as F + +class KernelPredictor(torch.nn.Module): + ''' Kernel predictor for the location-variable convolutions + ''' + + def __init__(self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1} + ): + ''' + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): + kpnet_ + ''' + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers + l_b = conv_out_channels * conv_layers + + padding = (kpnet_conv_size - 1) // 2 + self.input_conv = torch.nn.Sequential( + torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_conv = torch.nn.Sequential( + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, + padding=padding, bias=True) + self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, + bias=True) + + def forward(self, c): + ''' + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + Returns: + ''' + batch, cond_channels, cond_length = c.shape + + c = self.input_conv(c) + c = c + self.residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + + kernels = k.contiguous().view(batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length) + bias = b.contiguous().view(batch, + self.conv_layers, + self.conv_out_channels, + cond_length) + return kernels, bias + + +class LVCBlock(torch.nn.Module): + ''' the location-variable convolutions + ''' + + def __init__(self, + in_channels, + cond_channels, + upsample_ratio, + conv_layers=4, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0 + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = conv_layers + self.conv_kernel_size = conv_kernel_size + self.convs = torch.nn.ModuleList() + + self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels, + kernel_size=upsample_ratio*2, stride=upsample_ratio, + padding=upsample_ratio // 2 + upsample_ratio % 2, + output_padding=upsample_ratio % 2) + + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=conv_layers, + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout + ) + + + for i in range(conv_layers): + padding = (3 ** i) * int((conv_kernel_size - 1) / 2) + conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i) + + self.convs.append(conv) + + + def forward(self, x, c): + ''' forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + ''' + batch, in_channels, in_length = x.shape + + + kernels, bias = self.kernel_predictor(c) + + x = F.leaky_relu(x, 0.2) + x = self.upsample(x) + + for i in range(self.conv_layers): + y = F.leaky_relu(x, 0.2) + y = self.convs[i](y) + y = F.leaky_relu(y, 0.2) + + k = kernels[:, i, :, :, :, :] + b = bias[:, i, :, :] + y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length) + x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :]) + return x + + def location_variable_convolution(self, x, kernel, bias, dilation, hop_size): + ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + ''' + batch, in_channels, in_length = x.shape + batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape + + + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), 'constant', 0) + x = x.unfold(3, dilation, + dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum('bildsk,biokl->bolsd', x, kernel) + o = o + bias.unsqueeze(-1).unsqueeze(-1) + o = o.contiguous().view(batch, out_channels, -1) + return o diff --git a/vocoder/fregan/stft_loss.py b/vocoder/fregan/stft_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e47447455341e5725d6f82ded66dc08b5d2b1cc5 --- /dev/null +++ b/vocoder/fregan/stft_loss.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""STFT-based Loss modules.""" + +import torch +import torch.nn.functional as F + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + x_stft = torch.stft(x, fft_size, hop_size, win_length, window) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) + + +class SpectralConvergengeLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self): + """Initilize spectral convergence loss module.""" + super(SpectralConvergengeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Spectral convergence loss value. + """ + return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + + +class LogSTFTMagnitudeLoss(torch.nn.Module): + """Log STFT magnitude loss module.""" + + def __init__(self): + """Initilize los STFT magnitude loss module.""" + super(LogSTFTMagnitudeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Log STFT magnitude loss value. + """ + return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.spectral_convergenge_loss = SpectralConvergengeLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device())) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device())) + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__(self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window"): + """Initialize Multi resolution STFT loss module. + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + """ + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window)] + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + """ + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss \ No newline at end of file diff --git a/vocoder/fregan/train.py b/vocoder/fregan/train.py new file mode 100644 index 0000000000000000000000000000000000000000..de1fac9a2f09ae030139add645092d63f1485594 --- /dev/null +++ b/vocoder/fregan/train.py @@ -0,0 +1,246 @@ +import warnings + +warnings.simplefilter(action='ignore', category=FutureWarning) +import itertools +import os +import time +import torch +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DistributedSampler, DataLoader +from torch.distributed import init_process_group +from torch.nn.parallel import DistributedDataParallel +from vocoder.fregan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist +from vocoder.fregan.generator import FreGAN +from vocoder.fregan.discriminator import ResWiseMultiPeriodDiscriminator, ResWiseMultiScaleDiscriminator +from vocoder.fregan.loss import feature_loss, generator_loss, discriminator_loss +from vocoder.fregan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint + + +torch.backends.cudnn.benchmark = True + + +def train(rank, a, h): + + a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_fregan') + a.checkpoint_path.mkdir(exist_ok=True) + a.training_epochs = 3100 + a.stdout_interval = 5 + a.checkpoint_interval = a.backup_every + a.summary_interval = 5000 + a.validation_interval = 1000 + a.fine_tuning = True + + a.input_wavs_dir = a.syn_dir.joinpath("audio") + a.input_mels_dir = a.syn_dir.joinpath("mels") + + if h.num_gpus > 1: + init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], + world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) + + torch.cuda.manual_seed(h.seed) + device = torch.device('cuda:{:d}'.format(rank)) + + generator = FreGAN(h).to(device) + mpd = ResWiseMultiPeriodDiscriminator().to(device) + msd = ResWiseMultiScaleDiscriminator().to(device) + + if rank == 0: + print(generator) + os.makedirs(a.checkpoint_path, exist_ok=True) + print("checkpoints directory : ", a.checkpoint_path) + + if os.path.isdir(a.checkpoint_path): + cp_g = scan_checkpoint(a.checkpoint_path, 'g_fregan_') + cp_do = scan_checkpoint(a.checkpoint_path, 'do_fregan_') + + steps = 0 + if cp_g is None or cp_do is None: + state_dict_do = None + last_epoch = -1 + else: + state_dict_g = load_checkpoint(cp_g, device) + state_dict_do = load_checkpoint(cp_do, device) + generator.load_state_dict(state_dict_g['generator']) + mpd.load_state_dict(state_dict_do['mpd']) + msd.load_state_dict(state_dict_do['msd']) + steps = state_dict_do['steps'] + 1 + last_epoch = state_dict_do['epoch'] + + if h.num_gpus > 1: + generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) + mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) + msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) + + optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), + h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + + if state_dict_do is not None: + optim_g.load_state_dict(state_dict_do['optim_g']) + optim_d.load_state_dict(state_dict_do['optim_d']) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) + + training_filelist, validation_filelist = get_dataset_filelist(a) + + trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, + shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, + fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) + + train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None + + train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, + sampler=train_sampler, + batch_size=h.batch_size, + pin_memory=True, + drop_last=True) + + if rank == 0: + validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, + fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir) + validation_loader = DataLoader(validset, num_workers=1, shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True) + + sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) + + generator.train() + mpd.train() + msd.train() + for epoch in range(max(0, last_epoch), a.training_epochs): + if rank == 0: + start = time.time() + print("Epoch: {}".format(epoch + 1)) + + if h.num_gpus > 1: + train_sampler.set_epoch(epoch) + + for i, batch in enumerate(train_loader): + if rank == 0: + start_b = time.time() + x, y, _, y_mel = batch + x = torch.autograd.Variable(x.to(device, non_blocking=True)) + y = torch.autograd.Variable(y.to(device, non_blocking=True)) + y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) + y = y.unsqueeze(1) + y_g_hat = generator(x) + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, + h.win_size, + h.fmin, h.fmax_for_loss) + + if steps > h.disc_start_step: + optim_d.zero_grad() + + # MPD + y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) + + # MSD + y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) + + loss_disc_all = loss_disc_s + loss_disc_f + + loss_disc_all.backward() + optim_d.step() + + # Generator + optim_g.zero_grad() + + + # L1 Mel-Spectrogram Loss + loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 + + # sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1)) + # loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss + + if steps > h.disc_start_step: + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel + else: + loss_gen_all = loss_mel + + loss_gen_all.backward() + optim_g.step() + + if rank == 0: + # STDOUT logging + if steps % a.stdout_interval == 0: + with torch.no_grad(): + mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() + + print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. + format(steps, loss_gen_all, mel_error, time.time() - start_b)) + + # checkpointing + if steps % a.checkpoint_interval == 0 and steps != 0: + checkpoint_path = "{}/g_fregan_{:08d}.pt".format(a.checkpoint_path, steps) + save_checkpoint(checkpoint_path, + {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) + checkpoint_path = "{}/do_fregan_{:08d}.pt".format(a.checkpoint_path, steps) + save_checkpoint(checkpoint_path, + {'mpd': (mpd.module if h.num_gpus > 1 + else mpd).state_dict(), + 'msd': (msd.module if h.num_gpus > 1 + else msd).state_dict(), + 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, + 'epoch': epoch}) + + # Tensorboard summary logging + if steps % a.summary_interval == 0: + sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) + sw.add_scalar("training/mel_spec_error", mel_error, steps) + + # Validation + if steps % a.validation_interval == 0: # and steps != 0: + generator.eval() + torch.cuda.empty_cache() + val_err_tot = 0 + with torch.no_grad(): + for j, batch in enumerate(validation_loader): + x, y, _, y_mel = batch + y_g_hat = generator(x.to(device)) + y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, + h.hop_size, h.win_size, + h.fmin, h.fmax_for_loss) + #val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() + + if j <= 4: + if steps == 0: + sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) + sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) + + sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) + y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, + h.sampling_rate, h.hop_size, h.win_size, + h.fmin, h.fmax) + sw.add_figure('generated/y_hat_spec_{}'.format(j), + plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) + + val_err = val_err_tot / (j + 1) + sw.add_scalar("validation/mel_spec_error", val_err, steps) + + generator.train() + + steps += 1 + + scheduler_g.step() + scheduler_d.step() + + if rank == 0: + print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) + + diff --git a/vocoder/fregan/utils.py b/vocoder/fregan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45161b1379b91859efe5c002bece6e484546e059 --- /dev/null +++ b/vocoder/fregan/utils.py @@ -0,0 +1,65 @@ +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt +import shutil + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????.pt') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] \ No newline at end of file diff --git a/vocoder/hifigan/config_16k_.json b/vocoder/hifigan/config_16k_.json new file mode 100644 index 0000000000000000000000000000000000000000..7ea7c3c7ae64b18f616c8ba40e6854af607bf063 --- /dev/null +++ b/vocoder/hifigan/config_16k_.json @@ -0,0 +1,38 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "disc_start_step":0, + + "upsample_rates": [5,5,4,2], + "upsample_kernel_sizes": [10,10,8,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "segment_size": 6400, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 200, + "win_size": 800, + + "sampling_rate": 16000, + + "fmin": 0, + "fmax": 7600, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/vocoder/hifigan/env.py b/vocoder/hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0d306d518d0d86a40d7ee992fbad6f04fe875f --- /dev/null +++ b/vocoder/hifigan/env.py @@ -0,0 +1,8 @@ +import os +import shutil + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vocoder/hifigan/inference.py b/vocoder/hifigan/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8caf3485226d259cb2179780d09fbf71fc2d356f --- /dev/null +++ b/vocoder/hifigan/inference.py @@ -0,0 +1,74 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import json +import torch +from utils.util import AttrDict +from vocoder.hifigan.models import Generator + +generator = None # type: Generator +output_sample_rate = None +_device = None + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def load_model(weights_fpath, config_fpath=None, verbose=True): + global generator, _device, output_sample_rate + + if verbose: + print("Building hifigan") + + if config_fpath == None: + model_config_fpaths = list(weights_fpath.parent.rglob("*.json")) + if len(model_config_fpaths) > 0: + config_fpath = model_config_fpaths[0] + else: + config_fpath = "./vocoder/hifigan/config_16k_.json" + with open(config_fpath) as f: + data = f.read() + json_config = json.loads(data) + h = AttrDict(json_config) + output_sample_rate = h.sampling_rate + torch.manual_seed(h.seed) + + if torch.cuda.is_available(): + # _model = _model.cuda() + _device = torch.device('cuda') + else: + _device = torch.device('cpu') + + generator = Generator(h).to(_device) + state_dict_g = load_checkpoint( + weights_fpath, _device + ) + generator.load_state_dict(state_dict_g['generator']) + generator.eval() + generator.remove_weight_norm() + + +def is_loaded(): + return generator is not None + + +def infer_waveform(mel, progress_callback=None): + + if generator is None: + raise Exception("Please load hifi-gan in memory before using it") + + mel = torch.FloatTensor(mel).to(_device) + mel = mel.unsqueeze(0) + + with torch.no_grad(): + y_g_hat = generator(mel) + audio = y_g_hat.squeeze() + audio = audio.cpu().numpy() + + return audio, output_sample_rate + diff --git a/vocoder/hifigan/meldataset.py b/vocoder/hifigan/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0682b0f6c03319a4fd5d16f67d8aa843a0216e --- /dev/null +++ b/vocoder/hifigan/meldataset.py @@ -0,0 +1,178 @@ +import math +import os +import random +import torch +import torch.utils.data +import numpy as np +from librosa.util import normalize +from scipy.io.wavfile import read +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + + spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + # with open(a.input_training_file, 'r', encoding='utf-8') as fi: + # training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + # for x in fi.read().split('\n') if len(x) > 0] + + # with open(a.input_validation_file, 'r', encoding='utf-8') as fi: + # validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + # for x in fi.read().split('\n') if len(x) > 0] + + files = os.listdir(a.input_wavs_dir) + random.shuffle(files) + files = [os.path.join(a.input_wavs_dir, f) for f in files] + training_files = files[: -int(len(files)*0.05)] + validation_files = files[-int(len(files)*0.05): ] + + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__(self, training_files, segment_size, n_fft, num_mels, + hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, + device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + # audio, sampling_rate = load_wav(filename) + # audio = audio / MAX_WAV_VALUE + audio = np.load(filename) + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + # if sampling_rate != self.sampling_rate: + # raise ValueError("{} SR doesn't match target {} SR".format( + # sampling_rate, self.sampling_rate)) + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start:audio_start+self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') + + mel = mel_spectrogram(audio, self.n_fft, self.num_mels, + self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, + center=False) + else: + mel_path = os.path.join(self.base_mels_path, "mel" + "-" + filename.split("/")[-1].split("-")[-1]) + mel = np.load(mel_path).T + # mel = np.load( + # os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy')) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start:mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant') + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') + + mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, + self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, + center=False) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/vocoder/hifigan/models.py b/vocoder/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c352e19f0c8aab5b3c24e861f5b1c06c17c5e750 --- /dev/null +++ b/vocoder/hifigan/models.py @@ -0,0 +1,320 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from vocoder.hifigan.utils import init_weights, get_padding + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + +class InterpolationBlock(torch.nn.Module): + def __init__(self, scale_factor, mode='nearest', align_corners=None, downsample=False): + super(InterpolationBlock, self).__init__() + self.downsample = downsample + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + outputs = torch.nn.functional.interpolate( + x, + size=x.shape[-1] * self.scale_factor \ + if not self.downsample else x.shape[-1] // self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=False + ) + return outputs + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() +# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): +# # self.ups.append(weight_norm( +# # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), +# # k, u, padding=(k-u)//2))) + if h.sampling_rate == 24000: + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + torch.nn.Sequential( + InterpolationBlock(u), + weight_norm(torch.nn.Conv1d( + h.upsample_initial_channel//(2**i), + h.upsample_initial_channel//(2**(i+1)), + k, padding=(k-1)//2, + )) + ) + ) + else: + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i), + h.upsample_initial_channel//(2**(i+1)), + k, u, padding=(u//2 + u%2), output_padding=u%2))) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + if self.h.sampling_rate == 24000: + remove_weight_norm(l[-1]) + else: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i-1](y) + y_hat = self.meanpools[i-1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss*2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + diff --git a/vocoder/hifigan/train.py b/vocoder/hifigan/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9c2f2cc69afec4762bf3b354f5a07982f70d38 --- /dev/null +++ b/vocoder/hifigan/train.py @@ -0,0 +1,253 @@ +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +import itertools +import os +import time +import argparse +import json +import torch +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DistributedSampler, DataLoader +import torch.multiprocessing as mp +from torch.distributed import init_process_group +from torch.nn.parallel import DistributedDataParallel +from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist +from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ + discriminator_loss +from vocoder.hifigan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint + +torch.backends.cudnn.benchmark = True + + +def train(rank, a, h): + + a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan') + a.checkpoint_path.mkdir(exist_ok=True) + a.training_epochs = 3100 + a.stdout_interval = 5 + a.checkpoint_interval = a.backup_every + a.summary_interval = 5000 + a.validation_interval = 1000 + a.fine_tuning = True + + a.input_wavs_dir = a.syn_dir.joinpath("audio") + a.input_mels_dir = a.syn_dir.joinpath("mels") + + if h.num_gpus > 1: + init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], + world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) + + torch.cuda.manual_seed(h.seed) + device = torch.device('cuda:{:d}'.format(rank)) + + generator = Generator(h).to(device) + mpd = MultiPeriodDiscriminator().to(device) + msd = MultiScaleDiscriminator().to(device) + + if rank == 0: + print(generator) + os.makedirs(a.checkpoint_path, exist_ok=True) + print("checkpoints directory : ", a.checkpoint_path) + + if os.path.isdir(a.checkpoint_path): + cp_g = scan_checkpoint(a.checkpoint_path, 'g_hifigan_') + cp_do = scan_checkpoint(a.checkpoint_path, 'do_hifigan_') + + steps = 0 + if cp_g is None or cp_do is None: + state_dict_do = None + last_epoch = -1 + else: + state_dict_g = load_checkpoint(cp_g, device) + state_dict_do = load_checkpoint(cp_do, device) + generator.load_state_dict(state_dict_g['generator']) + mpd.load_state_dict(state_dict_do['mpd']) + msd.load_state_dict(state_dict_do['msd']) + steps = state_dict_do['steps'] + 1 + last_epoch = state_dict_do['epoch'] + + if h.num_gpus > 1: + generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) + mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) + msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) + + optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), + h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + + if state_dict_do is not None: + optim_g.load_state_dict(state_dict_do['optim_g']) + optim_d.load_state_dict(state_dict_do['optim_d']) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) + + training_filelist, validation_filelist = get_dataset_filelist(a) + + # print(training_filelist) + # exit() + + trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, + shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, + fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) + + train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None + + train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, + sampler=train_sampler, + batch_size=h.batch_size, + pin_memory=True, + drop_last=True) + + if rank == 0: + validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, + fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir) + validation_loader = DataLoader(validset, num_workers=1, shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True) + + sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) + + generator.train() + mpd.train() + msd.train() + for epoch in range(max(0, last_epoch), a.training_epochs): + if rank == 0: + start = time.time() + print("Epoch: {}".format(epoch+1)) + + if h.num_gpus > 1: + train_sampler.set_epoch(epoch) + + for i, batch in enumerate(train_loader): + if rank == 0: + start_b = time.time() + x, y, _, y_mel = batch + x = torch.autograd.Variable(x.to(device, non_blocking=True)) + y = torch.autograd.Variable(y.to(device, non_blocking=True)) + y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) + y = y.unsqueeze(1) + + y_g_hat = generator(x) + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, + h.fmin, h.fmax_for_loss) + if steps > h.disc_start_step: + optim_d.zero_grad() + + # MPD + y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) + + # MSD + y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) + + loss_disc_all = loss_disc_s + loss_disc_f + + loss_disc_all.backward() + optim_d.step() + + # Generator + optim_g.zero_grad() + + # L1 Mel-Spectrogram Loss + loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 + + if steps > h.disc_start_step: + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + else: + loss_gen_all = loss_mel + + loss_gen_all.backward() + optim_g.step() + + if rank == 0: + # STDOUT logging + if steps % a.stdout_interval == 0: + with torch.no_grad(): + mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() + + print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. + format(steps, loss_gen_all, mel_error, time.time() - start_b)) + + # checkpointing + if steps % a.checkpoint_interval == 0 and steps != 0: + checkpoint_path = "{}/g_hifigan_{:08d}.pt".format(a.checkpoint_path, steps) + save_checkpoint(checkpoint_path, + {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) + checkpoint_path = "{}/do_hifigan_{:08d}.pt".format(a.checkpoint_path, steps) + save_checkpoint(checkpoint_path, + {'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), + 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), + 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, + 'epoch': epoch}) + + # Tensorboard summary logging + if steps % a.summary_interval == 0: + sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) + sw.add_scalar("training/mel_spec_error", mel_error, steps) + + + # save temperate hifigan model + if steps % a.save_every == 0: + checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path) + save_checkpoint(checkpoint_path, + {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) + checkpoint_path = "{}/do_hifigan.pt".format(a.checkpoint_path) + save_checkpoint(checkpoint_path, + {'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), + 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), + 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, + 'epoch': epoch}) + + # Validation + if steps % a.validation_interval == 0: # and steps != 0: + generator.eval() + torch.cuda.empty_cache() + val_err_tot = 0 + with torch.no_grad(): + for j, batch in enumerate(validation_loader): + x, y, _, y_mel = batch + y_g_hat = generator(x.to(device)) + y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, + h.hop_size, h.win_size, + h.fmin, h.fmax_for_loss) +# val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() + + if j <= 4: + if steps == 0: + sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) + sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) + + sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) + y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, + h.sampling_rate, h.hop_size, h.win_size, + h.fmin, h.fmax) + sw.add_figure('generated/y_hat_spec_{}'.format(j), + plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) + + val_err = val_err_tot / (j+1) + sw.add_scalar("validation/mel_spec_error", val_err, steps) + + generator.train() + + steps += 1 + + scheduler_g.step() + scheduler_d.step() + + if rank == 0: + print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) diff --git a/vocoder/hifigan/utils.py b/vocoder/hifigan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e67cbcda0744201d8342b212160808b7c934ea64 --- /dev/null +++ b/vocoder/hifigan/utils.py @@ -0,0 +1,58 @@ +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????.pt') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vocoder/saved_models/pretrained/g_hifigan.pt b/vocoder/saved_models/pretrained/g_hifigan.pt new file mode 100644 index 0000000000000000000000000000000000000000..9cda7efd2ca27024816da61f1ddc8001cd0ea3b2 --- /dev/null +++ b/vocoder/saved_models/pretrained/g_hifigan.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c5b29830f9b42c481c108cb0b89d56f380928d4d46e1d30d65c92340ddc694e +size 51985448 diff --git a/vocoder/saved_models/pretrained/pretrained.pt b/vocoder/saved_models/pretrained/pretrained.pt new file mode 100644 index 0000000000000000000000000000000000000000..101053e9ab3558bca8be1f33a7525d51da3d405a --- /dev/null +++ b/vocoder/saved_models/pretrained/pretrained.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d7a6861589e927e0fbdaa5849ca022258fe2b58a20cc7bfb8fb598ccf936169 +size 53845290 diff --git a/vocoder/vocoder_dataset.py b/vocoder/vocoder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3aedb09290cfc8200363a0cc277eba671720736f --- /dev/null +++ b/vocoder/vocoder_dataset.py @@ -0,0 +1,84 @@ +from torch.utils.data import Dataset +from pathlib import Path +from vocoder.wavernn import audio +import vocoder.wavernn.hparams as hp +import numpy as np +import torch + + +class VocoderDataset(Dataset): + def __init__(self, metadata_fpath: Path, mel_dir: Path, wav_dir: Path): + print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, wav_dir)) + + with metadata_fpath.open("r") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + + gta_fnames = [x[1] for x in metadata if int(x[4])] + gta_fpaths = [mel_dir.joinpath(fname) for fname in gta_fnames] + wav_fnames = [x[0] for x in metadata if int(x[4])] + wav_fpaths = [wav_dir.joinpath(fname) for fname in wav_fnames] + self.samples_fpaths = list(zip(gta_fpaths, wav_fpaths)) + + print("Found %d samples" % len(self.samples_fpaths)) + + def __getitem__(self, index): + mel_path, wav_path = self.samples_fpaths[index] + + # Load the mel spectrogram and adjust its range to [-1, 1] + mel = np.load(mel_path).T.astype(np.float32) / hp.mel_max_abs_value + + # Load the wav + wav = np.load(wav_path) + if hp.apply_preemphasis: + wav = audio.pre_emphasis(wav) + wav = np.clip(wav, -1, 1) + + # Fix for missing padding # TODO: settle on whether this is any useful + r_pad = (len(wav) // hp.hop_length + 1) * hp.hop_length - len(wav) + wav = np.pad(wav, (0, r_pad), mode='constant') + assert len(wav) >= mel.shape[1] * hp.hop_length + wav = wav[:mel.shape[1] * hp.hop_length] + assert len(wav) % hp.hop_length == 0 + + # Quantize the wav + if hp.voc_mode == 'RAW': + if hp.mu_law: + quant = audio.encode_mu_law(wav, mu=2 ** hp.bits) + else: + quant = audio.float_2_label(wav, bits=hp.bits) + elif hp.voc_mode == 'MOL': + quant = audio.float_2_label(wav, bits=16) + + return mel.astype(np.float32), quant.astype(np.int64) + + def __len__(self): + return len(self.samples_fpaths) + + +def collate_vocoder(batch): + mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad + max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch] + mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] + sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets] + + mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)] + + labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)] + + mels = np.stack(mels).astype(np.float32) + labels = np.stack(labels).astype(np.int64) + + mels = torch.tensor(mels) + labels = torch.tensor(labels).long() + + x = labels[:, :hp.voc_seq_len] + y = labels[:, 1:] + + bits = 16 if hp.voc_mode == 'MOL' else hp.bits + + x = audio.label_2_float(x.float(), bits) + + if hp.voc_mode == 'MOL' : + y = audio.label_2_float(y.float(), bits) + + return x, y, mels \ No newline at end of file diff --git a/vocoder/wavernn/audio.py b/vocoder/wavernn/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..bec976840ca22ff39644d7dd51f70094910a8458 --- /dev/null +++ b/vocoder/wavernn/audio.py @@ -0,0 +1,108 @@ +import math +import numpy as np +import librosa +import vocoder.wavernn.hparams as hp +from scipy.signal import lfilter +import soundfile as sf + + +def label_2_float(x, bits) : + return 2 * x / (2**bits - 1.) - 1. + + +def float_2_label(x, bits) : + assert abs(x).max() <= 1.0 + x = (x + 1.) * (2**bits - 1) / 2 + return x.clip(0, 2**bits - 1) + + +def load_wav(path) : + return librosa.load(str(path), sr=hp.sample_rate)[0] + + +def save_wav(x, path) : + sf.write(path, x.astype(np.float32), hp.sample_rate) + + +def split_signal(x) : + unsigned = x + 2**15 + coarse = unsigned // 256 + fine = unsigned % 256 + return coarse, fine + + +def combine_signal(coarse, fine) : + return coarse * 256 + fine - 2**15 + + +def encode_16bits(x) : + return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) + + +mel_basis = None + + +def linear_to_mel(spectrogram): + global mel_basis + if mel_basis is None: + mel_basis = build_mel_basis() + return np.dot(mel_basis, spectrogram) + + +def build_mel_basis(): + return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin) + + +def normalize(S): + return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1) + + +def denormalize(S): + return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db + + +def amp_to_db(x): + return 20 * np.log10(np.maximum(1e-5, x)) + + +def db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def spectrogram(y): + D = stft(y) + S = amp_to_db(np.abs(D)) - hp.ref_level_db + return normalize(S) + + +def melspectrogram(y): + D = stft(y) + S = amp_to_db(linear_to_mel(np.abs(D))) + return normalize(S) + + +def stft(y): + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) + + +def pre_emphasis(x): + return lfilter([1, -hp.preemphasis], [1], x) + + +def de_emphasis(x): + return lfilter([1], [1, -hp.preemphasis], x) + + +def encode_mu_law(x, mu) : + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return np.floor((fx + 1) / 2 * mu + 0.5) + + +def decode_mu_law(y, mu, from_labels=True) : + if from_labels: + y = label_2_float(y, math.log2(mu)) + mu = mu - 1 + x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1) + return x + diff --git a/vocoder/wavernn/gen_wavernn.py b/vocoder/wavernn/gen_wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..abda3eb3f9b47ed74398d35a16ec0ffa8a5ff3e6 --- /dev/null +++ b/vocoder/wavernn/gen_wavernn.py @@ -0,0 +1,31 @@ +from vocoder.wavernn.models.fatchord_version import WaveRNN +from vocoder.wavernn.audio import * + + +def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path): + k = model.get_step() // 1000 + + for i, (m, x) in enumerate(test_set, 1): + if i > samples: + break + + print('\n| Generating: %i/%i' % (i, samples)) + + x = x[0].numpy() + + bits = 16 if hp.voc_mode == 'MOL' else hp.bits + + if hp.mu_law and hp.voc_mode != 'MOL' : + x = decode_mu_law(x, 2**bits, from_labels=True) + else : + x = label_2_float(x, bits) + + save_wav(x, save_path.joinpath("%dk_steps_%d_target.wav" % (k, i))) + + batch_str = "gen_batched_target%d_overlap%d" % (target, overlap) if batched else \ + "gen_not_batched" + save_str = save_path.joinpath("%dk_steps_%d_%s.wav" % (k, i, batch_str)) + + wav = model.generate(m, batched, target, overlap, hp.mu_law) + save_wav(wav, save_str) + diff --git a/vocoder/wavernn/hparams.py b/vocoder/wavernn/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..c1de9f7dcc2926735b80a28ed1226ff1b5824753 --- /dev/null +++ b/vocoder/wavernn/hparams.py @@ -0,0 +1,44 @@ +from synthesizer.hparams import hparams as _syn_hp + + +# Audio settings------------------------------------------------------------------------ +# Match the values of the synthesizer +sample_rate = _syn_hp.sample_rate +n_fft = _syn_hp.n_fft +num_mels = _syn_hp.num_mels +hop_length = _syn_hp.hop_size +win_length = _syn_hp.win_size +fmin = _syn_hp.fmin +min_level_db = _syn_hp.min_level_db +ref_level_db = _syn_hp.ref_level_db +mel_max_abs_value = _syn_hp.max_abs_value +preemphasis = _syn_hp.preemphasis +apply_preemphasis = _syn_hp.preemphasize + +bits = 9 # bit depth of signal +mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode + # below + + +# WAVERNN / VOCODER -------------------------------------------------------------------------------- +voc_mode = 'RAW' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from +# mixture of logistics) +voc_upsample_factors = (5, 5, 8) # NB - this needs to correctly factorise hop_length +voc_rnn_dims = 512 +voc_fc_dims = 512 +voc_compute_dims = 128 +voc_res_out_dims = 128 +voc_res_blocks = 10 + +# Training +voc_batch_size = 100 +voc_lr = 1e-4 +voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint +voc_pad = 2 # this will pad the input so that the resnet can 'see' wider + # than input length +voc_seq_len = hop_length * 5 # must be a multiple of hop_length + +# Generating / Synthesizing +voc_gen_batched = True # very fast (realtime+) single utterance batched generation +voc_target = 8000 # target number of samples to be generated in each batch entry +voc_overlap = 400 # number of samples for crossfading between batches diff --git a/vocoder/wavernn/inference.py b/vocoder/wavernn/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..40cd3054b54b1111a213740b35bc8c50c76930cf --- /dev/null +++ b/vocoder/wavernn/inference.py @@ -0,0 +1,64 @@ +from vocoder.wavernn.models.fatchord_version import WaveRNN +from vocoder.wavernn import hparams as hp +import torch + + +_model = None # type: WaveRNN + +def load_model(weights_fpath, verbose=True): + global _model, _device + + if verbose: + print("Building Wave-RNN") + _model = WaveRNN( + rnn_dims=hp.voc_rnn_dims, + fc_dims=hp.voc_fc_dims, + bits=hp.bits, + pad=hp.voc_pad, + upsample_factors=hp.voc_upsample_factors, + feat_dims=hp.num_mels, + compute_dims=hp.voc_compute_dims, + res_out_dims=hp.voc_res_out_dims, + res_blocks=hp.voc_res_blocks, + hop_length=hp.hop_length, + sample_rate=hp.sample_rate, + mode=hp.voc_mode + ) + + if torch.cuda.is_available(): + _model = _model.cuda() + _device = torch.device('cuda') + else: + _device = torch.device('cpu') + + if verbose: + print("Loading model weights at %s" % weights_fpath) + checkpoint = torch.load(weights_fpath, _device) + _model.load_state_dict(checkpoint['model_state']) + _model.eval() + + +def is_loaded(): + return _model is not None + + +def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800, + progress_callback=None): + """ + Infers the waveform of a mel spectrogram output by the synthesizer (the format must match + that of the synthesizer!) + + :param normalize: + :param batched: + :param target: + :param overlap: + :return: + """ + if _model is None: + raise Exception("Please load Wave-RNN in memory before using it") + + if normalize: + mel = mel / hp.mel_max_abs_value + mel = torch.from_numpy(mel[None, ...]) + wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback) + return wav, hp.sample_rate diff --git a/vocoder/wavernn/models/deepmind_version.py b/vocoder/wavernn/models/deepmind_version.py new file mode 100644 index 0000000000000000000000000000000000000000..17b33b271ec40cfc78db9e96bd54f44dd90ec844 --- /dev/null +++ b/vocoder/wavernn/models/deepmind_version.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils.display import * +from utils.dsp import * + + +class WaveRNN(nn.Module) : + def __init__(self, hidden_size=896, quantisation=256) : + super(WaveRNN, self).__init__() + + self.hidden_size = hidden_size + self.split_size = hidden_size // 2 + + # The main matmul + self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + + # Output fc layers + self.O1 = nn.Linear(self.split_size, self.split_size) + self.O2 = nn.Linear(self.split_size, quantisation) + self.O3 = nn.Linear(self.split_size, self.split_size) + self.O4 = nn.Linear(self.split_size, quantisation) + + # Input fc layers + self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False) + self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False) + + # biases for the gates + self.bias_u = nn.Parameter(torch.zeros(self.hidden_size)) + self.bias_r = nn.Parameter(torch.zeros(self.hidden_size)) + self.bias_e = nn.Parameter(torch.zeros(self.hidden_size)) + + # display num params + self.num_params() + + + def forward(self, prev_y, prev_hidden, current_coarse) : + + # Main matmul - the projection is split 3 ways + R_hidden = self.R(prev_hidden) + R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1) + + # Project the prev input + coarse_input_proj = self.I_coarse(prev_y) + I_coarse_u, I_coarse_r, I_coarse_e = \ + torch.split(coarse_input_proj, self.split_size, dim=1) + + # Project the prev input and current coarse sample + fine_input = torch.cat([prev_y, current_coarse], dim=1) + fine_input_proj = self.I_fine(fine_input) + I_fine_u, I_fine_r, I_fine_e = \ + torch.split(fine_input_proj, self.split_size, dim=1) + + # concatenate for the gates + I_u = torch.cat([I_coarse_u, I_fine_u], dim=1) + I_r = torch.cat([I_coarse_r, I_fine_r], dim=1) + I_e = torch.cat([I_coarse_e, I_fine_e], dim=1) + + # Compute all gates for coarse and fine + u = F.sigmoid(R_u + I_u + self.bias_u) + r = F.sigmoid(R_r + I_r + self.bias_r) + e = torch.tanh(r * R_e + I_e + self.bias_e) + hidden = u * prev_hidden + (1. - u) * e + + # Split the hidden state + hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1) + + # Compute outputs + out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) + out_fine = self.O4(F.relu(self.O3(hidden_fine))) + + return out_coarse, out_fine, hidden + + + def generate(self, seq_len): + with torch.no_grad(): + # First split up the biases for the gates + b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size) + b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size) + b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size) + + # Lists for the two output seqs + c_outputs, f_outputs = [], [] + + # Some initial inputs + out_coarse = torch.LongTensor([0]).cuda() + out_fine = torch.LongTensor([0]).cuda() + + # We'll meed a hidden state + hidden = self.init_hidden() + + # Need a clock for display + start = time.time() + + # Loop for generation + for i in range(seq_len) : + + # Split into two hidden states + hidden_coarse, hidden_fine = \ + torch.split(hidden, self.split_size, dim=1) + + # Scale and concat previous predictions + out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1. + out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1. + prev_outputs = torch.cat([out_coarse, out_fine], dim=1) + + # Project input + coarse_input_proj = self.I_coarse(prev_outputs) + I_coarse_u, I_coarse_r, I_coarse_e = \ + torch.split(coarse_input_proj, self.split_size, dim=1) + + # Project hidden state and split 6 ways + R_hidden = self.R(hidden) + R_coarse_u , R_fine_u, \ + R_coarse_r, R_fine_r, \ + R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1) + + # Compute the coarse gates + u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u) + r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r) + e = torch.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e) + hidden_coarse = u * hidden_coarse + (1. - u) * e + + # Compute the coarse output + out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) + posterior = F.softmax(out_coarse, dim=1) + distrib = torch.distributions.Categorical(posterior) + out_coarse = distrib.sample() + c_outputs.append(out_coarse) + + # Project the [prev outputs and predicted coarse sample] + coarse_pred = out_coarse.float() / 127.5 - 1. + fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1) + fine_input_proj = self.I_fine(fine_input) + I_fine_u, I_fine_r, I_fine_e = \ + torch.split(fine_input_proj, self.split_size, dim=1) + + # Compute the fine gates + u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u) + r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r) + e = torch.tanh(r * R_fine_e + I_fine_e + b_fine_e) + hidden_fine = u * hidden_fine + (1. - u) * e + + # Compute the fine output + out_fine = self.O4(F.relu(self.O3(hidden_fine))) + posterior = F.softmax(out_fine, dim=1) + distrib = torch.distributions.Categorical(posterior) + out_fine = distrib.sample() + f_outputs.append(out_fine) + + # Put the hidden state back together + hidden = torch.cat([hidden_coarse, hidden_fine], dim=1) + + # Display progress + speed = (i + 1) / (time.time() - start) + stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed)) + + coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy() + fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy() + output = combine_signal(coarse, fine) + + return output, coarse, fine + + def init_hidden(self, batch_size=1) : + return torch.zeros(batch_size, self.hidden_size).cuda() + + def num_params(self) : + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + print('Trainable Parameters: %.3f million' % parameters) \ No newline at end of file diff --git a/vocoder/wavernn/models/fatchord_version.py b/vocoder/wavernn/models/fatchord_version.py new file mode 100644 index 0000000000000000000000000000000000000000..6413a921651971b4859ed7de8b3a676cd6595d6b --- /dev/null +++ b/vocoder/wavernn/models/fatchord_version.py @@ -0,0 +1,434 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from vocoder.distribution import sample_from_discretized_mix_logistic +from vocoder.display import * +from vocoder.wavernn.audio import * + + +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Module): + def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + k_size = pad * 2 + 1 + self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for i in range(res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__(self, feat_dims, upsample_scales, compute_dims, + res_blocks, res_out_dims, pad): + super().__init__() + total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * total_scale + self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet_stretch = Stretch2d(total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1. / k_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + m = m.unsqueeze(1) + for f in self.up_layers: m = f(m) + m = m.squeeze(1)[:, :, self.indent:-self.indent] + return m.transpose(1, 2), aux.transpose(1, 2) + + +class WaveRNN(nn.Module): + def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors, + feat_dims, compute_dims, res_out_dims, res_blocks, + hop_length, sample_rate, mode='RAW'): + super().__init__() + self.mode = mode + self.pad = pad + if self.mode == 'RAW' : + self.n_classes = 2 ** bits + elif self.mode == 'MOL' : + self.n_classes = 30 + else : + RuntimeError("Unknown model mode value - ", self.mode) + + self.rnn_dims = rnn_dims + self.aux_dims = res_out_dims // 4 + self.hop_length = hop_length + self.sample_rate = sample_rate + + self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad) + self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) + self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) + self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) + self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) + self.fc3 = nn.Linear(fc_dims, self.n_classes) + + self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False) + self.num_params() + + def forward(self, x, mels): + self.step += 1 + bsize = x.size(0) + if torch.cuda.is_available(): + h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() + h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() + else: + h1 = torch.zeros(1, bsize, self.rnn_dims).cpu() + h2 = torch.zeros(1, bsize, self.rnn_dims).cpu() + mels, aux = self.upsample(mels) + + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] + + x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + x = self.I(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def generate(self, mels, batched, target, overlap, mu_law, progress_callback=None): + mu_law = mu_law if self.mode == 'RAW' else False + progress_callback = progress_callback or self.gen_display + + self.eval() + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + + with torch.no_grad(): + if torch.cuda.is_available(): + mels = mels.cuda() + else: + mels = mels.cpu() + wave_len = (mels.size(-1) - 1) * self.hop_length + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both') + mels, aux = self.upsample(mels.transpose(1, 2)) + + if batched: + mels = self.fold_with_overlap(mels, target, overlap) + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mels.size() + + if torch.cuda.is_available(): + h1 = torch.zeros(b_size, self.rnn_dims).cuda() + h2 = torch.zeros(b_size, self.rnn_dims).cuda() + x = torch.zeros(b_size, 1).cuda() + else: + h1 = torch.zeros(b_size, self.rnn_dims).cpu() + h2 = torch.zeros(b_size, self.rnn_dims).cpu() + x = torch.zeros(b_size, 1).cpu() + + d = self.aux_dims + aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + + m_t = mels[:, i, :] + + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + + x = torch.cat([x, m_t, a1_t], dim=1) + x = self.I(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.mode == 'MOL': + sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + if torch.cuda.is_available(): + # x = torch.FloatTensor([[sample]]).cuda() + x = sample.transpose(0, 1).cuda() + else: + x = sample.transpose(0, 1) + + elif self.mode == 'RAW' : + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1. + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError("Unknown model mode value - ", self.mode) + + if i % 100 == 0: + gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 + progress_callback(i, seq_len, b_size, gen_rate) + + output = torch.stack(output).transpose(0, 1) + output = output.cpu().numpy() + output = output.astype(np.float64) + + if batched: + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + if mu_law: + output = decode_mu_law(output, self.n_classes, False) + if hp.apply_preemphasis: + output = de_emphasis(output) + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = np.linspace(1, 0, 20 * self.hop_length) + output = output[:wave_len] + output[-20 * self.hop_length:] *= fade_out + + self.train() + + return output + + + def gen_display(self, i, seq_len, b_size, gen_rate): + pbar = progbar(i, seq_len) + msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | ' + stream(msg) + + def get_gru_cell(self, gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + + def pad_tensor(self, x, pad, side='both'): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == 'both' else t + pad + if torch.cuda.is_available(): + padded = torch.zeros(b, total, c).cuda() + else: + padded = torch.zeros(b, total, c).cpu() + if side == 'before' or side == 'both': + padded[:, pad:pad + t, :] = x + elif side == 'after': + padded[:, :t, :] = x + return padded + + def fold_with_overlap(self, x, target, overlap): + + ''' Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + + Args: + x (tensor) : Upsampled conditioning features. + shape=(1, timesteps, features) + target (int) : Target timesteps for each index of batch + overlap (int) : Timesteps for both xfade and rnn warmup + + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + + Details: + x = [[h1, h2, ... hn]] + + Where each h is a vector of conditioning features + + Eg: target=2, overlap=1 with x.size(1)=10 + + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + ''' + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side='after') + + if torch.cuda.is_available(): + folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda() + else: + folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu() + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + + def xfade_and_unfold(self, y, target, overlap): + + ''' Applies a crossfade and unfolds into a 1d array. + + Args: + y (ndarry) : Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=np.float64 + overlap (int) : Timesteps for both xfade and rnn warmup + + Return: + (ndarry) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + + Details: + y = [[seq1], + [seq2], + [seq3]] + + Apply a gain envelope at both ends of the sequences + + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + + Stagger and add up the groups of samples: + + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + + ''' + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = np.zeros((silence_len), dtype=np.float64) + + # Equal power crossfade + t = np.linspace(-1, 1, fade_len, dtype=np.float64) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = np.concatenate([silence, fade_in]) + fade_out = np.concatenate([fade_out, silence]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = np.zeros((total_len), dtype=np.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + + def get_step(self) : + return self.step.data.item() + + def checkpoint(self, model_dir, optimizer) : + k_steps = self.get_step() // 1000 + self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer) + + def log(self, path, msg) : + with open(path, 'a') as f: + print(msg, file=f) + + def load(self, path, optimizer) : + checkpoint = torch.load(path) + if "optimizer_state" in checkpoint: + self.load_state_dict(checkpoint["model_state"]) + optimizer.load_state_dict(checkpoint["optimizer_state"]) + else: + # Backwards compatibility + self.load_state_dict(checkpoint) + + def save(self, path, optimizer) : + torch.save({ + "model_state": self.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, path) + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out : + print('Trainable Parameters: %.3fM' % parameters) diff --git a/vocoder/wavernn/train.py b/vocoder/wavernn/train.py new file mode 100644 index 0000000000000000000000000000000000000000..44e0929ac67d778b5cc78b669b42fb89e17acf9e --- /dev/null +++ b/vocoder/wavernn/train.py @@ -0,0 +1,127 @@ +from vocoder.wavernn.models.fatchord_version import WaveRNN +from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder +from vocoder.distribution import discretized_mix_logistic_loss +from vocoder.display import stream, simple_table +from vocoder.wavernn.gen_wavernn import gen_testset +from torch.utils.data import DataLoader +from pathlib import Path +from torch import optim +import torch.nn.functional as F +import vocoder.wavernn.hparams as hp +import numpy as np +import time +import torch + + +def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, + save_every: int, backup_every: int, force_restart: bool): + # Check to make sure the hop length is correctly factorised + assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length + + # Instantiate the model + print("Initializing the model...") + model = WaveRNN( + rnn_dims=hp.voc_rnn_dims, + fc_dims=hp.voc_fc_dims, + bits=hp.bits, + pad=hp.voc_pad, + upsample_factors=hp.voc_upsample_factors, + feat_dims=hp.num_mels, + compute_dims=hp.voc_compute_dims, + res_out_dims=hp.voc_res_out_dims, + res_blocks=hp.voc_res_blocks, + hop_length=hp.hop_length, + sample_rate=hp.sample_rate, + mode=hp.voc_mode + ) + + if torch.cuda.is_available(): + model = model.cuda() + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # Initialize the optimizer + optimizer = optim.Adam(model.parameters()) + for p in optimizer.param_groups: + p["lr"] = hp.voc_lr + loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss + + # Load the weights + model_dir = models_dir.joinpath(run_id) + model_dir.mkdir(exist_ok=True) + weights_fpath = model_dir.joinpath(run_id + ".pt") + if force_restart or not weights_fpath.exists(): + print("\nStarting the training of WaveRNN from scratch\n") + model.save(weights_fpath, optimizer) + else: + print("\nLoading weights at %s" % weights_fpath) + model.load(weights_fpath, optimizer) + print("WaveRNN weights loaded from step %d" % model.step) + + # Initialize the dataset + metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \ + voc_dir.joinpath("synthesized.txt") + mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta") + wav_dir = syn_dir.joinpath("audio") + dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir) + test_loader = DataLoader(dataset, + batch_size=1, + shuffle=True, + pin_memory=True) + + # Begin the training + simple_table([('Batch size', hp.voc_batch_size), + ('LR', hp.voc_lr), + ('Sequence Len', hp.voc_seq_len)]) + + for epoch in range(1, 350): + data_loader = DataLoader(dataset, + collate_fn=collate_vocoder, + batch_size=hp.voc_batch_size, + num_workers=2, + shuffle=True, + pin_memory=True) + start = time.time() + running_loss = 0. + + for i, (x, y, m) in enumerate(data_loader, 1): + if torch.cuda.is_available(): + x, m, y = x.cuda(), m.cuda(), y.cuda() + + # Forward pass + y_hat = model(x, m) + if model.mode == 'RAW': + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + elif model.mode == 'MOL': + y = y.float() + y = y.unsqueeze(-1) + + # Backward pass + loss = loss_func(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + speed = i / (time.time() - start) + avg_loss = running_loss / i + + step = model.get_step() + k = step // 1000 + + if backup_every != 0 and step % backup_every == 0 : + model.checkpoint(model_dir, optimizer) + + if save_every != 0 and step % save_every == 0 : + model.save(weights_fpath, optimizer) + + msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \ + f"Loss: {avg_loss:.4f} | {speed:.1f} " \ + f"steps/s | Step: {k}k | " + stream(msg) + + + gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, + hp.voc_target, hp.voc_overlap, model_dir) + print("") diff --git a/vocoder_preprocess.py b/vocoder_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..95f9e5a0f80edc566cfb31bb77736a1c58573d47 --- /dev/null +++ b/vocoder_preprocess.py @@ -0,0 +1,59 @@ +from synthesizer.synthesize import run_synthesis +from synthesizer.hparams import hparams +from utils.argutils import print_args +import argparse +import os + + +if __name__ == "__main__": + class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): + pass + + parser = argparse.ArgumentParser( + description="Creates ground-truth aligned (GTA) spectrograms from the vocoder.", + formatter_class=MyFormatter + ) + parser.add_argument("datasets_root", type=str, help=\ + "Path to the directory containing your SV2TTS directory. If you specify both --in_dir and " + "--out_dir, this argument won't be used.") + parser.add_argument("-m", "--model_dir", type=str, + default="synthesizer/saved_models/mandarin/", help=\ + "Path to the pretrained model directory.") + parser.add_argument("-i", "--in_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the synthesizer directory that contains the mel spectrograms, the wavs and the " + "embeds. Defaults to /SV2TTS/synthesizer/.") + parser.add_argument("-o", "--out_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the output vocoder directory that will contain the ground truth aligned mel " + "spectrograms. Defaults to /SV2TTS/vocoder/.") + parser.add_argument("--hparams", default="", + help="Hyperparameter overrides as a comma-separated list of name=value " + "pairs") + parser.add_argument("--no_trim", action="store_true", help=\ + "Preprocess audio without trimming silences (not recommended).") + parser.add_argument("--cpu", action="store_true", help=\ + "If True, processing is done on CPU, even when a GPU is available.") + args = parser.parse_args() + print_args(args, parser) + modified_hp = hparams.parse(args.hparams) + + if not hasattr(args, "in_dir"): + args.in_dir = os.path.join(args.datasets_root, "SV2TTS", "synthesizer") + if not hasattr(args, "out_dir"): + args.out_dir = os.path.join(args.datasets_root, "SV2TTS", "vocoder") + + if args.cpu: + # Hide GPUs from Pytorch to force CPU processing + os.environ["CUDA_VISIBLE_DEVICES"] = "" + + # Verify webrtcvad is available + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + del args.no_trim + + run_synthesis(args.in_dir, args.out_dir, args.model_dir, modified_hp) + diff --git a/vocoder_train.py b/vocoder_train.py new file mode 100644 index 0000000000000000000000000000000000000000..f618ee00d8f774ecf821b9714932acc7e99aa5d5 --- /dev/null +++ b/vocoder_train.py @@ -0,0 +1,92 @@ +from utils.argutils import print_args +from vocoder.wavernn.train import train +from vocoder.hifigan.train import train as train_hifigan +from vocoder.fregan.train import train as train_fregan +from utils.util import AttrDict +from pathlib import Path +import argparse +import json +import torch +import torch.multiprocessing as mp + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Trains the vocoder from the synthesizer audios and the GTA synthesized mels, " + "or ground truth mels.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("run_id", type=str, help= \ + "Name for this model instance. If a model state from the same run ID was previously " + "saved, the training will restart from there. Pass -f to overwrite saved states and " + "restart from scratch.") + parser.add_argument("datasets_root", type=str, help= \ + "Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir " + "will take priority over this argument.") + parser.add_argument("vocoder_type", type=str, default="wavernn", help= \ + "Choose the vocoder type for train. Defaults to wavernn" + "Now, Support and for choose") + parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the synthesizer directory that contains the ground truth mel spectrograms, " + "the wavs and the embeds. Defaults to /SV2TTS/synthesizer/.") + parser.add_argument("--voc_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the vocoder directory that contains the GTA synthesized mel spectrograms. " + "Defaults to /SV2TTS/vocoder/. Unused if --ground_truth is passed.") + parser.add_argument("-m", "--models_dir", type=str, default="vocoder/saved_models/", help=\ + "Path to the directory that will contain the saved model weights, as well as backups " + "of those weights and wavs generated during training.") + parser.add_argument("-g", "--ground_truth", action="store_true", help= \ + "Train on ground truth spectrograms (/SV2TTS/synthesizer/mels).") + parser.add_argument("-s", "--save_every", type=int, default=1000, help= \ + "Number of steps between updates of the model on the disk. Set to 0 to never save the " + "model.") + parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \ + "Number of steps between backups of the model. Set to 0 to never make backups of the " + "model.") + parser.add_argument("-f", "--force_restart", action="store_true", help= \ + "Do not load any saved model and restart from scratch.") + parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json") + args = parser.parse_args() + + if not hasattr(args, "syn_dir"): + args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer") + args.syn_dir = Path(args.syn_dir) + if not hasattr(args, "voc_dir"): + args.voc_dir = Path(args.datasets_root, "SV2TTS", "vocoder") + args.voc_dir = Path(args.voc_dir) + del args.datasets_root + args.models_dir = Path(args.models_dir) + args.models_dir.mkdir(exist_ok=True) + + print_args(args, parser) + + # Process the arguments + if args.vocoder_type == "wavernn": + # Run the training wavernn + delattr(args, 'vocoder_type') + delattr(args, 'config') + train(**vars(args)) + elif args.vocoder_type == "hifigan": + with open(args.config) as f: + json_config = json.load(f) + h = AttrDict(json_config) + if h.num_gpus > 1: + h.num_gpus = torch.cuda.device_count() + h.batch_size = int(h.batch_size / h.num_gpus) + print('Batch size per GPU :', h.batch_size) + mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,)) + else: + train_hifigan(0, args, h) + elif args.vocoder_type == "fregan": + with open('vocoder/fregan/config.json') as f: + json_config = json.load(f) + h = AttrDict(json_config) + if h.num_gpus > 1: + h.num_gpus = torch.cuda.device_count() + h.batch_size = int(h.batch_size / h.num_gpus) + print('Batch size per GPU :', h.batch_size) + mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,)) + else: + train_fregan(0, args, h) + + \ No newline at end of file diff --git a/web.py b/web.py new file mode 100644 index 0000000000000000000000000000000000000000..d232530ec912f9c985cdd5b67a49f0fc53b4d947 --- /dev/null +++ b/web.py @@ -0,0 +1,21 @@ +import os +import sys +import typer + +cli = typer.Typer() + +@cli.command() +def launch_ui(port: int = typer.Option(8080, "--port", "-p")) -> None: + """Start a graphical UI server for the opyrator. + + The UI is auto-generated from the input- and output-schema of the given function. + """ + # Add the current working directory to the sys path + # This is required to resolve the opyrator path + sys.path.append(os.getcwd()) + + from mkgui.base.ui.streamlit_ui import launch_ui + launch_ui(port) + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/web/DOCKERFILE b/web/DOCKERFILE new file mode 100644 index 0000000000000000000000000000000000000000..64e8c532db299b2f634fd0478c7f89baa817d999 --- /dev/null +++ b/web/DOCKERFILE @@ -0,0 +1,10 @@ + +FROM python:3.7 + +RUN pip install gevent uwsgi flask + +COPY app.py /app.py + +EXPOSE 3000 + +ENTRYPOINT ["uwsgi", "--http", ":3000", "--master", "--module", "app:app"] \ No newline at end of file diff --git a/web/__init__.py b/web/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b71aa5ded086d398c53ce60604c368a1f703c4f --- /dev/null +++ b/web/__init__.py @@ -0,0 +1,135 @@ +from web.api import api_blueprint +from pathlib import Path +from gevent import pywsgi as wsgi +from flask import Flask, Response, request, render_template +from synthesizer.inference import Synthesizer +from encoder import inference as encoder +from vocoder.hifigan import inference as gan_vocoder +from vocoder.wavernn import inference as rnn_vocoder +import numpy as np +import re +from scipy.io.wavfile import write +import librosa +import io +import base64 +from flask_cors import CORS +from flask_wtf import CSRFProtect +import webbrowser + +def webApp(): + # Init and load config + app = Flask(__name__, instance_relative_config=True) + app.config.from_object("web.config.default") + app.config['RESTPLUS_MASK_SWAGGER'] = False + app.register_blueprint(api_blueprint) + + # CORS(app) #允许跨域,注释掉此行则禁止跨域请求 + csrf = CSRFProtect(app) + csrf.init_app(app) + + syn_models_dirt = "synthesizer/saved_models" + synthesizers = list(Path(syn_models_dirt).glob("**/*.pt")) + synthesizers_cache = {} + encoder.load_model(Path("encoder/saved_models/pretrained.pt")) + rnn_vocoder.load_model(Path("vocoder/saved_models/pretrained/pretrained.pt")) + gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt")) + + def pcm2float(sig, dtype='float32'): + """Convert PCM signal to floating point with a range from -1 to 1. + Use dtype='float32' for single precision. + Parameters + ---------- + sig : array_like + Input array, must have integral type. + dtype : data type, optional + Desired (floating point) data type. + Returns + ------- + numpy.ndarray + Normalized floating point data. + See Also + -------- + float2pcm, dtype + """ + sig = np.asarray(sig) + if sig.dtype.kind not in 'iu': + raise TypeError("'sig' must be an array of integers") + dtype = np.dtype(dtype) + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(sig.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + return (sig.astype(dtype) - offset) / abs_max + + # Cache for synthesizer + @csrf.exempt + @app.route("/api/synthesize", methods=["POST"]) + def synthesize(): + # TODO Implementation with json to support more platform + # Load synthesizer + if "synt_path" in request.form: + synt_path = request.form["synt_path"] + else: + synt_path = synthesizers[0] + print("NO synthsizer is specified, try default first one.") + if synthesizers_cache.get(synt_path) is None: + current_synt = Synthesizer(Path(synt_path)) + synthesizers_cache[synt_path] = current_synt + else: + current_synt = synthesizers_cache[synt_path] + print("using synthesizer model: " + str(synt_path)) + # Load input wav + if "upfile_b64" in request.form: + wav_base64 = request.form["upfile_b64"] + wav = base64.b64decode(bytes(wav_base64, 'utf-8')) + wav = pcm2float(np.frombuffer(wav, dtype=np.int16), dtype=np.float32) + sample_rate = Synthesizer.sample_rate + else: + wav, sample_rate, = librosa.load(request.files['file']) + write("temp.wav", sample_rate, wav) #Make sure we get the correct wav + + encoder_wav = encoder.preprocess_wav(wav, sample_rate) + embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + # Load input text + texts = filter(None, request.form["text"].split("\n")) + punctuation = '!,。、,' # punctuate and split/clean text + processed_texts = [] + for text in texts: + for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'): + if processed_text: + processed_texts.append(processed_text.strip()) + texts = processed_texts + + # synthesize and vocode + embeds = [embed] * len(texts) + specs = current_synt.synthesize_spectrograms(texts, embeds) + spec = np.concatenate(specs, axis=1) + sample_rate = Synthesizer.sample_rate + if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN": + wav, sample_rate = rnn_vocoder.infer_waveform(spec) + else: + wav, sample_rate = gan_vocoder.infer_waveform(spec) + + # Return cooked wav + out = io.BytesIO() + write(out, sample_rate, wav.astype(np.float32)) + return Response(out, mimetype="audio/wav") + + @app.route('/', methods=['GET']) + def index(): + return render_template("index.html") + + host = app.config.get("HOST") + port = app.config.get("PORT") + web_address = 'http://{}:{}'.format(host, port) + print(f"Web server:" + web_address) + webbrowser.open(web_address) + server = wsgi.WSGIServer((host, port), app) + server.serve_forever() + return app + +if __name__ == "__main__": + webApp() diff --git a/web/api/__init__.py b/web/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c8726d6b4456830e947b7165cf77ff1879361f --- /dev/null +++ b/web/api/__init__.py @@ -0,0 +1,16 @@ +from flask import Blueprint +from flask_restx import Api +from .audio import api as audio +from .synthesizer import api as synthesizer + +api_blueprint = Blueprint('api', __name__, url_prefix='/api') + +api = Api( + app=api_blueprint, + title='Mocking Bird', + version='1.0', + description='My API' +) + +api.add_namespace(audio) +api.add_namespace(synthesizer) \ No newline at end of file diff --git a/web/api/audio.py b/web/api/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..b30e5dd9ad3a249c2a6e73d9f42372f0ed098b5a --- /dev/null +++ b/web/api/audio.py @@ -0,0 +1,43 @@ +import os +from pathlib import Path +from flask_restx import Namespace, Resource, fields +from flask import Response, current_app + +api = Namespace('audios', description='Audios related operations') + +audio = api.model('Audio', { + 'name': fields.String(required=True, description='The audio name'), +}) + +def generate(wav_path): + with open(wav_path, "rb") as fwav: + data = fwav.read(1024) + while data: + yield data + data = fwav.read(1024) + +@api.route('/') +class AudioList(Resource): + @api.doc('list_audios') + @api.marshal_list_with(audio) + def get(self): + '''List all audios''' + audio_samples = [] + AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR") + if os.path.isdir(AUDIO_SAMPLES_DIR): + audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav")) + return list(a.name for a in audio_samples) + +@api.route('/') +@api.param('name', 'The name of audio') +@api.response(404, 'audio not found') +class Audio(Resource): + @api.doc('get_audio') + @api.marshal_with(audio) + def get(self, name): + '''Fetch a cat given its identifier''' + AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR") + if Path(AUDIO_SAMPLES_DIR + name).exists(): + return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav") + api.abort(404) + \ No newline at end of file diff --git a/web/api/synthesizer.py b/web/api/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23963b3593c444f625214e0778d3a23f14e34e63 --- /dev/null +++ b/web/api/synthesizer.py @@ -0,0 +1,23 @@ +from pathlib import Path +from flask_restx import Namespace, Resource, fields + +api = Namespace('synthesizers', description='Synthesizers related operations') + +synthesizer = api.model('Synthesizer', { + 'name': fields.String(required=True, description='The synthesizer name'), + 'path': fields.String(required=True, description='The synthesizer path'), +}) + +synthesizers_cache = {} +syn_models_dirt = "synthesizer/saved_models" +synthesizers = list(Path(syn_models_dirt).glob("**/*.pt")) +print("Loaded synthesizer models: " + str(len(synthesizers))) + +@api.route('/') +class SynthesizerList(Resource): + @api.doc('list_synthesizers') + @api.marshal_list_with(synthesizer) + def get(self): + '''List all synthesizers''' + return list({"name": e.name, "path": str(e)} for e in synthesizers) + diff --git a/web/config/__init__.py b/web/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web/config/default.py b/web/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..02149ab5c730bd56f5e4d557c44860aa4dff2902 --- /dev/null +++ b/web/config/default.py @@ -0,0 +1,8 @@ +AUDIO_SAMPLES_DIR = 'samples\\' +DEVICE = '0' +HOST = 'localhost' +PORT = 8080 +MAX_CONTENT_PATH =1024 * 1024 * 4 # mp3文件大小限定不能超过4M +SECRET_KEY = "mockingbird_key" +WTF_CSRF_SECRET_KEY = "mockingbird_key" +TEMPLATES_AUTO_RELOAD = True \ No newline at end of file diff --git a/web/static/img/bird-sm.png b/web/static/img/bird-sm.png new file mode 100644 index 0000000000000000000000000000000000000000..d94ab4392906ba614847f08da8c2b49725115272 Binary files /dev/null and b/web/static/img/bird-sm.png differ diff --git a/web/static/img/bird.png b/web/static/img/bird.png new file mode 100644 index 0000000000000000000000000000000000000000..fc02771a4dc5bbcb0239c0bc560cb0fa01509da9 Binary files /dev/null and b/web/static/img/bird.png differ diff --git a/web/static/img/mockingbird.png b/web/static/img/mockingbird.png new file mode 100644 index 0000000000000000000000000000000000000000..9feaf869f5459e0154a35764a9242f40a219a1a0 Binary files /dev/null and b/web/static/img/mockingbird.png differ diff --git a/web/static/js/eruda.min.js b/web/static/js/eruda.min.js new file mode 100644 index 0000000000000000000000000000000000000000..0609b9e8f15d39918a3818abaf979cdb7238b3d5 --- /dev/null +++ b/web/static/js/eruda.min.js @@ -0,0 +1,2 @@ +/*! eruda v1.5.4 https://eruda.liriliri.io/ */ +!function(e,t){"object"==typeof exports&&"object"==typeof module?module.exports=t():"function"==typeof define&&define.amd?define([],t):"object"==typeof exports?exports.eruda=t():e.eruda=t()}("undefined"!=typeof self?self:this,function(){return function(e){function t(r){if(n[r])return n[r].exports;var i=n[r]={i:r,l:!1,exports:{}};return e[r].call(i.exports,i,i.exports,t),i.l=!0,i.exports}var n={};return t.m=e,t.c=n,t.d=function(e,n,r){t.o(e,n)||Object.defineProperty(e,n,{configurable:!1,enumerable:!0,get:r})},t.n=function(e){var n=e&&e.__esModule?function(){return e.default}:function(){return e};return t.d(n,"a",n),n},t.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},t.p="/assets/",t(t.s=82)}([function(e,t,n){"use strict";(function(e,r){function i(e){return e&&e.__esModule?e:{default:e}}Object.defineProperty(t,"__esModule",{value:!0}),t.wrap=t.viewportScale=t.unique=t.uniqId=t.tryIt=t.stripHtmlTag=t.LocalStore=t.stringify=t.type=t.ajax=t.Url=t.query=t.getFileName=t.trim=t.rtrim=t.rmCookie=t.pxToNum=t.perfNow=t.orientation=t.Store=t.Logger=t.Emitter=t.once=t.partial=t.restArgs=t.now=t.nextTick=t.detectBrowser=t.toInt=t.ms=t.toNum=t.meta=t.safeStorage=t.memStorage=t.$=t.$class=t.some=t.cloneDeep=t.mapObj=void 0,t.concat=t.$event=t.delegate=t.$show=t.$remove=t.$property=t.$offset=t.$insert=t.$css=t.$data=t.$attr=t.$safeEls=t.Select=t.MutationObserver=t.Enum=t.Class=t.toArr=t.cookie=t.decodeUriComponent=t.map=t.evalCss=t.filter=t.safeCb=t.matcher=t.ltrim=t.dateFormat=t.lpad=t.repeat=t.loadJs=t.isRegExp=t.isNull=t.isNative=t.toSrc=t.isNil=t.isNaN=t.prefix=t.isMobile=t.memoize=t.isMatch=t.isErudaEl=t.isErr=t.isEl=t.isCrossOrig=t.startWith=t.isBool=t.isEmpty=t.isStr=t.contain=t.values=t.extendOwn=t.clone=t.extend=t.defaults=t.createAssigner=t.each=t.isArrLike=t.isNum=t.isMiniProgram=t.isFn=t.isDate=t.safeGet=t.castPath=t.isArr=t.isArgs=t.objToStr=t.identity=t.getObjType=t.upperFirst=t.fullUrl=t.fileSize=t.escapeRegExp=t.escapeJsonStr=t.escapeJsStr=t.escape=t.endWith=t.optimizeCb=t.detectOs=t.freeze=t.keys=t.detectMocha=t.root=t.utf8=t.ucs2=t.toStr=t.idxOf=t.clamp=t.chunk=t.kebabCase=t.camelCase=t.splitCase=t.before=t.allKeys=t.noop=t.isBrowser=t.slice=t.has=t.inherits=t.isObj=t.isUndef=t.last=void 0;var o=n(28),a=i(o),s=n(123),u=i(s),l=n(66),c=i(l),d=n(34),f=i(d),p=n(130),h=i(p),v=n(35),g=i(v),m=n(135),_=i(m),b=n(73),y=i(b),x=n(25),w=i(x),k={},E=k.last=function(){function e(e){var t=e?e.length:0;if(t)return e[t-1]}return e}();t.last=E;var S=t.isUndef=k.isUndef=function(){function e(e){return void 0===e}return e}(),T=t.isObj=k.isObj=function(){function e(e){var t=void 0===e?"undefined":(0,w.default)(e);return!!e&&("function"===t||"object"===t)}return e}(),O=t.inherits=k.inherits=function(){function e(e,r){if(n)return e.prototype=n(r.prototype);t.prototype=r.prototype,e.prototype=new t}function t(){}var n=y.default;return e}(),A=t.has=k.has=function(){function e(e,n){return t.call(e,n)}var t=Object.prototype.hasOwnProperty;return e}(),C=t.slice=k.slice=function(){function e(e,t,n){var r=e.length;t=null==t?0:t<0?Math.max(r+t,0):Math.min(t,r),n=null==n?r:n<0?Math.max(r+n,0):Math.min(n,r);for(var i=[];t0&&(n=t.apply(this,arguments)),e<=1&&(t=null),n}}return e}(),L=t.splitCase=k.splitCase=function(){function e(e){return e=e.replace(t,"-$1").toLowerCase().replace(n,"-").replace(r,""),e.split("-")}var t=/([A-Z])/g,n=/[_.\- ]+/g,r=/(^-)|(-$)/g;return e}(),N=t.camelCase=k.camelCase=function(){function e(e){var n=L(e),r=n[0];return n.shift(),n.forEach(t,n),r+=n.join("")}function t(e,t){this[t]=e.replace(/\w/,function(e){return e.toUpperCase()})}return e}(),D=t.kebabCase=k.kebabCase=function(){function e(e){return L(e).join("-")}return e}(),I=(t.chunk=k.chunk=function(){function e(e,t){var n=[];t=t||1;for(var r=0,i=Math.ceil(e.length/t);rn?n:e}return e}()),K=t.idxOf=k.idxOf=function(){function e(e,t,n){return Array.prototype.indexOf.call(e,t,n)}return e}(),z=t.toStr=k.toStr=function(){function e(e){return null==e?"":e.toString()}return e}(),F=t.ucs2=k.ucs2=function(e){return{encode:function(e){return _.default.apply(String,e)},decode:function(e){for(var t=[],n=0,r=e.length;n=55296&&i<=56319&&n>6*t)+n);t>0;){r+=f(128|63&e>>6*(t-1)),t--}return r}function n(e){for(;;){if(o>=a&&l){if(e)return r();throw new Error("Invalid byte index")}if(o===a)return!1;var t=i[o];if(o++,l){if(td){if(e)return o--,r();throw new Error("Invalid continuation byte")}if(c=128,d=191,s=s<<6|63&t,++u===l){var n=s;return s=0,l=0,u=0,n}}else{if(0==(128&t))return t;if(192==(224&t))l=1,s=31&t;else if(224==(240&t))224===t&&(c=160),237===t&&(d=159),l=2,s=15&t;else{if(240!=(248&t)){if(e)return r();throw new Error("Invalid UTF-8 detected")}240===t&&(c=144),244===t&&(d=143),l=3,s=7&t}}}}function r(){var e=o-u-1;return o=e+1,s=0,l=0,u=0,c=128,d=191,i[e]}e={encode:function(e){for(var n=F.decode(e),r="",i=0,o=n.length;i-1}return e=e||(j?navigator.userAgent:""),e=e.toLowerCase(),t("windows phone")?"windows phone":t("win")?"windows":t("android")?"android":t("ipad")||t("iphone")||t("ipod")?"ios":t("mac")?"os x":t("linux")?"linux":"unknown"}return e}(),t.optimizeCb=k.optimizeCb=function(){function e(e,t,n){if(S(t))return e;switch(null==n?3:n){case 1:return function(n){return e.call(t,n)};case 3:return function(n,r,i){return e.call(t,n,r,i)};case 4:return function(n,r,i,o){return e.call(t,n,r,i,o)}}return function(){return e.apply(t,arguments)}}return e}()),G=(t.endWith=k.endWith=function(){function e(e,t){var n=e.length-t.length;return n>=0&&e.indexOf(t,n)===n}return e}(),t.escape=k.escape=function(){function e(e){return i.test(e)?e.replace(o,t):e}function t(e){return n[e]}var n=e.map={"&":"&","<":"<",">":">",'"':""","'":"'","`":"`"},r="(?:"+U(n).join("|")+")",i=new RegExp(r),o=new RegExp(r,"g");return e}(),t.escapeJsStr=k.escapeJsStr=function(){function e(e){return z(e).replace(t,function(e){switch(e){case'"':case"'":case"\\":return"\\"+e;case"\n":return"\\n";case"\r":return"\\r";case"\u2028":return"\\u2028";case"\u2029":return"\\u2029"}})}var t=/["'\\\n\r\u2028\u2029]/g;return e}()),q=(t.escapeJsonStr=k.escapeJsonStr=function(){function e(e){return G(e).replace(/\\'/g,"'").replace(/\t/g,"\\t")}return e}(),t.escapeRegExp=k.escapeRegExp=function(){function e(e){return e.replace(/\W/g,"\\$&")}return e}(),t.fileSize=k.fileSize=function(){function e(e){if(e<=0)return"0";var n=Math.floor(Math.log(e)/Math.log(1024));return+(e/Math.pow(2,10*n)).toFixed(2)+t[n]}var t=["","K","M","G","T"];return e}(),t.fullUrl=k.fullUrl=function(){function e(e){return t.href=e,t.protocol+"//"+t.host+t.pathname+t.search+t.hash}var t=document.createElement("a");return e}(),t.upperFirst=k.upperFirst=function(){function e(e){return e.length<1?e:e[0].toUpperCase()+e.slice(1)}return e}()),J=(t.getObjType=k.getObjType=function(){function e(e){return e.constructor&&e.constructor.name?e.constructor.name:q({}.toString.call(e).replace(/(\[object )|]/g,""))}return e}(),t.identity=k.identity=function(){function e(e){return e}return e}()),Y=t.objToStr=k.objToStr=function(){function e(e){return t.call(e)}var t=Object.prototype.toString;return e}(),Q=t.isArgs=k.isArgs=function(){function e(e){return"[object Arguments]"===Y(e)}return e}(),X=t.isArr=k.isArr=function(e){return Array.isArray||function(e){return"[object Array]"===Y(e)}}(),Z=t.castPath=k.castPath=function(){function e(e,r){if(X(e))return e;if(r&&A(r,e))return[e];var i=[];return e.replace(t,function(e,t,r,o){i.push(r?o.replace(n,"$1"):t||e)}),i}var t=/[^.[\]]+|\[(?:(-?\d+(?:\.\d+)?)|(["'])((?:(?!\2)[^\\]|\\.)*?)\2)\]|(?=(?:\.|\[\])(?:\.|\[\]|$))/g,n=/\\(\\)?/g;return e}(),ee=t.safeGet=k.safeGet=function(){function e(e,t){t=Z(t,e);var n;for(n=t.shift();!S(n);){if(null==(e=e[n]))return;n=t.shift()}return e}return e}(),te=t.isDate=k.isDate=function(){function e(e){return"[object Date]"===Y(e)}return e}(),ne=t.isFn=k.isFn=function(){function e(e){var t=Y(e);return"[object Function]"===t||"[object GeneratorFunction]"===t}return e}(),re=t.isMiniProgram=k.isMiniProgram=function(e){return"undefined"!=typeof wx&&ne(wx.openLocation)}(),ie=t.isNum=k.isNum=function(){function e(e){return"[object Number]"===Y(e)}return e}(),oe=t.isArrLike=k.isArrLike=function(){function e(e){if(!e)return!1;var n=e.length;return ie(n)&&n>=0&&n<=t&&!ne(e)}var t=Math.pow(2,53)-1;return e}(),ae=k.each=function(){function e(e,t,n){t=W(t,n);var r,i;if(oe(e))for(r=0,i=e.length;r=0}return e}(),he=t.isStr=k.isStr=function(){function e(e){return"[object String]"===Y(e)}return e}(),ve=t.isEmpty=k.isEmpty=function(){function e(e){return null==e||(oe(e)&&(X(e)||he(e)||Q(e))?0===e.length:0===U(e).length)}return e}(),ge=(t.isBool=k.isBool=function(){function e(e){return!0===e||!1===e}return e}(),t.startWith=k.startWith=function(){function e(e,t){return 0===e.indexOf(t)}return e}()),me=(t.isCrossOrig=k.isCrossOrig=function(){function e(e){return!ge(e,t)}var t=window.location.origin;return e}(),t.isEl=k.isEl=function(){function e(e){return!(!e||1!==e.nodeType)}return e}(),t.isErr=k.isErr=function(){function e(e){return"[object Error]"===Y(e)}return e}(),t.isErudaEl=k.isErudaEl=function(){function e(e){var t=e.parentNode;if(!t)return!1;for(;t;)if((t=t.parentNode)&&"eruda"===t.id)return!0;return!1}return e}(),t.isMatch=k.isMatch=function(){function e(e,t){var n=U(t),r=n.length;if(null==e)return!r;e=Object(e);for(var i=0;i0;)1&t&&(n+=e),t>>=1,e+=e;return n}}()),Se=t.lpad=k.lpad=function(){function e(e,t,n){e=z(e);var r=e.length;return n=n||" ",r0?"-":"+")+t(100*Math.floor(Math.abs(y)/60)+Math.abs(y)%60,4),S:["th","st","nd","rd"][f%10>3?0:(f%100-f%10!=10)*f%10]};return s.replace(n,function(e){return e in x?x[e]:e.slice(1,e.length-1)})}function t(e,t){return Se(z(e),t||2,"0")}var n=/d{1,4}|m{1,4}|yy(?:yy)?|([HhMsTt])\1?|[LloSZWN]|'[^']*'|'[^']*'/g,r=/\b(?:[PMCEA][SDP]T|(?:Pacific|Mountain|Central|Eastern|Atlantic) (?:Standard|Daylight|Prevailing) Time|(?:GMT|UTC)(?:[-+]\d{4})?)\b/g,i=/\d/,o=/[^-+\dA-Z]/g;return e.masks={default:"ddd mmm dd yyyy HH:MM:ss",shortDate:"m/d/yy",mediumDate:"mmm d, yyyy",longDate:"mmmm d, yyyy",fullDate:"dddd, mmmm d, yyyy",shortTime:"h:MM TT",mediumTime:"h:MM:ss TT",longTime:"h:MM:ss TT Z",isoDate:"yyyy-mm-dd",isoTime:"HH:MM:ss",isoDateTime:"yyyy-mm-dd'T'HH:MM:sso",isoUtcDateTime:"UTC:yyyy-mm-dd'T'HH:MM:ss'Z'",expiresHeaderFormat:"ddd, dd mmm yyyy HH:MM:ss Z"},e.i18n={dayNames:["Sun","Mon","Tue","Wed","Thu","Fri","Sat","Sunday","Monday","Tuesday","Wednesday","Thursday","Friday","Saturday"],monthNames:["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec","January","February","March","April","May","June","July","August","September","October","November","December"]},e}(),t.ltrim=k.ltrim=function(){function e(e,n){if(null==n)return e.replace(t,"");for(var r,i,o=0,a=e.length,s=n.length,u=!0;u&&o=a?"":e.substr(o,a)}var t=/^\s+/;return e}()),Oe=t.matcher=k.matcher=function(){function e(e){return e=de({},e),function(t){return me(t,e)}}return e}(),Ae=t.safeCb=k.safeCb=function(e){return function(e,t,n){return null==e?J:ne(e)?W(e,t,n):T(e)?Oe(e):function(e){return function(t){return null==t?void 0:t[e]}}}}(),Ce=t.filter=k.filter=function(){function e(e,t,n){var r=[];return t=Ae(t,n),ae(e,function(e,n,i){t(e,n,i)&&r.push(e)}),r}return e}(),je=(t.evalCss=k.evalCss=function(){function e(r,i){r=z(r);for(var o=0,a=n.length;o=0&&e=t[n[s]]){a=n[s];break}return+(o/t[a]).toFixed(2)+a}var t={ms:1,s:1e3};t.m=60*t.s,t.h=60*t.m,t.d=24*t.h,t.y=365.25*t.d;var n=["y","d","h","m","s"],r=/^((?:\d+)?\.?\d+) *(s|m|h|d|y)?$/;return e}(),t.toInt=k.toInt=function(){function e(e){return e?(e=et(e))-e%1:0===e?e:0}return e}()),nt=(t.detectBrowser=k.detectBrowser=function(){function e(e){e=e||(j?navigator.userAgent:""),e=e.toLowerCase();var o=t(e,"msie ");if(o)return{version:o,name:"ie"};if(r.test(e))return{version:11,name:"ie"};for(var a=0,s=i.length;a-1)return tt(e.substring(n+t.length,e.indexOf(".",n)))}var n={edge:/edge\/([0-9._]+)/,firefox:/firefox\/([0-9.]+)(?:\s|$)/,opera:/opera\/([0-9.]+)(?:\s|$)/,android:/android\s([0-9.]+)/,ios:/version\/([0-9._]+).*mobile.*safari.*/,safari:/version\/([0-9._]+).*safari/,chrome:/(?!chrom.*opr)chrom(?:e|ium)\/([0-9.]+)(:?\s|$)/},r=/trident\/7\./,i=U(n);return e}(),t.nextTick=k.nextTick=function(e){function t(e){if("function"!=typeof e)throw new TypeError(e+" is not a function");return e}return"object"===(void 0===r?"undefined":(0,w.default)(r))&&r.nextTick?r.nextTick:"function"==typeof u.default?function(e){(0,u.default)(t(e))}:function(e){setTimeout(t(e),0)}}(),t.now=k.now=function(e){return Date.now||function(){return(new Date).getTime()}}()),rt=t.restArgs=k.restArgs=function(){function e(e,t){return t=null==t?e.length-1:+t,function(){var n,r=Math.max(arguments.length-t,0),i=new Array(r);for(n=0;nwindow.innerHeight?"landscape":"portrait"}},at.mixin(e),window.addEventListener("orientationchange",function(){setTimeout(function(){e.emit("change",e.get())},200)},!1),e}({}),t.perfNow=k.perfNow=function(e){var t,n=H.performance,r=H.process;if(n&&n.now)e=function(){return n.now()};else if(r&&r.hrtime){var i=function(){var e=r.hrtime();return 1e9*e[0]+e[1]};t=i()-1e9*r.uptime(),e=function(){return(i()-t)/1e6}}else t=nt(),e=function(){return nt()-t};return e}({}),t.pxToNum=k.pxToNum=function(){function e(e){return et(e.replace("px",""))}return e}(),t.rmCookie=k.rmCookie=function(){function e(e){function t(t){return t=t||{},Pe.remove(e,t),!Pe.get(e)}var n,r=window.location,i=r.hostname,o=r.pathname,a=i.split("."),s=o.split("/"),u="",l=s.length;if(!t())for(var c=a.length-1;c>=0;c--){var d=a[c];if(""!==d){if(u=""===u?d:d+"."+u,n="/",t({domain:u,path:n})||t({domain:u}))return;for(var f=0;f=0;)for(s=!1,r=-1,i=e.charAt(o);++r=0?e.substring(0,o+1):""}var t=/\s+$/;return e}()),lt=t.trim=k.trim=function(){function e(e,n){return null==n?e.replace(t,""):Te(ut(e,n),n)}var t=/^\s+|\s+$/g;return e}(),ct=(t.getFileName=k.getFileName=function(){function e(e){var t=E(e.split("/"));return t.indexOf("?")>-1&&(t=lt(t.split("?")[0])),""===t?"unknown":t}return e}(),t.query=k.query=function(e){e={parse:function(e){var n={};return e=lt(e).replace(t,""),ae(e.split("&"),function(e){var t=e.split("="),r=t.shift(),i=t.length>0?t.join("="):null;r=decodeURIComponent(r),i=decodeURIComponent(i),S(n[r])?n[r]=i:X(n[r])?n[r].push(i):n[r]=[n[r],i]}),n},stringify:function(t,n){return Ce(je(t,function(t,r){return T(t)&&ve(t)?"":X(t)?e.stringify(t,r):(n?encodeURIComponent(n):encodeURIComponent(r))+"="+encodeURIComponent(t)}),function(e){return e.length>0}).join("&")}};var t=/^(\?|#|&)/g;return e}({})),dt=(t.Url=k.Url=function(e){e=Le({className:"Url",initialize:function(t){!t&&j&&(t=window.location.href),le(this,e.parse(t||""))},setQuery:function(e,t){var n=this.query;return T(e)?ae(e,function(e,t){n[t]=e}):n[e]=t,this},rmQuery:function(e){var t=this.query;return X(e)||(e=Re(e)),ae(e,function(e){delete t[e]}),this},toString:function(){return e.stringify(this)}},{parse:function(e){var i={protocol:"",auth:"",hostname:"",hash:"",query:{},port:"",pathname:"",slashes:!1},o=lt(e),a=o.match(t);if(a&&(a=a[0],i.protocol=a.toLowerCase(),o=o.substr(a.length)),a){var s="//"===o.substr(0,2);s&&(o=o.slice(2),i.slashes=!0)}if(s){for(var u=-1,l=0,c=r.length;l=200&&t<300||304===t){e=f.responseText,"xml"===s&&(e=f.responseXML);try{"json"===s&&(e=JSON.parse(e))}catch(e){}u(e,f)}else l(f);d(f)}},"GET"===r?(o=ct.stringify(o),i+=i.indexOf("?")>-1?"&"+o:"?"+o):"application/x-www-form-urlencoded"===t.contentType?T(o)&&(o=ct.stringify(o)):"application/json"===t.contentType&&T(o)&&(o=(0,a.default)(o)),f.open(r,i,!0),f.setRequestHeader("Content-Type",t.contentType),c>0&&(n=setTimeout(function(){f.onreadystatechange=M,f.abort(),l(f,"timeout"),d(f)},c)),f.send("GET"===r?null:o),f}function t(e,t,n,r){return ne(t)&&(r=n,n=t,t={}),{url:e,data:t,success:n,dataType:r}}return e.setting={type:"GET",success:M,error:M,complete:M,dataType:"json",contentType:"application/x-www-form-urlencoded",data:{},xhr:function(){return new XMLHttpRequest},timeout:0},e.get=function(){return e(t.apply(null,arguments))},e.post=function(){var n=t.apply(null,arguments);return n.type="POST",e(n)},e}(),t.type=k.type=function(){function e(e){if(null===e)return"null";if(void 0===e)return"undefined";if(ye(e))return"nan";var n=Y(e).match(t);return n?n[1].toLowerCase():""}var t=/^\[object\s+(.*?)]$/;return e}()),ft=t.stringify=k.stringify=function(){function e(e,n){return(0,a.default)(e,t(),n)}function t(){var e=[],t=[];return function(n,r){if(e.length>0){var i=e.indexOf(this);i>-1?(e.splice(i+1),t.splice(i,1/0,n)):(e.push(this),t.push(n));var o=e.indexOf(r);o>-1&&(r=e[0]===r?"[Circular ~]":"[Circular ~."+t.slice(0,o).join(".")+"]")}else e.push(r);return ke(r)||ne(r)?r="["+q(dt(r))+" "+z(r)+"]":S(r)&&(r=null),r}}return e}();t.LocalStore=k.LocalStore=function(e){var t=Xe("local");return st.extend({initialize:function(e,n){this._name=e;var r=t.getItem(e);try{r=JSON.parse(r)}catch(e){r={}}T(r)||(r={}),n=ue(r,n),this.callSuper(st,"initialize",[n])},save:function(e){if(ve(e))return t.removeItem(this._name);t.setItem(this._name,ft(e))}})}(),t.stripHtmlTag=k.stripHtmlTag=function(){function e(e){return e.replace(t,"")}var t=/<[^>]*>/g;return e}(),t.tryIt=k.tryIt=function(){function e(e,t){t=t||M;try{t(null,e())}catch(e){return void t(e)}}return e}(),t.uniqId=k.uniqId=function(){function e(e){var n=++t+"";return e?e+n:n}var t=0;return e}(),t.unique=k.unique=function(){function e(e,n){return n=n||t,Ce(e,function(e,t,r){for(var i=r.length;++t= 2.0.0-beta.1",7:">= 4.0.0"};t.REVISION_CHANGES=f;r.prototype={constructor:r,logger:d.default,log:d.default.log,registerHelper:function(e,t){if("[object Object]"===o.toString.call(e)){if(t)throw new s.default("Arg not supported with multiple helpers");o.extend(this.helpers,e)}else this.helpers[e]=t},unregisterHelper:function(e){delete this.helpers[e]},registerPartial:function(e,t){if("[object Object]"===o.toString.call(e))o.extend(this.partials,e);else{if(void 0===t)throw new s.default('Attempting to register a partial called "'+e+'" as undefined');this.partials[e]=t}},unregisterPartial:function(e){delete this.partials[e]},registerDecorator:function(e,t){if("[object Object]"===o.toString.call(e)){if(t)throw new s.default("Arg not supported with multiple decorators");o.extend(this.decorators,e)}else this.decorators[e]=t},unregisterDecorator:function(e){delete this.decorators[e]}};var p=d.default.log;t.log=p,t.createFrame=o.createFrame,t.logger=d.default},function(e,t){"use strict";function n(e){return c[e]}function r(e){for(var t=1;t":">",'"':""","'":"'","`":"`","=":"="},d=/[&<>"'`=]/g,f=/[&<>"'`=]/,p=Object.prototype.toString;t.toString=p;var h=function(e){return"function"==typeof e};h(/x/)&&(t.isFunction=h=function(e){return"function"==typeof e&&"[object Function]"===p.call(e)}),t.isFunction=h;var v=Array.isArray||function(e){return!(!e||"object"!=typeof e)&&"[object Array]"===p.call(e)};t.isArray=v},function(e,t,n){"use strict";function r(e,t){var n=t&&t.loc,a=void 0,s=void 0;n&&(a=n.start.line,s=n.start.column,e+=" - "+a+":"+s);for(var u=Error.prototype.constructor.call(this,e),l=0;l0?(n.ids&&(n.ids=[n.name]),e.helpers.each(t,n)):i(this);if(n.data&&n.ids){var a=r.createFrame(n.data);a.contextPath=r.appendContextPath(n.data.contextPath,n.name),n={data:a}}return o(t,n)})},e.exports=t.default},function(e,t,n){"use strict";var r=n(2).default;t.__esModule=!0;var i=n(4),o=n(5),a=r(o);t.default=function(e){e.registerHelper("each",function(e,t){function n(t,n,o){l&&(l.key=t,l.index=n,l.first=0===n,l.last=!!o,c&&(l.contextPath=c+t)),u+=r(e[t],{data:l,blockParams:i.blockParams([e[t],t],[c+t,null])})}if(!t)throw new a.default("Must pass iterator to #each");var r=t.fn,o=t.inverse,s=0,u="",l=void 0,c=void 0;if(t.data&&t.ids&&(c=i.appendContextPath(t.data.contextPath,t.ids[0])+"."),i.isFunction(e)&&(e=e.call(this)),t.data&&(l=i.createFrame(t.data)),e&&"object"==typeof e)if(i.isArray(e))for(var d=e.length;s=0?t:parseInt(e,10)}return e},log:function(e){if(e=i.lookupLevel(e),"undefined"!=typeof console&&i.lookupLevel(i.level)<=e){var t=i.methodMap[e];console[t]||(t="log");for(var n=arguments.length,r=Array(n>1?n-1:0),o=1;o3&&void 0!==arguments[3]?arguments[3]:["#2196f3","#707d8b","#f44336","#009688","#ffc107"],i=this._genId("settings");return this._settings.push({config:e,key:t,id:i}),this._$el.append(this._colorTpl({desc:n,colors:r,id:i,val:e.get(t)})),this}},{key:"select",value:function(e,t,n,r){var i=this._genId("settings");return this._settings.push({config:e,key:t,id:i}),this._$el.append(this._selectTpl({desc:n,selections:r,id:i,val:e.get(t)})),this}},{key:"range",value:function(e,t,n,r){var i=r.min,o=void 0===i?0:i,a=r.max,s=void 0===a?1:a,u=r.step,l=void 0===u?.1:u,c=this._genId("settings");this._settings.push({config:e,key:t,min:o,max:s,step:l,id:c});var d=e.get(t);return this._$el.append(this._rangeTpl({desc:n,min:o,max:s,step:l,val:d,progress:y(d,o,s),id:c})),this}},{key:"separator",value:function(){return this._$el.append('
'),this}},{key:"text",value:function(e){return this._$el.append('
'+e+"
"),this}},{key:"_cleanSeparator",value:function(){function e(e){return"eruda-separator"===e.getAttribute("class")}for(var t=(0,_.clone)(this._$el.get(0).children),n=0,r=t.length;n0?r:n)(e)}},function(e,t,n){var r=n(44)("keys"),i=n(31);e.exports=function(e){return r[e]||(r[e]=i(e))}},function(e,t,n){var r=n(11),i=r["__core-js_shared__"]||(r["__core-js_shared__"]={});e.exports=function(e){return i[e]||(i[e]={})}},function(e,t){e.exports="constructor,hasOwnProperty,isPrototypeOf,propertyIsEnumerable,toLocaleString,toString,valueOf".split(",")},function(e,t){t.f=Object.getOwnPropertySymbols},function(e,t,n){"use strict";var r=n(96)(!0);n(67)(String,"String",function(e){this._t=String(e),this._i=0},function(){var e,t=this._t,n=this._i;return n>=t.length?{value:void 0,done:!0}:(e=r(t,n),this._i+=e.length,{value:e,done:!1})})},function(e,t){e.exports=!0},function(e,t,n){var r=n(21),i=n(98),o=n(45),a=n(43)("IE_PROTO"),s=function(){},u=function(){var e,t=n(39)("iframe"),r=o.length;for(t.style.display="none",n(69).appendChild(t),t.src="javascript:",e=t.contentWindow.document,e.open(),e.write(" + + + + + + + + + + + +
+ +
+
+
+ +
+
+
+ 拟声鸟工具箱 +
+ +
+
+ +
+
1. 请输入中文
+ +
+
+ +
2. 请直接录音,点击停止结束
+ + + +
+
+
或上传音频
+ + +
+
+
+
3. 选择Synthesizer模型
+ + + +
+
+
4. 选择Vocoder模型
+ + + +
+
+ +
+ + + + +
+
+
+ + + + + + + + + + + + + + + \ No newline at end of file