diff --git a/.ipynb_checkpoints/requirements-checkpoint.txt b/.ipynb_checkpoints/requirements-checkpoint.txt
new file mode 100644
index 0000000000000000000000000000000000000000..976247a9d4f771489046246c2f07767b41a84721
--- /dev/null
+++ b/.ipynb_checkpoints/requirements-checkpoint.txt
@@ -0,0 +1,17 @@
+diffusers
+einops
+fastdtw
+librosa
+matplotlib
+music21
+numpy
+pandas
+pretty_midi
+pysptk
+pyworld
+scipy
+soundfile
+tgt
+torch
+torchaudio
+tqdm
diff --git a/.ipynb_checkpoints/score_based_apc-checkpoint.py b/.ipynb_checkpoints/score_based_apc-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9259d641045d7bfdb6d5aede00f6cebcc28f477
--- /dev/null
+++ b/.ipynb_checkpoints/score_based_apc-checkpoint.py
@@ -0,0 +1,159 @@
+import os.path
+
+import numpy as np
+import pandas as pd
+import torch
+import yaml
+import librosa
+import soundfile as sf
+from tqdm import tqdm
+
+from diffusers import DDIMScheduler
+from pitch_controller.models.unet import UNetPitcher
+from pitch_controller.utils import minmax_norm_diff, reverse_minmax_norm_diff
+from pitch_controller.modules.BigVGAN.inference import load_model
+from utils import get_mel, get_world_mel, get_f0, f0_to_coarse, show_plot, get_matched_f0, log_f0
+from pitch_predictor.models.transformer import PitchFormer
+import pretty_midi
+
+
+def prepare_midi_wav(wav_id, midi_id, sr=24000):
+ midi = pretty_midi.PrettyMIDI(midi_id)
+ roll = midi.get_piano_roll()
+ roll = np.pad(roll, ((0, 0), (0, 1000)), constant_values=0)
+ roll[roll > 0] = 100
+
+ onset = midi.get_onsets()
+ before_onset = list(np.round(onset * 100 - 1).astype(int))
+ roll[:, before_onset] = 0
+
+ wav, sr = librosa.load(wav_id, sr=sr)
+
+ start = 0
+ end = round(100 * len(wav) / sr) / 100
+ # save audio
+ wav_seg = wav[round(start * sr):round(end * sr)]
+ cur_roll = roll[:, round(100 * start):round(100 * end)]
+ return wav_seg, cur_roll
+
+
+def algin_mapping(content, target_len):
+ # align content with mel
+ src_len = content.shape[-1]
+ target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
+ temp = torch.arange(src_len+1) * target_len / src_len
+
+ for i in range(target_len):
+ cur_idx = torch.argmin(torch.abs(temp-i))
+ target[:, i] = content[:, cur_idx]
+ return target
+
+
+def midi_to_hz(midi):
+ idx = torch.zeros(midi.shape[-1])
+ for frame in range(midi.shape[-1]):
+ midi_frame = midi[:, frame]
+ non_zero = midi_frame.nonzero()
+ if len(non_zero) != 0:
+ hz = librosa.midi_to_hz(non_zero[0])
+ idx[frame] = torch.tensor(hz)
+ return idx
+
+
+@torch.no_grad()
+def score_pitcher(source, pitch_ref, model, hifigan, pitcher, steps=50, shift_semi=0, mask_with_source=False):
+ wav, midi = prepare_midi_wav(source, pitch_ref, sr=sr)
+
+ source_mel = get_world_mel(None, sr=sr, wav=wav)
+
+ midi = torch.tensor(midi, dtype=torch.float32)
+ midi = algin_mapping(midi, source_mel.shape[-1])
+ midi = midi_to_hz(midi)
+
+ f0_ori = np.nan_to_num(get_f0(source))
+
+ source_mel = torch.from_numpy(source_mel).float().unsqueeze(0).to(device)
+ f0_ori = torch.from_numpy(f0_ori).float().unsqueeze(0).to(device)
+ midi = midi.unsqueeze(0).to(device)
+
+ f0_pred = pitcher(sp=source_mel, midi=midi)
+ if mask_with_source:
+ # mask unvoiced frames based on original pitch estimation
+ f0_pred[f0_ori == 0] = 0
+ f0_pred = f0_pred.cpu().numpy()[0]
+ # limit range
+ f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0
+ f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6')
+
+ f0_pred = f0_pred * (2 ** (shift_semi / 12))
+
+ f0_pred = log_f0(f0_pred, {'f0_bin': 345,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C#6')})
+ f0_pred = torch.from_numpy(f0_pred).float().unsqueeze(0).to(device)
+
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000)
+ generator = torch.Generator(device=device).manual_seed(2024)
+
+ noise_scheduler.set_timesteps(steps)
+ noise = torch.randn(source_mel.shape, generator=generator, device=device)
+ pred = noise
+ source_x = minmax_norm_diff(source_mel, vmax=max_mel, vmin=min_mel)
+
+ for t in tqdm(noise_scheduler.timesteps):
+ pred = noise_scheduler.scale_model_input(pred, t)
+ model_output = model(x=pred, mean=source_x, f0=f0_pred, t=t, ref=None, embed=None)
+ pred = noise_scheduler.step(model_output=model_output,
+ timestep=t,
+ sample=pred,
+ eta=1, generator=generator).prev_sample
+
+ pred = reverse_minmax_norm_diff(pred, vmax=max_mel, vmin=min_mel)
+
+ pred_audio = hifigan(pred)
+ pred_audio = pred_audio.cpu().squeeze().clamp(-1, 1)
+
+ return pred_audio
+
+
+if __name__ == '__main__':
+ min_mel = np.log(1e-5)
+ max_mel = 2.5
+ sr = 24000
+
+ use_gpu = torch.cuda.is_available()
+ device = 'cuda' if use_gpu else 'cpu'
+
+ # load diffusion model
+ config = yaml.load(open('pitch_controller/config/DiffWorld_24k.yaml'), Loader=yaml.FullLoader)
+ mel_cfg = config['logmel']
+ ddpm_cfg = config['ddpm']
+ unet_cfg = config['unet']
+ model = UNetPitcher(**unet_cfg)
+ unet_path = 'ckpts/world_fixed_40.pt'
+
+ state_dict = torch.load(unet_path)
+ for key in list(state_dict.keys()):
+ state_dict[key.replace('_orig_mod.', '')] = state_dict.pop(key)
+ model.load_state_dict(state_dict)
+ if use_gpu:
+ model.cuda()
+ model.eval()
+
+ # load vocoder
+ hifi_path = 'ckpts/bigvgan_24khz_100band/g_05000000.pt'
+ hifigan, cfg = load_model(hifi_path, device=device)
+ hifigan.eval()
+
+ # load pitch predictor
+ pitcher = PitchFormer(100, 512).to(device)
+ ckpt = torch.load('ckpts/ckpt_transformer_pitch/transformer_pitch_360.pt')
+ pitcher.load_state_dict(ckpt)
+ pitcher.eval()
+
+ pred_audio = score_pitcher('examples/score_vocal.wav', 'examples/score_midi.midi', model, hifigan, pitcher, steps=50)
+ sf.write('output_score.wav', pred_audio, samplerate=sr)
+
+
+
+
diff --git a/.ipynb_checkpoints/template_based_apc-checkpoint.py b/.ipynb_checkpoints/template_based_apc-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ecdf7a6d92d1e763656e003e4718ea2e5853d3d
--- /dev/null
+++ b/.ipynb_checkpoints/template_based_apc-checkpoint.py
@@ -0,0 +1,89 @@
+import os.path
+
+import numpy as np
+import pandas as pd
+import torch
+import yaml
+import librosa
+import soundfile as sf
+from tqdm import tqdm
+
+from diffusers import DDIMScheduler
+from pitch_controller.models.unet import UNetPitcher
+from pitch_controller.utils import minmax_norm_diff, reverse_minmax_norm_diff
+from pitch_controller.modules.BigVGAN.inference import load_model
+from utils import get_mel, get_world_mel, get_f0, f0_to_coarse, show_plot, get_matched_f0, log_f0
+
+
+@torch.no_grad()
+def template_pitcher(source, pitch_ref, model, hifigan, steps=50, shift_semi=0):
+
+ source_mel = get_world_mel(source, sr=sr)
+
+ f0_ref = get_matched_f0(source, pitch_ref, 'world')
+ f0_ref = f0_ref * 2 ** (shift_semi / 12)
+
+ f0_ref = log_f0(f0_ref, {'f0_bin': 345,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C#6')})
+
+ source_mel = torch.from_numpy(source_mel).float().unsqueeze(0).to(device)
+ f0_ref = torch.from_numpy(f0_ref).float().unsqueeze(0).to(device)
+
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000)
+ generator = torch.Generator(device=device).manual_seed(2024)
+
+ noise_scheduler.set_timesteps(steps)
+ noise = torch.randn(source_mel.shape, generator=generator, device=device)
+ pred = noise
+ source_x = minmax_norm_diff(source_mel, vmax=max_mel, vmin=min_mel)
+
+ for t in tqdm(noise_scheduler.timesteps):
+ pred = noise_scheduler.scale_model_input(pred, t)
+ model_output = model(x=pred, mean=source_x, f0=f0_ref, t=t, ref=None, embed=None)
+ pred = noise_scheduler.step(model_output=model_output,
+ timestep=t,
+ sample=pred,
+ eta=1, generator=generator).prev_sample
+
+ pred = reverse_minmax_norm_diff(pred, vmax=max_mel, vmin=min_mel)
+
+ pred_audio = hifigan(pred)
+ pred_audio = pred_audio.cpu().squeeze().clamp(-1, 1)
+
+ return pred_audio
+
+
+if __name__ == '__main__':
+ min_mel = np.log(1e-5)
+ max_mel = 2.5
+ sr = 24000
+
+ use_gpu = torch.cuda.is_available()
+ device = 'cuda' if use_gpu else 'cpu'
+
+ # load diffusion model
+ config = yaml.load(open('pitch_controller/config/DiffWorld_24k.yaml'), Loader=yaml.FullLoader)
+ mel_cfg = config['logmel']
+ ddpm_cfg = config['ddpm']
+ unet_cfg = config['unet']
+ model = UNetPitcher(**unet_cfg)
+ unet_path = 'ckpts/world_fixed_40.pt'
+
+ state_dict = torch.load(unet_path)
+ for key in list(state_dict.keys()):
+ state_dict[key.replace('_orig_mod.', '')] = state_dict.pop(key)
+ model.load_state_dict(state_dict)
+ if use_gpu:
+ model.cuda()
+ model.eval()
+
+ # load vocoder
+ hifi_path = 'ckpts/bigvgan_24khz_100band/g_05000000.pt'
+ hifigan, cfg = load_model(hifi_path, device=device)
+ hifigan.eval()
+
+ pred_audio = template_pitcher('examples/off-key.wav', 'examples/reference.wav', model, hifigan, steps=50, shift_semi=0)
+ sf.write('output_template.wav', pred_audio, samplerate=sr)
+
+
diff --git a/README.md b/README.md
index 7be5fc7f47d5db027d120b8024982df93db95b74..c9f28873837de10d64c45aed99bfda06c18d340d 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,86 @@
----
-license: mit
----
+
+
+# Diff-Pitcher (PyTorch)
+
+Official Pytorch Implementation of [Diff-Pitcher: Diffusion-based Singing Voice Pitch Correction](https://engineering.jhu.edu/lcap/data/uploads/pdfs/waspaa2023_hai.pdf)
+
+--------------------
+
+Thank you all for your interest in this research project. I am currently optimizing the model's performance and computation efficiency. I plan to release a user-friendly version, either a GUI or a VST, in the first half of this year, and will update the open-source license.
+
+If you are familiar with PyTorch, you can follow [Code Examples](#examples) to use Diff-Pitcher.
+
+--------------------
+
+Diff-Pitcher
+
+- [Demo Page](#demo)
+- [Todo List](#todo)
+- [Code Examples](#examples)
+- [References](#references)
+- [Acknowledgement](#acknowledgement)
+
+## Demo
+
+🎵 Listen to [examples](https://jhu-lcap.github.io/Diff-Pitcher/)
+
+## Todo
+- [x] Update codes and demo
+- [x] Support 🤗 [Diffusers](https://github.com/huggingface/diffusers)
+- [x] Upload checkpoints
+- [x] Pipeline tutorial
+- [ ] Merge to [Your-Stable-Audio](https://github.com/haidog-yaqub/Your-Stable-Audio)
+- [ ] Audio Plugin Support
+## Examples
+- Download checkpoints: 🎒[ckpts](https://github.com/haidog-yaqub/DiffPitcher/tree/main/ckpts)
+- Prepare environment: [requirements.txt](requirements.txt)
+- Feel free to try:
+ - template-based automatic pitch correction: [template_based_apc.py](template_based_apc.py)
+ - score-based automatic pitch correction: [score_based_apc.py](score_based_apc.py)
+
+
+## References
+
+If you find the code useful for your research, please consider citing:
+
+```bibtex
+@inproceedings{hai2023diff,
+ title={Diff-Pitcher: Diffusion-Based Singing Voice Pitch Correction},
+ author={Hai, Jiarui and Elhilali, Mounya},
+ booktitle={2023 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)},
+ pages={1--5},
+ year={2023},
+ organization={IEEE}
+}
+```
+
+This repo is inspired by:
+
+```bibtex
+@article{popov2021diffusion,
+ title={Diffusion-based voice conversion with fast maximum likelihood sampling scheme},
+ author={Popov, Vadim and Vovk, Ivan and Gogoryan, Vladimir and Sadekova, Tasnima and Kudinov, Mikhail and Wei, Jiansheng},
+ journal={arXiv preprint arXiv:2109.13821},
+ year={2021}
+}
+```
+```bibtex
+@inproceedings{liu2022diffsinger,
+ title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
+ author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Zhao, Zhou},
+ booktitle={Proceedings of the AAAI conference on artificial intelligence},
+ volume={36},
+ number={10},
+ pages={11020--11028},
+ year={2022}
+}
+```
+
+## Acknowledgement
+
+[Welcome to LCAP! < LCAP (jhu.edu)](https://engineering.jhu.edu/lcap/)
+
+We borrow code from following repos:
+
+ - `Diffusion Schedulers` are based on 🤗 [Diffusers](https://github.com/huggingface/diffusers)
+ - `2D UNet` is based on [DiffVC](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC)
diff --git a/examples/off-key.wav b/examples/off-key.wav
new file mode 100644
index 0000000000000000000000000000000000000000..9f4d731509f33531784f06a90ea7ecfb9f8e58a0
Binary files /dev/null and b/examples/off-key.wav differ
diff --git a/examples/reference.wav b/examples/reference.wav
new file mode 100644
index 0000000000000000000000000000000000000000..5b88dd80696b392e278ce005cba1a7a535ecd768
Binary files /dev/null and b/examples/reference.wav differ
diff --git a/examples/score_midi.midi b/examples/score_midi.midi
new file mode 100644
index 0000000000000000000000000000000000000000..aa32cd3eb5f4e5a55ea95ad04fc61922de25adef
Binary files /dev/null and b/examples/score_midi.midi differ
diff --git a/examples/score_midi.npy b/examples/score_midi.npy
new file mode 100644
index 0000000000000000000000000000000000000000..fe94b503a4dfa5370e8c30f6b5a1dfcd1b95128c
--- /dev/null
+++ b/examples/score_midi.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7baacba4afb8813d057e420cd63853657401403b6c798f6cb7f06673e7dcea5a
+size 559232
diff --git a/examples/score_vocal.wav b/examples/score_vocal.wav
new file mode 100644
index 0000000000000000000000000000000000000000..93979f26d61d29819bacd3ad91dd5ec2a4f3d5ec
Binary files /dev/null and b/examples/score_vocal.wav differ
diff --git a/output_score.wav b/output_score.wav
new file mode 100644
index 0000000000000000000000000000000000000000..1fdfe65dfcd230f209309782d636d5d1c9c8f3e0
Binary files /dev/null and b/output_score.wav differ
diff --git a/output_template.wav b/output_template.wav
new file mode 100644
index 0000000000000000000000000000000000000000..c3dbd20ce4ca06cba47d44e21478eb5f76ade960
Binary files /dev/null and b/output_template.wav differ
diff --git a/pitch_controller/README.md b/pitch_controller/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3cc28a7375b89a65bb54f54db6c1b0a3393457f9
--- /dev/null
+++ b/pitch_controller/README.md
@@ -0,0 +1 @@
+# Diffusion-based Pitch Controller
diff --git a/pitch_controller/__pycache__/utils.cpython-310.pyc b/pitch_controller/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9986b33f07576240fb93bbbb15f247cb6c1097e9
Binary files /dev/null and b/pitch_controller/__pycache__/utils.cpython-310.pyc differ
diff --git a/pitch_controller/config/DiffWorld_24k.yaml b/pitch_controller/config/DiffWorld_24k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bddae2a73780500bfb9783180a0823648dbe7631
--- /dev/null
+++ b/pitch_controller/config/DiffWorld_24k.yaml
@@ -0,0 +1,24 @@
+version: 1.0
+
+logmel:
+ n_mels: 100
+ sampling_rate: 24000
+ n_fft: 1024
+ hop_size: 256
+ max: 2.5
+ min: -12
+
+unet:
+ dim_base: 256
+ use_embed: False
+ dim_embed: None
+ use_ref_t: False
+ dim_cond: 128
+ dim_mults: [1, 2, 4]
+
+ddpm:
+ num_train_steps: 1000
+ inference_steps: 100
+ eta: 0.8
+
+
diff --git a/pitch_controller/data/example/f0/p225_001.wav.npy b/pitch_controller/data/example/f0/p225_001.wav.npy
new file mode 100644
index 0000000000000000000000000000000000000000..df726e935d388f9bff0757baad41a51b6928e7a8
--- /dev/null
+++ b/pitch_controller/data/example/f0/p225_001.wav.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8df28ae08ef686e7c7e523fdde25b62fbd05725cdacc043cde407a898182272f
+size 1672
diff --git a/pitch_controller/data/example/mel/p225_001.wav.npy b/pitch_controller/data/example/mel/p225_001.wav.npy
new file mode 100644
index 0000000000000000000000000000000000000000..5a9db81767bb94382793a71c6fd8d51e2585c6e4
--- /dev/null
+++ b/pitch_controller/data/example/mel/p225_001.wav.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8bf3c0e6956f57acdd82f5d91f6390ce148d89066faedbdd6f6ac8c48d1d2c76
+size 77328
diff --git a/pitch_controller/data/example/wav/p225_001.wav b/pitch_controller/data/example/wav/p225_001.wav
new file mode 100644
index 0000000000000000000000000000000000000000..bdc14bce583af4fb51e3382ec852e30e3dbf62f6
Binary files /dev/null and b/pitch_controller/data/example/wav/p225_001.wav differ
diff --git a/pitch_controller/data/example/world/p225_001.wav.npy b/pitch_controller/data/example/world/p225_001.wav.npy
new file mode 100644
index 0000000000000000000000000000000000000000..19bb6564a4c130e2373f00189e4f5601ddcd0030
--- /dev/null
+++ b/pitch_controller/data/example/world/p225_001.wav.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e00d5eb7fa9df26321df3f3df06e2ff44c3b3732cc5179ef135e41ffeb3a3b82
+size 77328
diff --git a/pitch_controller/data/prepare_f0.py b/pitch_controller/data/prepare_f0.py
new file mode 100644
index 0000000000000000000000000000000000000000..664f276509e810688c330c014806fea8a96f2e6f
--- /dev/null
+++ b/pitch_controller/data/prepare_f0.py
@@ -0,0 +1,66 @@
+# import amfm_decompy.basic_tools as basic
+# import amfm_decompy.pYAAPT as pYAAPT
+from multiprocessing import Process
+import os
+import numpy as np
+import pandas as pd
+import librosa
+from librosa.core import load
+from tqdm import tqdm
+
+
+def get_f0(wav_path):
+ wav, _ = load(wav_path, sr=24000)
+ wav = wav[:(wav.shape[0] // 256) * 256]
+ wav = np.pad(wav, 384, mode='reflect')
+ f0, _, _ = librosa.pyin(wav, frame_length=1024, hop_length=256, center=False,
+ fmin=librosa.note_to_hz('C2'),
+ fmax=librosa.note_to_hz('C6'))
+ return np.nan_to_num(f0)
+
+
+def chunks(arr, m):
+ result = [[] for i in range(m)]
+ for i in range(len(arr)):
+ result[i%m].append(arr[i])
+ return result
+
+
+def extract_f0(subset):
+ meta = pd.read_csv('../raw_data/meta_fix.csv')
+ meta = meta[meta['subset'] == 'train']
+ # meta = meta[meta['folder'] == 'VCTK-Corpus/vocal/']
+
+ for i in tqdm(subset):
+ line = meta.iloc[i]
+ audio_dir = '../raw_data/' + line['folder'] + line['subfolder']
+ f = line['file_name']
+
+ f0_dir = audio_dir.replace('vocal', 'f0').replace('raw_data/', '24k_data_f0/')
+
+ try:
+ np.load(os.path.join(f0_dir, f+'.npy'))
+ except:
+ print(line)
+ f0 = get_f0(os.path.join(audio_dir, f))
+ if os.path.exists(f0_dir) is False:
+ os.makedirs(f0_dir, exist_ok=True)
+ np.save(os.path.join(f0_dir, f + '.npy'), f0)
+
+ # if os.path.exists(os.path.join(f0_dir, f+'.npy')) is False:
+ # f0 = get_yaapt_f0(os.path.join(audio_dir, f))
+
+
+if __name__ == '__main__':
+ cores = 8
+ meta = pd.read_csv('../raw_data/meta_fix.csv')
+ meta = meta[meta['subset']=='train']
+ # meta = meta[meta['folder'] == 'VCTK-Corpus/vocal/']
+
+ idx_list = [i for i in range(len(meta))]
+
+ subsets = chunks(idx_list, cores)
+
+ for subset in subsets:
+ t = Process(target=extract_f0, args=(subset,))
+ t.start()
diff --git a/pitch_controller/data/prepare_mel.py b/pitch_controller/data/prepare_mel.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b441b64071040eed73fc4b9fe0688f9edb419dc
--- /dev/null
+++ b/pitch_controller/data/prepare_mel.py
@@ -0,0 +1,72 @@
+import os
+import numpy as np
+
+import librosa
+from librosa.core import load
+from librosa.filters import mel as librosa_mel_fn
+mel_basis = librosa_mel_fn(sr=24000, n_fft=1024, n_mels=100, fmin=0, fmax=12000)
+
+from tqdm import tqdm
+import pandas as pd
+
+from multiprocessing import Process
+
+
+# def get_f0(wav_path):
+# wav, _ = load(wav_path, sr=22050)
+# wav = wav[:(wav.shape[0] // 256) * 256]
+# wav = np.pad(wav, 384, mode='reflect')
+# f0, _, _ = librosa.pyin(wav, frame_length=1024, hop_length=256, center=False,
+# fmin=librosa.note_to_hz('C2'),
+# fmax=librosa.note_to_hz('C6'))
+# return np.nan_to_num(f0)
+
+def get_mel(wav_path):
+ wav, _ = load(wav_path, sr=24000)
+ wav = wav[:(wav.shape[0] // 256)*256]
+ wav = np.pad(wav, 384, mode='reflect')
+ stft = librosa.core.stft(wav, n_fft=1024, hop_length=256, win_length=1024, window='hann', center=False)
+ stftm = np.sqrt(np.real(stft) ** 2 + np.imag(stft) ** 2 + (1e-9))
+ mel_spectrogram = np.matmul(mel_basis, stftm)
+ log_mel_spectrogram = np.log(np.clip(mel_spectrogram, a_min=1e-5, a_max=None))
+ return log_mel_spectrogram
+
+
+def chunks(arr, m):
+ result = [[] for i in range(m)]
+ for i in range(len(arr)):
+ result[i%m].append(arr[i])
+ return result
+
+
+def extract_mel(subset):
+ meta = pd.read_csv('../raw_data/meta_fix.csv')
+ meta = meta[meta['folder'] == 'eval/vocal/']
+
+ for i in tqdm(subset):
+ line = meta.iloc[i]
+ audio_dir = '../raw_data/' + line['folder'] + line['subfolder']
+ f = line['file_name']
+
+ mel_dir = audio_dir.replace('vocal', 'mel').replace('raw_data/', '24k_data/')
+
+ if os.path.exists(os.path.join(mel_dir, f+'.npy')) is False:
+ mel = get_mel(os.path.join(audio_dir, f))
+ if os.path.exists(mel_dir) is False:
+ os.makedirs(mel_dir)
+ np.save(os.path.join(mel_dir, f+'.npy'), mel)
+
+
+if __name__ == '__main__':
+ cores = 8
+
+ meta = pd.read_csv('../raw_data/meta_fix.csv')
+ meta = meta[meta['folder'] == 'eval/vocal/']
+
+ idx_list = [i for i in range(len(meta))]
+
+ subsets = chunks(idx_list, cores)
+
+ for subset in subsets:
+ t = Process(target=extract_mel, args=(subset,))
+ t.start()
diff --git a/pitch_controller/data/prepare_world.py b/pitch_controller/data/prepare_world.py
new file mode 100644
index 0000000000000000000000000000000000000000..651f84a9c654b89b7bb365720fd16a2ba366067e
--- /dev/null
+++ b/pitch_controller/data/prepare_world.py
@@ -0,0 +1,85 @@
+from multiprocessing import Process
+import os
+import numpy as np
+
+import librosa
+from librosa.core import load
+from librosa.filters import mel as librosa_mel_fn
+mel_basis = librosa_mel_fn(sr=24000, n_fft=1024, n_mels=100, fmin=0, fmax=12000)
+
+from tqdm import tqdm
+import pandas as pd
+import pyworld as pw
+
+
+def get_world_mel(wav_path, sr=24000):
+ wav, _ = librosa.load(wav_path, sr=sr)
+ wav = (wav * 32767).astype(np.int16)
+ wav = (wav / 32767).astype(np.float64)
+ # wav = wav.astype(np.float64)
+ wav = wav[:(wav.shape[0] // 256) * 256]
+
+ _f0, t = pw.dio(wav, sr, frame_period=256/sr*1000)
+ f0 = pw.stonemask(wav, _f0, t, sr)
+ sp = pw.cheaptrick(wav, f0, t, sr)
+ ap = pw.d4c(wav, f0, t, sr)
+ wav_hat = pw.synthesize(f0 * 0, sp, ap, sr, frame_period=256/sr*1000)
+
+ # pyworld output does not pad left
+ wav_hat = wav_hat[:len(wav)]
+ # wav_hat = wav_hat[256//2: len(wav)+256//2]
+ assert len(wav_hat) == len(wav)
+ wav = wav_hat.astype(np.float32)
+ wav = np.pad(wav, 384, mode='reflect')
+ stft = librosa.core.stft(wav, n_fft=1024, hop_length=256, win_length=1024, window='hann', center=False)
+ stftm = np.sqrt(np.real(stft) ** 2 + np.imag(stft) ** 2 + (1e-9))
+ mel_spectrogram = np.matmul(mel_basis, stftm)
+ log_mel_spectrogram = np.log(np.clip(mel_spectrogram, a_min=1e-5, a_max=None))
+
+ return log_mel_spectrogram, f0
+
+
+def chunks(arr, m):
+ result = [[] for i in range(m)]
+ for i in range(len(arr)):
+ result[i%m].append(arr[i])
+ return result
+
+
+def extract_pw(subset, save_f0=False):
+ meta = pd.read_csv('../raw_data/meta_fix.csv')
+ meta = meta[meta['subset'] == 'train']
+
+ for i in tqdm(subset):
+ line = meta.iloc[i]
+ audio_dir = '../raw_data/' + line['folder'] + line['subfolder']
+ f = line['file_name']
+
+ mel_dir = audio_dir.replace('vocal', 'world').replace('raw_data/', '24k_data/')
+ f0_dir = audio_dir.replace('vocal', 'f0').replace('raw_data/', '24k_f0/')
+
+ if os.path.exists(os.path.join(mel_dir, f+'.npy')) is False:
+ mel = get_world_mel(os.path.join(audio_dir, f))
+
+ if os.path.exists(mel_dir) is False:
+ os.makedirs(mel_dir)
+ np.save(os.path.join(mel_dir, f+'.npy'), mel)
+
+ if save_f0 is True:
+ if os.path.exists(f0_dir) is False:
+ os.makedirs(f0_dir)
+ np.save(os.path.join(f0_dir, f + '.npy'), f0)
+
+
+if __name__ == '__main__':
+ cores = 8
+ meta = pd.read_csv('../raw_data/meta_fix.csv')
+ meta = meta[meta['subset'] == 'train']
+
+ idx_list = [i for i in range(len(meta))]
+
+ subsets = chunks(idx_list, cores)
+
+ for subset in subsets:
+ t = Process(target=extract_pw, args=(subset,))
+ t.start()
\ No newline at end of file
diff --git a/pitch_controller/dataset/__init__.py b/pitch_controller/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02ee5fcb455bf5d6bccdeb860cb28a3065d356a0
--- /dev/null
+++ b/pitch_controller/dataset/__init__.py
@@ -0,0 +1 @@
+from .diff_lpc import VCDecLPCDataset, VCDecLPCBatchCollate, VCDecLPCTest
\ No newline at end of file
diff --git a/pitch_controller/dataset/__pycache__/__init__.cpython-310.pyc b/pitch_controller/dataset/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dd0168f76ea87c34f3effbd02c94fc1b3889228
Binary files /dev/null and b/pitch_controller/dataset/__pycache__/__init__.cpython-310.pyc differ
diff --git a/pitch_controller/dataset/__pycache__/__init__.cpython-39.pyc b/pitch_controller/dataset/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc8a0a636619ddb9e002e29ae0bfab48a9d3160d
Binary files /dev/null and b/pitch_controller/dataset/__pycache__/__init__.cpython-39.pyc differ
diff --git a/pitch_controller/dataset/__pycache__/content_enc.cpython-310.pyc b/pitch_controller/dataset/__pycache__/content_enc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec3ae2e3219fc24868acce4d96470302d5b0c8a1
Binary files /dev/null and b/pitch_controller/dataset/__pycache__/content_enc.cpython-310.pyc differ
diff --git a/pitch_controller/dataset/__pycache__/content_enc.cpython-39.pyc b/pitch_controller/dataset/__pycache__/content_enc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6e47d6c54c00baee291bbbbfda90872a5d7a756
Binary files /dev/null and b/pitch_controller/dataset/__pycache__/content_enc.cpython-39.pyc differ
diff --git a/pitch_controller/dataset/__pycache__/diff.cpython-310.pyc b/pitch_controller/dataset/__pycache__/diff.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cbf3f34ac9208da3b873d56088129f16cfda236
Binary files /dev/null and b/pitch_controller/dataset/__pycache__/diff.cpython-310.pyc differ
diff --git a/pitch_controller/dataset/__pycache__/diff.cpython-39.pyc b/pitch_controller/dataset/__pycache__/diff.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b891cf87ea4f8ef38e1a969c2797b426fc90fbdd
Binary files /dev/null and b/pitch_controller/dataset/__pycache__/diff.cpython-39.pyc differ
diff --git a/pitch_controller/dataset/__pycache__/diff_lpc.cpython-310.pyc b/pitch_controller/dataset/__pycache__/diff_lpc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e61812d6f0a77d6b1ac9e89f97d1e01636e03ff6
Binary files /dev/null and b/pitch_controller/dataset/__pycache__/diff_lpc.cpython-310.pyc differ
diff --git a/pitch_controller/dataset/diff_lpc.py b/pitch_controller/dataset/diff_lpc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3fd77f18b339d463709c5ff38ae0721576ea951
--- /dev/null
+++ b/pitch_controller/dataset/diff_lpc.py
@@ -0,0 +1,271 @@
+import os
+import random
+import numpy as np
+import torch
+import tgt
+import pandas as pd
+
+from torch.utils.data import Dataset
+import librosa
+
+
+def f0_to_coarse(f0, hparams):
+ f0_bin = hparams['f0_bin']
+ f0_max = hparams['f0_max']
+ f0_min = hparams['f0_min']
+ is_torch = isinstance(f0, torch.Tensor)
+ # to mel scale
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+
+ unvoiced = (f0_mel == 0)
+
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+
+ f0_mel[unvoiced] = 0
+
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 0, (f0_coarse.max(), f0_coarse.min())
+ return f0_coarse
+
+
+def log_f0(f0, hparams):
+ f0_bin = hparams['f0_bin']
+ f0_max = hparams['f0_max']
+ f0_min = hparams['f0_min']
+
+ f0_mel = np.zeros_like(f0)
+ f0_mel[f0 != 0] = 12*np.log2(f0[f0 != 0]/f0_min) + 1
+ f0_mel_min = 12*np.log2(f0_min/f0_min) + 1
+ f0_mel_max = 12*np.log2(f0_max/f0_min) + 1
+
+ unvoiced = (f0_mel == 0)
+
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+
+ f0_mel[unvoiced] = 0
+
+ f0_coarse = np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= (f0_bin-1) and f0_coarse.min() >= 0, (f0_coarse.max(), f0_coarse.min())
+ return f0_coarse
+
+
+# training "average voice" encoder
+class VCDecLPCDataset(Dataset):
+ def __init__(self, data_dir, subset, content_dir='lpc_mel_512', extract_emb=False,
+ f0_type='bins'):
+ self.path = data_dir
+ meta = pd.read_csv(data_dir + 'meta_fix.csv')
+ self.meta = meta[meta['subset'] == subset]
+ self.content_dir = content_dir
+ self.extract_emb = extract_emb
+ self.f0_type = f0_type
+
+ def get_vc_data(self, audio_path, mel_id):
+ mel_dir = audio_path.replace('vocal', 'mel')
+ embed_dir = audio_path.replace('vocal', 'embed')
+ pitch_dir = audio_path.replace('vocal', 'f0')
+ content_dir = audio_path.replace('vocal', self.content_dir)
+
+ mel = os.path.join(mel_dir, mel_id + '.npy')
+ embed = os.path.join(embed_dir, mel_id + '.npy')
+ pitch = os.path.join(pitch_dir, mel_id + '.npy')
+ content = os.path.join(content_dir, mel_id + '.npy')
+
+ mel = np.load(mel)
+ if self.extract_emb:
+ embed = np.load(embed)
+ else:
+ embed = np.zeros(1)
+
+ pitch = np.load(pitch)
+ content = np.load(content)
+
+ pitch = np.nan_to_num(pitch)
+ if self.f0_type == 'bins':
+ pitch = f0_to_coarse(pitch, {'f0_bin': 256,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C6')})
+ elif self.f0_type == 'log':
+ pitch = log_f0(pitch, {'f0_bin': 345,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C#6')})
+
+ mel = torch.from_numpy(mel).float()
+ embed = torch.from_numpy(embed).float()
+ pitch = torch.from_numpy(pitch).float()
+ content = torch.from_numpy(content).float()
+
+ return (mel, embed, pitch, content)
+
+ def __getitem__(self, index):
+ row = self.meta.iloc[index]
+ mel_id = row['file_name']
+ audio_path = self.path + row['folder'] + row['subfolder']
+ mel, embed, pitch, content = self.get_vc_data(audio_path, mel_id)
+ item = {'mel': mel, 'embed': embed, 'f0': pitch, 'content': content}
+ return item
+
+ def __len__(self):
+ return len(self.meta)
+
+
+class VCDecLPCBatchCollate(object):
+ def __init__(self, train_frames, eps=1e-5):
+ self.train_frames = train_frames
+ self.eps = eps
+
+ def __call__(self, batch):
+ train_frames = self.train_frames
+ eps = self.eps
+
+ B = len(batch)
+ embed = torch.stack([item['embed'] for item in batch], 0)
+
+ n_mels = batch[0]['mel'].shape[0]
+ content_dim = batch[0]['content'].shape[0]
+
+ # min value of log-mel spectrogram is np.log(eps) == padding zero in time domain
+ mels1 = torch.ones((B, n_mels, train_frames), dtype=torch.float32) * np.log(eps)
+ mels2 = torch.ones((B, n_mels, train_frames), dtype=torch.float32) * np.log(eps)
+
+ # ! need to deal with empty frames here
+ contents1 = torch.ones((B, content_dim, train_frames), dtype=torch.float32) * np.log(eps)
+
+ f0s1 = torch.zeros((B, train_frames), dtype=torch.float32)
+ max_starts = [max(item['mel'].shape[-1] - train_frames, 0)
+ for item in batch]
+
+ starts1 = [random.choice(range(m)) if m > 0 else 0 for m in max_starts]
+ starts2 = [random.choice(range(m)) if m > 0 else 0 for m in max_starts]
+ mel_lengths = []
+ for i, item in enumerate(batch):
+ mel = item['mel']
+ f0 = item['f0']
+ content = item['content']
+
+ if mel.shape[-1] < train_frames:
+ mel_length = mel.shape[-1]
+ else:
+ mel_length = train_frames
+
+ mels1[i, :, :mel_length] = mel[:, starts1[i]:starts1[i] + mel_length]
+ f0s1[i, :mel_length] = f0[starts1[i]:starts1[i] + mel_length]
+ contents1[i, :, :mel_length] = content[:, starts1[i]:starts1[i] + mel_length]
+
+ mels2[i, :, :mel_length] = mel[:, starts2[i]:starts2[i] + mel_length]
+ mel_lengths.append(mel_length)
+
+ mel_lengths = torch.LongTensor(mel_lengths)
+
+ return {'mel1': mels1, 'mel2': mels2, 'mel_lengths': mel_lengths,
+ 'embed': embed,
+ 'f0_1': f0s1,
+ 'content1': contents1}
+
+
+class VCDecLPCTest(Dataset):
+ def __init__(self, data_dir, subset='test', eps=1e-5, test_frames=256, content_dir='lpc_mel_512', extract_emb=False,
+ f0_type='bins'):
+ self.path = data_dir
+ meta = pd.read_csv(data_dir + 'meta_test.csv')
+ self.meta = meta[meta['subset'] == subset]
+ self.content_dir = content_dir
+ self.extract_emb = extract_emb
+ self.eps = eps
+ self.test_frames = test_frames
+ self.f0_type = f0_type
+
+ def get_vc_data(self, audio_path, mel_id, pitch_shift):
+ mel_dir = audio_path.replace('vocal', 'mel')
+ embed_dir = audio_path.replace('vocal', 'embed')
+ pitch_dir = audio_path.replace('vocal', 'f0')
+ content_dir = audio_path.replace('vocal', self.content_dir)
+
+ mel = os.path.join(mel_dir, mel_id + '.npy')
+ embed = os.path.join(embed_dir, mel_id + '.npy')
+ pitch = os.path.join(pitch_dir, mel_id + '.npy')
+ content = os.path.join(content_dir, mel_id + '.npy')
+
+ mel = np.load(mel)
+ if self.extract_emb:
+ embed = np.load(embed)
+ else:
+ embed = np.zeros(1)
+
+ pitch = np.load(pitch)
+ content = np.load(content)
+
+ pitch = np.nan_to_num(pitch)
+ pitch = pitch*pitch_shift
+
+ if self.f0_type == 'bins':
+ pitch = f0_to_coarse(pitch, {'f0_bin': 256,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C6')})
+ elif self.f0_type == 'log':
+ pitch = log_f0(pitch, {'f0_bin': 345,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C#6')})
+
+ mel = torch.from_numpy(mel).float()
+ embed = torch.from_numpy(embed).float()
+ pitch = torch.from_numpy(pitch).float()
+ content = torch.from_numpy(content).float()
+
+ return (mel, embed, pitch, content)
+
+ def __getitem__(self, index):
+ row = self.meta.iloc[index]
+
+ mel_id = row['content_file_name']
+ audio_path = self.path + row['content_folder'] + row['content_subfolder']
+ pitch_shift = row['pitch_shift']
+ mel1, _, f0, content = self.get_vc_data(audio_path, mel_id, pitch_shift)
+
+ mel_id = row['timbre_file_name']
+ audio_path = self.path + row['timbre_folder'] + row['timbre_subfolder']
+ mel2, embed, _, _ = self.get_vc_data(audio_path, mel_id, pitch_shift)
+
+ n_mels = mel1.shape[0]
+ content_dim = content.shape[0]
+
+ mels1 = torch.ones((n_mels, self.test_frames), dtype=torch.float32) * np.log(self.eps)
+ mels2 = torch.ones((n_mels, self.test_frames), dtype=torch.float32) * np.log(self.eps)
+ lpcs1 = torch.ones((content_dim, self.test_frames), dtype=torch.float32) * np.log(self.eps)
+
+ f0s1 = torch.zeros(self.test_frames, dtype=torch.float32)
+
+ if mel1.shape[-1] < self.test_frames:
+ mel_length = mel1.shape[-1]
+ else:
+ mel_length = self.test_frames
+ mels1[:, :mel_length] = mel1[:, :mel_length]
+ f0s1[:mel_length] = f0[:mel_length]
+ lpcs1[:, :mel_length] = content[:, :mel_length]
+
+ if mel2.shape[-1] < self.test_frames:
+ mel_length = mel2.shape[-1]
+ else:
+ mel_length = self.test_frames
+ mels2[:, :mel_length] = mel2[:, :mel_length]
+
+ return {'mel1': mels1, 'mel2': mels2, 'embed': embed, 'f0_1': f0s1, 'content1': lpcs1}
+
+ def __len__(self):
+ return len(self.meta)
+
+
+if __name__ == '__main__':
+ f0 = np.array([110.0, 220.0, librosa.note_to_hz('C2'), 0, librosa.note_to_hz('E3'), librosa.note_to_hz('C6')])
+ # 50 midi notes = (50-1)
+ pitch = log_f0(f0, {'f0_bin': 345,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C#6')})
\ No newline at end of file
diff --git a/pitch_controller/dataset/diff_lpc_content.py b/pitch_controller/dataset/diff_lpc_content.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1263ee8cd332c7c64aad683c08d201a04b80883
--- /dev/null
+++ b/pitch_controller/dataset/diff_lpc_content.py
@@ -0,0 +1,231 @@
+import os
+import random
+import numpy as np
+import torch
+import tgt
+import pandas as pd
+
+from torch.utils.data import Dataset
+import librosa
+
+
+def f0_to_coarse(f0, hparams):
+ f0_bin = hparams['f0_bin']
+ f0_max = hparams['f0_max']
+ f0_min = hparams['f0_min']
+ is_torch = isinstance(f0, torch.Tensor)
+ # to mel scale
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+
+ unvoiced = (f0_mel == 0)
+
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+
+ f0_mel[unvoiced] = 0
+
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 0, (f0_coarse.max(), f0_coarse.min())
+ return f0_coarse
+
+
+# training "average voice" encoder
+class VCDecLPCDataset(Dataset):
+ def __init__(self, data_dir, subset, content_dir='lpc_mel_512', extract_emb=False):
+ self.path = data_dir
+ meta = pd.read_csv(data_dir + 'meta_fix.csv')
+ self.meta = meta[meta['subset'] == subset]
+ self.content_dir = content_dir
+ self.extract_emb = extract_emb
+
+ def get_vc_data(self, audio_path, mel_id):
+ mel_dir = audio_path.replace('vocal', 'mel')
+ embed_dir = audio_path.replace('vocal', 'embed')
+ pitch_dir = audio_path.replace('vocal', 'f0')
+ content_dir = audio_path.replace('vocal', self.content_dir)
+
+ mel = os.path.join(mel_dir, mel_id + '.npy')
+ embed = os.path.join(embed_dir, mel_id + '.npy')
+ pitch = os.path.join(pitch_dir, mel_id + '.npy')
+ content = os.path.join(content_dir, mel_id + '.npy')
+
+ mel = np.load(mel)
+ if self.extract_emb:
+ embed = np.load(embed)
+ else:
+ embed = np.zeros(1)
+
+ pitch = np.load(pitch)
+ content = np.load(content)
+
+ pitch = np.nan_to_num(pitch)
+ pitch = f0_to_coarse(pitch, {'f0_bin': 256,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C6')})
+
+ mel = torch.from_numpy(mel).float()
+ embed = torch.from_numpy(embed).float()
+ pitch = torch.from_numpy(pitch).float()
+ content = torch.from_numpy(content).float()
+
+ return (mel, embed, pitch, content)
+
+ def __getitem__(self, index):
+ row = self.meta.iloc[index]
+ mel_id = row['file_name']
+ audio_path = self.path + row['folder'] + row['subfolder']
+ mel, embed, pitch, content = self.get_vc_data(audio_path, mel_id)
+ item = {'mel': mel, 'embed': embed, 'f0': pitch, 'content': content}
+ return item
+
+ def __len__(self):
+ return len(self.meta)
+
+
+class VCDecLPCBatchCollate(object):
+ def __init__(self, train_frames, eps=np.log(1e-5), content_eps=np.log(1e-12)):
+ self.train_frames = train_frames
+ self.eps = eps
+ self.content_eps = content_eps
+
+ def __call__(self, batch):
+ train_frames = self.train_frames
+ eps = self.eps
+ content_eps = self.content_eps
+
+ B = len(batch)
+ embed = torch.stack([item['embed'] for item in batch], 0)
+
+ n_mels = batch[0]['mel'].shape[0]
+ content_dim = batch[0]['content'].shape[0]
+
+ # min value of log-mel spectrogram is np.log(eps) == padding zero in time domain
+ mels1 = torch.ones((B, n_mels, train_frames), dtype=torch.float32) * eps
+ mels2 = torch.ones((B, n_mels, train_frames), dtype=torch.float32) * eps
+
+ # using a different eps
+ contents1 = torch.ones((B, content_dim, train_frames), dtype=torch.float32) * content_eps
+
+ f0s1 = torch.zeros((B, train_frames), dtype=torch.float32)
+ max_starts = [max(item['mel'].shape[-1] - train_frames, 0)
+ for item in batch]
+
+ starts1 = [random.choice(range(m)) if m > 0 else 0 for m in max_starts]
+ starts2 = [random.choice(range(m)) if m > 0 else 0 for m in max_starts]
+ mel_lengths = []
+ for i, item in enumerate(batch):
+ mel = item['mel']
+ f0 = item['f0']
+ content = item['content']
+
+ if mel.shape[-1] < train_frames:
+ mel_length = mel.shape[-1]
+ else:
+ mel_length = train_frames
+
+ mels1[i, :, :mel_length] = mel[:, starts1[i]:starts1[i] + mel_length]
+ f0s1[i, :mel_length] = f0[starts1[i]:starts1[i] + mel_length]
+ contents1[i, :, :mel_length] = content[:, starts1[i]:starts1[i] + mel_length]
+
+ mels2[i, :, :mel_length] = mel[:, starts2[i]:starts2[i] + mel_length]
+ mel_lengths.append(mel_length)
+
+ mel_lengths = torch.LongTensor(mel_lengths)
+
+ return {'mel1': mels1, 'mel2': mels2, 'mel_lengths': mel_lengths,
+ 'embed': embed,
+ 'f0_1': f0s1,
+ 'content1': contents1}
+
+
+class VCDecLPCTest(Dataset):
+ def __init__(self, data_dir, subset='test', eps=np.log(1e-5), content_eps=np.log(1e-12), test_frames=256, content_dir='lpc_mel_512', extract_emb=False):
+ self.path = data_dir
+ meta = pd.read_csv(data_dir + 'meta_test.csv')
+ self.meta = meta[meta['subset'] == subset]
+ self.content_dir = content_dir
+ self.extract_emb = extract_emb
+ self.eps = eps
+ self.content_eps = content_eps
+ self.test_frames = test_frames
+
+ def get_vc_data(self, audio_path, mel_id, pitch_shift):
+ mel_dir = audio_path.replace('vocal', 'mel')
+ embed_dir = audio_path.replace('vocal', 'embed')
+ pitch_dir = audio_path.replace('vocal', 'f0')
+ content_dir = audio_path.replace('vocal', self.content_dir)
+
+ mel = os.path.join(mel_dir, mel_id + '.npy')
+ embed = os.path.join(embed_dir, mel_id + '.npy')
+ pitch = os.path.join(pitch_dir, mel_id + '.npy')
+ content = os.path.join(content_dir, mel_id + '.npy')
+
+ mel = np.load(mel)
+ if self.extract_emb:
+ embed = np.load(embed)
+ else:
+ embed = np.zeros(1)
+
+ pitch = np.load(pitch)
+ content = np.load(content)
+
+ pitch = np.nan_to_num(pitch)
+ pitch = pitch*pitch_shift
+ pitch = f0_to_coarse(pitch, {'f0_bin': 256,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C6')})
+
+ mel = torch.from_numpy(mel).float()
+ embed = torch.from_numpy(embed).float()
+ pitch = torch.from_numpy(pitch).float()
+ content = torch.from_numpy(content).float()
+
+ return (mel, embed, pitch, content)
+
+ def __getitem__(self, index):
+ row = self.meta.iloc[index]
+
+ mel_id = row['content_file_name']
+ audio_path = self.path + row['content_folder'] + row['content_subfolder']
+ pitch_shift = row['pitch_shift']
+ mel1, _, f0, content = self.get_vc_data(audio_path, mel_id, pitch_shift)
+
+ mel_id = row['timbre_file_name']
+ audio_path = self.path + row['timbre_folder'] + row['timbre_subfolder']
+ mel2, embed, _, _ = self.get_vc_data(audio_path, mel_id, pitch_shift)
+
+ n_mels = mel1.shape[0]
+ content_dim = content.shape[0]
+
+ mels1 = torch.ones((n_mels, self.test_frames), dtype=torch.float32) * self.eps
+ mels2 = torch.ones((n_mels, self.test_frames), dtype=torch.float32) * self.eps
+ # content
+ lpcs1 = torch.ones((content_dim, self.test_frames), dtype=torch.float32) * self.content_eps
+
+ f0s1 = torch.zeros(self.test_frames, dtype=torch.float32)
+
+ if mel1.shape[-1] < self.test_frames:
+ mel_length = mel1.shape[-1]
+ else:
+ mel_length = self.test_frames
+ mels1[:, :mel_length] = mel1[:, :mel_length]
+ f0s1[:mel_length] = f0[:mel_length]
+ lpcs1[:, :mel_length] = content[:, :mel_length]
+
+ if mel2.shape[-1] < self.test_frames:
+ mel_length = mel2.shape[-1]
+ else:
+ mel_length = self.test_frames
+ mels2[:, :mel_length] = mel2[:, :mel_length]
+
+ return {'mel1': mels1, 'mel2': mels2, 'embed': embed, 'f0_1': f0s1, 'content1': lpcs1}
+
+ def __len__(self):
+ return len(self.meta)
+
+
+
diff --git a/pitch_controller/load_vocoder.py b/pitch_controller/load_vocoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d61fbc0a259861a69e7a4cad245a8f460b12707
--- /dev/null
+++ b/pitch_controller/load_vocoder.py
@@ -0,0 +1,51 @@
+# from nsf_hifigan.models import load_model
+from modules.BigVGAN.inference import load_model
+import librosa
+
+import torch
+import torch.nn.functional as F
+import torchaudio
+import torchaudio.transforms as transforms
+
+import numpy as np
+import soundfile as sf
+
+
+class LogMelSpectrogram(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.melspctrogram = transforms.MelSpectrogram(
+ sample_rate=22050,
+ n_fft=1024,
+ win_length=1024,
+ hop_length=256,
+ center=False,
+ power=1.0,
+ norm="slaney",
+ n_mels=80,
+ mel_scale="slaney",
+ f_max=8000,
+ f_min=0,
+ )
+
+ def forward(self, wav):
+ wav = F.pad(wav, ((1024 - 256) // 2, (1024 - 256) // 2), "reflect")
+ mel = self.melspctrogram(wav)
+ logmel = torch.log(torch.clamp(mel, min=1e-5))
+ return logmel
+
+
+hifigan, cfg = load_model('modules/BigVGAN/ckpt/bigvgan_22khz_80band/g_05000000', device='cuda')
+M = LogMelSpectrogram()
+
+source, sr = torchaudio.load("music.mp3")
+source = torchaudio.functional.resample(source, sr, 22050)
+source = source.unsqueeze(0)
+mel = M(source).squeeze(0)
+
+# f0, f0_bin = get_pitch("116_1_pred.wav")
+# f0 = torch.tensor(f0).unsqueeze(0)
+with torch.no_grad():
+ y_hat = hifigan(mel.cuda()).cpu().numpy().squeeze(1)
+
+sf.write('test.wav', y_hat[0], samplerate=22050)
\ No newline at end of file
diff --git a/pitch_controller/models/__pycache__/base.cpython-310.pyc b/pitch_controller/models/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9054e2273f7445bb67fbbc6dc8c4606ec04924d1
Binary files /dev/null and b/pitch_controller/models/__pycache__/base.cpython-310.pyc differ
diff --git a/pitch_controller/models/__pycache__/base.cpython-39.pyc b/pitch_controller/models/__pycache__/base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..461a7872ff2e1d23ec111d6f250c3c8cc4728166
Binary files /dev/null and b/pitch_controller/models/__pycache__/base.cpython-39.pyc differ
diff --git a/pitch_controller/models/__pycache__/modules.cpython-310.pyc b/pitch_controller/models/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df9d7ed57017530ef7b4c5aeff750f3722a24395
Binary files /dev/null and b/pitch_controller/models/__pycache__/modules.cpython-310.pyc differ
diff --git a/pitch_controller/models/__pycache__/modules.cpython-39.pyc b/pitch_controller/models/__pycache__/modules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61d0a80a28ac7f5398bdbfa16099009012b4792e
Binary files /dev/null and b/pitch_controller/models/__pycache__/modules.cpython-39.pyc differ
diff --git a/pitch_controller/models/__pycache__/pitch.cpython-39.pyc b/pitch_controller/models/__pycache__/pitch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7a520d1d49ca153688c6cd64179734bffa9732a
Binary files /dev/null and b/pitch_controller/models/__pycache__/pitch.cpython-39.pyc differ
diff --git a/pitch_controller/models/__pycache__/unet.cpython-310.pyc b/pitch_controller/models/__pycache__/unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9a5bff3d0e6fb01be2a1b9bbe236ec5ca5a6e81
Binary files /dev/null and b/pitch_controller/models/__pycache__/unet.cpython-310.pyc differ
diff --git a/pitch_controller/models/__pycache__/unet.cpython-39.pyc b/pitch_controller/models/__pycache__/unet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66469b07f2add05927d1b7094014dc4ddca437e3
Binary files /dev/null and b/pitch_controller/models/__pycache__/unet.cpython-39.pyc differ
diff --git a/pitch_controller/models/__pycache__/update_unet.cpython-310.pyc b/pitch_controller/models/__pycache__/update_unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19b214b3997b5cf6fe01e7dfcfef408d2ac7aab1
Binary files /dev/null and b/pitch_controller/models/__pycache__/update_unet.cpython-310.pyc differ
diff --git a/pitch_controller/models/__pycache__/utils.cpython-310.pyc b/pitch_controller/models/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3b953957b48c64d44a111d45f8db0d28d723b43
Binary files /dev/null and b/pitch_controller/models/__pycache__/utils.cpython-310.pyc differ
diff --git a/pitch_controller/models/__pycache__/utils.cpython-39.pyc b/pitch_controller/models/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aeb6b27427a6bf7de53487337de85b96921b4846
Binary files /dev/null and b/pitch_controller/models/__pycache__/utils.cpython-39.pyc differ
diff --git a/pitch_controller/models/base.py b/pitch_controller/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c7395ddeba674eea0cb59594b9b2c838ae78c55
--- /dev/null
+++ b/pitch_controller/models/base.py
@@ -0,0 +1,30 @@
+# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the MIT License.
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# MIT License for more details.
+
+import numpy as np
+import torch
+
+
+class BaseModule(torch.nn.Module):
+ def __init__(self):
+ super(BaseModule, self).__init__()
+
+ @property
+ def nparams(self):
+ num_params = 0
+ for name, param in self.named_parameters():
+ if param.requires_grad:
+ num_params += np.prod(param.detach().cpu().numpy().shape)
+ return num_params
+
+ def relocate_input(self, x: list):
+ device = next(self.parameters()).device
+ for i in range(len(x)):
+ if isinstance(x[i], torch.Tensor) and x[i].device != device:
+ x[i] = x[i].to(device)
+ return x
diff --git a/pitch_controller/models/modules.py b/pitch_controller/models/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..da76268c128bfc9acc587d4db138f44ef180d5cc
--- /dev/null
+++ b/pitch_controller/models/modules.py
@@ -0,0 +1,237 @@
+# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the MIT License.
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# MIT License for more details.
+
+import math
+import torch
+from einops import rearrange
+
+from .base import BaseModule
+
+
+class Mish(BaseModule):
+ def forward(self, x):
+ return x * torch.tanh(torch.nn.functional.softplus(x))
+
+
+class Upsample(BaseModule):
+ def __init__(self, dim):
+ super(Upsample, self).__init__()
+ self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Downsample(BaseModule):
+ def __init__(self, dim):
+ super(Downsample, self).__init__()
+ self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Rezero(BaseModule):
+ def __init__(self, fn):
+ super(Rezero, self).__init__()
+ self.fn = fn
+ self.g = torch.nn.Parameter(torch.zeros(1))
+
+ def forward(self, x):
+ return self.fn(x) * self.g
+
+
+class Block(BaseModule):
+ def __init__(self, dim, dim_out, groups=8):
+ super(Block, self).__init__()
+ self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
+ padding=1), torch.nn.GroupNorm(
+ groups, dim_out), Mish())
+
+ def forward(self, x):
+ output = self.block(x)
+ return output
+
+
+class ResnetBlock(BaseModule):
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
+ super(ResnetBlock, self).__init__()
+ self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
+ dim_out))
+
+ self.block1 = Block(dim, dim_out, groups=groups)
+ self.block2 = Block(dim_out, dim_out, groups=groups)
+ if dim != dim_out:
+ self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
+ else:
+ self.res_conv = torch.nn.Identity()
+
+ def forward(self, x, time_emb):
+ h = self.block1(x)
+ h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
+ h = self.block2(h)
+ output = h + self.res_conv(x)
+ return output
+
+
+class LinearAttention(BaseModule):
+ def __init__(self, dim, heads=4, dim_head=32, q_norm=True):
+ super(LinearAttention, self).__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
+ self.q_norm = q_norm
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
+ heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ if self.q_norm:
+ q = q.softmax(dim=-2)
+
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
+ heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class Residual(BaseModule):
+ def __init__(self, fn):
+ super(Residual, self).__init__()
+ self.fn = fn
+
+ def forward(self, x, *args, **kwargs):
+ output = self.fn(x, *args, **kwargs) + x
+ return output
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class Timesteps(BaseModule):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class PitchPosEmb(BaseModule):
+ def __init__(self, dim, flip_sin_to_cos=False, downscale_freq_shift=0):
+ super(PitchPosEmb, self).__init__()
+ self.dim = dim
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, x):
+ # B * L
+ b, l = x.shape
+ x = rearrange(x, 'b l -> (b l)')
+ emb = get_timestep_embedding(
+ x,
+ self.dim,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ emb = rearrange(emb, '(b l) d -> b d l', b=b, l=l)
+ return emb
+
+
+class TimbreBlock(BaseModule):
+ def __init__(self, out_dim):
+ super(TimbreBlock, self).__init__()
+ base_dim = out_dim // 4
+
+ self.block11 = torch.nn.Sequential(torch.nn.Conv2d(1, 2 * base_dim,
+ 3, 1, 1),
+ torch.nn.InstanceNorm2d(2 * base_dim, affine=True),
+ torch.nn.GLU(dim=1))
+ self.block12 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 2 * base_dim,
+ 3, 1, 1),
+ torch.nn.InstanceNorm2d(2 * base_dim, affine=True),
+ torch.nn.GLU(dim=1))
+ self.block21 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 4 * base_dim,
+ 3, 1, 1),
+ torch.nn.InstanceNorm2d(4 * base_dim, affine=True),
+ torch.nn.GLU(dim=1))
+ self.block22 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 4 * base_dim,
+ 3, 1, 1),
+ torch.nn.InstanceNorm2d(4 * base_dim, affine=True),
+ torch.nn.GLU(dim=1))
+ self.block31 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 8 * base_dim,
+ 3, 1, 1),
+ torch.nn.InstanceNorm2d(8 * base_dim, affine=True),
+ torch.nn.GLU(dim=1))
+ self.block32 = torch.nn.Sequential(torch.nn.Conv2d(4 * base_dim, 8 * base_dim,
+ 3, 1, 1),
+ torch.nn.InstanceNorm2d(8 * base_dim, affine=True),
+ torch.nn.GLU(dim=1))
+ self.final_conv = torch.nn.Conv2d(4 * base_dim, out_dim, 1)
+
+ def forward(self, x):
+ y = self.block11(x)
+ y = self.block12(y)
+ y = self.block21(y)
+ y = self.block22(y)
+ y = self.block31(y)
+ y = self.block32(y)
+ y = self.final_conv(y)
+
+ return y.sum((2, 3)) / (y.shape[2] * y.shape[3])
\ No newline at end of file
diff --git a/pitch_controller/models/unet.py b/pitch_controller/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2ba05dd79cfac6e0deabf39034b834b4f701512
--- /dev/null
+++ b/pitch_controller/models/unet.py
@@ -0,0 +1,153 @@
+import math
+import torch
+
+from .base import BaseModule
+from .modules import Mish, Upsample, Downsample, Rezero, Block, ResnetBlock
+from .modules import LinearAttention, Residual, Timesteps, TimbreBlock, PitchPosEmb
+
+from einops import rearrange
+
+
+class UNetPitcher(BaseModule):
+ def __init__(self,
+ dim_base,
+ dim_cond,
+ use_ref_t,
+ use_embed,
+ dim_embed=256,
+ dim_mults=(1, 2, 4),
+ pitch_type='bins'):
+
+ super(UNetPitcher, self).__init__()
+ self.use_ref_t = use_ref_t
+ self.use_embed = use_embed
+ self.pitch_type = pitch_type
+
+ dim_in = 2
+
+ # time embedding
+ self.time_pos_emb = Timesteps(num_channels=dim_base,
+ flip_sin_to_cos=True,
+ downscale_freq_shift=0)
+
+ self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4),
+ Mish(), torch.nn.Linear(dim_base * 4, dim_base))
+
+ # speaker embedding
+ timbre_total = 0
+ if use_ref_t:
+ self.ref_block = TimbreBlock(out_dim=dim_cond)
+ timbre_total += dim_cond
+ if use_embed:
+ timbre_total += dim_embed
+
+ if timbre_total != 0:
+ self.timbre_block = torch.nn.Sequential(
+ torch.nn.Linear(timbre_total, 4 * dim_cond),
+ Mish(),
+ torch.nn.Linear(4 * dim_cond, dim_cond))
+
+ if use_embed or use_ref_t:
+ dim_in += dim_cond
+
+ self.pitch_pos_emb = PitchPosEmb(dim_cond)
+ self.pitch_mlp = torch.nn.Sequential(
+ torch.nn.Conv1d(dim_cond, dim_cond * 4, 1, stride=1),
+ Mish(),
+ torch.nn.Conv1d(dim_cond * 4, dim_cond, 1, stride=1), )
+ dim_in += dim_cond
+
+ # pitch embedding
+ # if self.pitch_type == 'bins':
+ # print('using mel bins for f0')
+ # elif self.pitch_type == 'log':
+ # print('using log bins f0')
+
+ dims = [dim_in, *map(lambda m: dim_base * m, dim_mults)]
+ in_out = list(zip(dims[:-1], dims[1:]))
+ # blocks
+ self.downs = torch.nn.ModuleList([])
+ self.ups = torch.nn.ModuleList([])
+ num_resolutions = len(in_out)
+
+ for ind, (dim_in, dim_out) in enumerate(in_out):
+ is_last = ind >= (num_resolutions - 1)
+ self.downs.append(torch.nn.ModuleList([
+ ResnetBlock(dim_in, dim_out, time_emb_dim=dim_base),
+ ResnetBlock(dim_out, dim_out, time_emb_dim=dim_base),
+ Residual(Rezero(LinearAttention(dim_out))),
+ Downsample(dim_out) if not is_last else torch.nn.Identity()]))
+
+ mid_dim = dims[-1]
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)
+ self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)
+
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
+ self.ups.append(torch.nn.ModuleList([
+ ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim_base),
+ ResnetBlock(dim_in, dim_in, time_emb_dim=dim_base),
+ Residual(Rezero(LinearAttention(dim_in))),
+ Upsample(dim_in)]))
+ self.final_block = Block(dim_base, dim_base)
+ self.final_conv = torch.nn.Conv2d(dim_base, 1, 1)
+
+ def forward(self, x, mean, f0, t, ref=None, embed=None):
+ if not torch.is_tensor(t):
+ t = torch.tensor([t], dtype=torch.long, device=x.device)
+ if len(t.shape) == 0:
+ t = t * torch.ones(x.shape[0], dtype=t.dtype, device=x.device)
+
+ t = self.time_pos_emb(t)
+ t = self.mlp(t)
+
+ x = torch.stack([x, mean], 1)
+
+ f0 = self.pitch_pos_emb(f0)
+ f0 = self.pitch_mlp(f0)
+ f0 = f0.unsqueeze(2)
+ f0 = torch.cat(x.shape[2] * [f0], 2)
+
+ timbre = None
+ if self.use_ref_t:
+ ref = torch.stack([ref], 1)
+ timbre = self.ref_block(ref)
+ if self.use_embed:
+ if timbre is not None:
+ timbre = torch.cat([timbre, embed], 1)
+ else:
+ timbre = embed
+ if timbre is None:
+ # raise Exception("at least use one timbre condition")
+ condition = f0
+ else:
+ timbre = self.timbre_block(timbre).unsqueeze(-1).unsqueeze(-1)
+ timbre = torch.cat(x.shape[2] * [timbre], 2)
+ timbre = torch.cat(x.shape[3] * [timbre], 3)
+ condition = torch.cat([f0, timbre], 1)
+
+ x = torch.cat([x, condition], 1)
+
+ hiddens = []
+ for resnet1, resnet2, attn, downsample in self.downs:
+ x = resnet1(x, t)
+ x = resnet2(x, t)
+ x = attn(x)
+ hiddens.append(x)
+ x = downsample(x)
+
+ x = self.mid_block1(x, t)
+ x = self.mid_attn(x)
+ x = self.mid_block2(x, t)
+
+ for resnet1, resnet2, attn, upsample in self.ups:
+ x = torch.cat((x, hiddens.pop()), dim=1)
+ x = resnet1(x, t)
+ x = resnet2(x, t)
+ x = attn(x)
+ x = upsample(x)
+
+ x = self.final_block(x)
+ output = self.final_conv(x)
+
+ return output.squeeze(1)
\ No newline at end of file
diff --git a/pitch_controller/models/utils.py b/pitch_controller/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..985298131db8980f9016c3851eab4505529430b9
--- /dev/null
+++ b/pitch_controller/models/utils.py
@@ -0,0 +1,110 @@
+# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the MIT License.
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# MIT License for more details.
+
+import torch
+import torchaudio
+import numpy as np
+from librosa.filters import mel as librosa_mel_fn
+
+from .base import BaseModule
+
+
+def mse_loss(x, y, mask, n_feats):
+ loss = torch.sum(((x - y)**2) * mask)
+ return loss / (torch.sum(mask) * n_feats)
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def fix_len_compatibility(length, num_downsamplings_in_unet=2):
+ while True:
+ if length % (2**num_downsamplings_in_unet) == 0:
+ return length
+ length += 1
+
+
+class PseudoInversion(BaseModule):
+ def __init__(self, n_mels, sampling_rate, n_fft):
+ super(PseudoInversion, self).__init__()
+ self.n_mels = n_mels
+ self.sampling_rate = sampling_rate
+ self.n_fft = n_fft
+ mel_basis = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=0, fmax=8000)
+ mel_basis_inverse = np.linalg.pinv(mel_basis)
+ mel_basis_inverse = torch.from_numpy(mel_basis_inverse).float()
+ self.register_buffer("mel_basis_inverse", mel_basis_inverse)
+
+ def forward(self, log_mel_spectrogram):
+ mel_spectrogram = torch.exp(log_mel_spectrogram)
+ stftm = torch.matmul(self.mel_basis_inverse, mel_spectrogram)
+ return stftm
+
+
+class InitialReconstruction(BaseModule):
+ def __init__(self, n_fft, hop_size):
+ super(InitialReconstruction, self).__init__()
+ self.n_fft = n_fft
+ self.hop_size = hop_size
+ window = torch.hann_window(n_fft).float()
+ self.register_buffer("window", window)
+
+ def forward(self, stftm):
+ real_part = torch.ones_like(stftm, device=stftm.device)
+ imag_part = torch.zeros_like(stftm, device=stftm.device)
+ stft = torch.stack([real_part, imag_part], -1)*stftm.unsqueeze(-1)
+ istft = torch.istft(stft, n_fft=self.n_fft,
+ hop_length=self.hop_size, win_length=self.n_fft,
+ window=self.window, center=True)
+ return istft.unsqueeze(1)
+
+
+# Fast Griffin-Lim algorithm as a PyTorch module
+class FastGL(BaseModule):
+ def __init__(self, n_mels, sampling_rate, n_fft, hop_size, momentum=0.99):
+ super(FastGL, self).__init__()
+ self.n_mels = n_mels
+ self.sampling_rate = sampling_rate
+ self.n_fft = n_fft
+ self.hop_size = hop_size
+ self.momentum = momentum
+ self.pi = PseudoInversion(n_mels, sampling_rate, n_fft)
+ self.ir = InitialReconstruction(n_fft, hop_size)
+ window = torch.hann_window(n_fft).float()
+ self.register_buffer("window", window)
+
+ @torch.no_grad()
+ def forward(self, s, n_iters=32):
+ c = self.pi(s)
+ x = self.ir(c)
+ x = x.squeeze(1)
+ c = c.unsqueeze(-1)
+ prev_angles = torch.zeros_like(c, device=c.device)
+ for _ in range(n_iters):
+ s = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_size,
+ win_length=self.n_fft, window=self.window,
+ center=True)
+ real_part, imag_part = s.unbind(-1)
+ stftm = torch.sqrt(torch.clamp(real_part**2 + imag_part**2, min=1e-8))
+ angles = s / stftm.unsqueeze(-1)
+ s = c * (angles + self.momentum * (angles - prev_angles))
+ x = torch.istft(s, n_fft=self.n_fft, hop_length=self.hop_size,
+ win_length=self.n_fft, window=self.window,
+ center=True)
+ prev_angles = angles
+ return x.unsqueeze(1)
diff --git a/pitch_controller/modules/BigVGAN/LICENSE b/pitch_controller/modules/BigVGAN/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e9663595cc28938f88d6299acd3ba791542e4c0c
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 NVIDIA CORPORATION.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/README.md b/pitch_controller/modules/BigVGAN/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a6cff37786a486deb55bc070254027aa492c2e92
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/README.md
@@ -0,0 +1,95 @@
+## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
+#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
+
+
+
+
+### [Paper](https://arxiv.org/abs/2206.04658)
+### [Audio demo](https://bigvgan-demo.github.io/)
+
+## Installation
+Clone the repository and install dependencies.
+```shell
+# the codebase has been tested on Python 3.8 / 3.10 with PyTorch 1.12.1 / 1.13 conda binaries
+git clone https://github.com/NVIDIA/BigVGAN
+pip install -r requirements.txt
+```
+
+Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset.
+``` shell
+cd LibriTTS && \
+ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
+ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
+ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
+ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
+ln -s /path/to/your/LibriTTS/dev-other dev-other && \
+ln -s /path/to/your/LibriTTS/test-clean test-clean && \
+ln -s /path/to/your/LibriTTS/test-other test-other && \
+cd ..
+```
+
+## Training
+Train BigVGAN model. Below is an example command for training BigVGAN using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input.
+```shell
+python train.py \
+--config configs/bigvgan_24khz_100band.json \
+--input_wavs_dir LibriTTS \
+--input_training_file LibriTTS/train-full.txt \
+--input_validation_file LibriTTS/val-full.txt \
+--list_input_unseen_wavs_dir LibriTTS LibriTTS \
+--list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \
+--checkpoint_path exp/bigvgan
+```
+
+## Synthesis
+Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
+It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
+```shell
+python inference.py \
+--checkpoint_file exp/bigvgan/g_05000000 \
+--input_wavs_dir /path/to/your/input_wav \
+--output_dir /path/to/your/output_wav
+```
+
+`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
+It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
+
+Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
+```shell
+python inference_e2e.py \
+--checkpoint_file exp/bigvgan/g_05000000 \
+--input_mels_dir /path/to/your/input_mel \
+--output_dir /path/to/your/output_wav
+```
+
+## Pretrained Models
+We provide the [pretrained models](https://drive.google.com/drive/folders/1e9wdM29d-t3EHUpBb8T4dcHrkYGAXTgq).
+One can download the checkpoints of generator (e.g., g_05000000) and discriminator (e.g., do_05000000) within the listed folders.
+
+|Folder Name|Sampling Rate|Mel band|fmax|Params.|Dataset|Fine-Tuned|
+|------|---|---|---|---|------|---|
+|bigvgan_24khz_100band|24 kHz|100|12000|112M|LibriTTS|No|
+|bigvgan_base_24khz_100band|24 kHz|100|12000|14M|LibriTTS|No|
+|bigvgan_22khz_80band|22 kHz|80|8000|112M|LibriTTS + VCTK + LJSpeech|No|
+|bigvgan_base_22khz_80band|22 kHz|80|8000|14M|LibriTTS + VCTK + LJSpeech|No|
+
+The paper results are based on 24kHz BigVGAN models trained on LibriTTS dataset.
+We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
+Note that, the latest checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality.
+
+
+## TODO
+
+Current codebase only provides a plain PyTorch implementation for the filtered nonlinearity. We are working on a fast CUDA kernel implementation, which will be released in the future.
+
+
+## References
+* [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
+
+* [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
+
+* [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
+
+* [Julius](https://github.com/adefossez/julius) (for low-pass filter)
+
+* [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/__pycache__/env.cpython-310.pyc b/pitch_controller/modules/BigVGAN/__pycache__/env.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b56d1a9e093afa8f993d94a084b8324456c464c3
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/__pycache__/env.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/__pycache__/inference.cpython-310.pyc b/pitch_controller/modules/BigVGAN/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8db7c95c1b8844d87c40655c5e6dc2e60d001ed6
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/__pycache__/inference.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/__pycache__/meldataset.cpython-310.pyc b/pitch_controller/modules/BigVGAN/__pycache__/meldataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27e63b2b64f00899b7ec2910443d7058cfb05570
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/__pycache__/meldataset.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/__pycache__/models.cpython-310.pyc b/pitch_controller/modules/BigVGAN/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ac2618df919f7988dfce771f50d5cef688949ec
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/__pycache__/models.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/__pycache__/utils.cpython-310.pyc b/pitch_controller/modules/BigVGAN/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb138b675d58dd1889d453e13cc0d7a2b8f72177
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/__pycache__/utils.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/activations/__pycache__/activations.cpython-310.pyc b/pitch_controller/modules/BigVGAN/activations/__pycache__/activations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce906239d790c95b8db90a1cc29ff741f0251aa1
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/activations/__pycache__/activations.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/activations/activations.py b/pitch_controller/modules/BigVGAN/activations/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..61f2808a5466b3cf4d041059700993af5527dd29
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/activations/activations.py
@@ -0,0 +1,120 @@
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+
+
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class SnakeBeta(nn.Module):
+ '''
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.beta = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/__init__.py b/pitch_controller/modules/BigVGAN/alias_free_torch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2318b63198250856809c0cb46210a4147b829bc
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/alias_free_torch/__init__.py
@@ -0,0 +1,6 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+from .filter import *
+from .resample import *
+from .act import *
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/__init__.cpython-310.pyc b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cc31973147be0273b1383745a8adc15b1668d48
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/__init__.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/act.cpython-310.pyc b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13167af90a4b6e49dabc891572964efa3ca0ae88
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/act.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/filter.cpython-310.pyc b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/filter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e67190a4c1ae41d9a06db70a0d4849e29cf4fa1
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/filter.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/resample.cpython-310.pyc b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/resample.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e33be87a2dfc070e8cdb527e9c7204a90f2f433
Binary files /dev/null and b/pitch_controller/modules/BigVGAN/alias_free_torch/__pycache__/resample.cpython-310.pyc differ
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/act.py b/pitch_controller/modules/BigVGAN/alias_free_torch/act.py
new file mode 100644
index 0000000000000000000000000000000000000000..028debd697dd60458aae75010057df038bd3518a
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/alias_free_torch/act.py
@@ -0,0 +1,28 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from .resample import UpSample1d, DownSample1d
+
+
+class Activation1d(nn.Module):
+ def __init__(self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/filter.py b/pitch_controller/modules/BigVGAN/alias_free_torch/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ad6ea87c1f10ddd94c544037791d7a4634d5ae1
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/alias_free_torch/filter.py
@@ -0,0 +1,95 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+if 'sinc' in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ # LICENSE is in incl_licenses directory.
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(x == 0,
+ torch.tensor(1., device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x)
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+# LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
+ even = (kernel_size % 2 == 0)
+ half_size = kernel_size // 2
+
+ #For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.:
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
+ else:
+ beta = 0.
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = (torch.arange(-half_size, half_size) + 0.5)
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = 'replicate',
+ kernel_size: int = 12):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = (kernel_size % 2 == 0)
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ #input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right),
+ mode=self.padding_mode)
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
+ stride=self.stride, groups=C)
+
+ return out
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/alias_free_torch/resample.py b/pitch_controller/modules/BigVGAN/alias_free_torch/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..750e6c3402cc5ac939c4b9d075246562e0e1d1a7
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/alias_free_torch/resample.py
@@ -0,0 +1,49 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from torch.nn import functional as F
+from .filter import LowPassFilter1d
+from .filter import kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ kernel_size=self.kernel_size)
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+ x = x[..., self.pad_left:-self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size)
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/env.py b/pitch_controller/modules/BigVGAN/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8be238d4db710c8c9a338d336baea0138f18d1f
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/env.py
@@ -0,0 +1,18 @@
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+import os
+import shutil
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def build_env(config, config_name, path):
+ t_path = os.path.join(path, config_name)
+ if config != t_path:
+ os.makedirs(path, exist_ok=True)
+ shutil.copyfile(config, os.path.join(path, config_name))
\ No newline at end of file
diff --git a/pitch_controller/modules/BigVGAN/inference.py b/pitch_controller/modules/BigVGAN/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a739344db3ec9ae08560e5477a394cca32d4a6d9
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/inference.py
@@ -0,0 +1,36 @@
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import glob
+import os
+import argparse
+import json
+import torch
+from scipy.io.wavfile import write
+from .env import AttrDict
+from .utils import MAX_WAV_VALUE
+from .models import BigVGAN as Generator
+import librosa
+
+
+def load_model(model_path, device='cuda'):
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
+ with open(config_file) as f:
+ data = f.read()
+
+ global h
+ json_config = json.loads(data)
+
+ h = AttrDict(json_config)
+
+ generator = Generator(h).to(device)
+
+ cp_dict = torch.load(model_path, map_location=device)
+ generator.load_state_dict(cp_dict['generator'])
+ generator.eval()
+ generator.remove_weight_norm()
+ del cp_dict
+ return generator, h
+
diff --git a/pitch_controller/modules/BigVGAN/models.py b/pitch_controller/modules/BigVGAN/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb40e0cff7819dcbe69555520253afd64580720
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/models.py
@@ -0,0 +1,381 @@
+# Copyright (c) 2022 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from .activations import activations
+from .utils import init_weights, get_padding
+from .alias_free_torch import *
+
+LRELU_SLOPE = 0.1
+
+
+class AMPBlock1(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
+ super(AMPBlock1, self).__init__()
+ self.h = h
+
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
+
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ else:
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
+
+ def forward(self, x):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
+ xt = a1(x)
+ xt = c1(xt)
+ xt = a2(xt)
+ xt = c2(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class AMPBlock2(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
+ super(AMPBlock2, self).__init__()
+ self.h = h
+
+ self.convs = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])))
+ ])
+ self.convs.apply(init_weights)
+
+ self.num_layers = len(self.convs) # total number of conv layers
+
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ else:
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
+
+ def forward(self, x):
+ for c, a in zip (self.convs, self.activations):
+ xt = a(x)
+ xt = c(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class BigVGAN(torch.nn.Module):
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
+ def __init__(self, h):
+ super(BigVGAN, self).__init__()
+ self.h = h
+
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+
+ # pre conv
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
+
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
+ resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
+
+ # transposed conv-based upsamplers. does not apply anti-aliasing
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(nn.ModuleList([
+ weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
+ h.upsample_initial_channel // (2 ** (i + 1)),
+ k, u, padding=(k - u) // 2))
+ ]))
+
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+ self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
+
+ # post conv
+ if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
+ activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(activation=activation_post)
+ elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
+ activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(activation=activation_post)
+ else:
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+
+ # weight initialization
+ for i in range(len(self.ups)):
+ self.ups[i].apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ # pre conv
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ # upsampling
+ for i_up in range(len(self.ups[i])):
+ x = self.ups[i][i_up](x)
+ # AMP blocks
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ # post conv
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ for l_i in l:
+ remove_weight_norm(l_i)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ self.d_mult = h.discriminator_channel_mult
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self, h):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.mpd_reshapes = h.mpd_reshapes
+ print("mpd_reshapes: {}".format(self.mpd_reshapes))
+ discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
+ self.discriminators = nn.ModuleList(discriminators)
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorR(nn.Module):
+ def __init__(self, cfg, resolution):
+ super().__init__()
+
+ self.resolution = resolution
+ assert len(self.resolution) == 3, \
+ "MRD layer requires list with len=3, got {}".format(self.resolution)
+ self.lrelu_slope = LRELU_SLOPE
+
+ norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
+ if hasattr(cfg, "mrd_use_spectral_norm"):
+ print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
+ norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
+ self.d_mult = cfg.discriminator_channel_mult
+ if hasattr(cfg, "mrd_channel_mult"):
+ print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
+ self.d_mult = cfg.mrd_channel_mult
+
+ self.convs = nn.ModuleList([
+ norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
+ ])
+ self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.spectrogram(x)
+ x = x.unsqueeze(1)
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.lrelu_slope)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+ def spectrogram(self, x):
+ n_fft, hop_length, win_length = self.resolution
+ x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
+ x = x.squeeze(1)
+ x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
+ x = torch.view_as_real(x) # [B, F, TT, 2]
+ mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
+
+ return mag
+
+
+class MultiResolutionDiscriminator(nn.Module):
+ def __init__(self, cfg, debug=False):
+ super().__init__()
+ self.resolutions = cfg.resolutions
+ assert len(self.resolutions) == 3,\
+ "MRD requires list of list with len=3, each element having a list with len=3. got {}".\
+ format(self.resolutions)
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(x=y)
+ y_d_g, fmap_g = d(x=y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss*2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1-dr)**2)
+ g_loss = torch.mean(dg**2)
+ loss += (r_loss + g_loss)
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean((1-dg)**2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
diff --git a/pitch_controller/modules/BigVGAN/utils.py b/pitch_controller/modules/BigVGAN/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed67f356aef6ce3af01b43d97d8aafb31c57b017
--- /dev/null
+++ b/pitch_controller/modules/BigVGAN/utils.py
@@ -0,0 +1,81 @@
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+import glob
+import os
+import matplotlib
+import torch
+from torch.nn.utils import weight_norm
+matplotlib.use("Agg")
+import matplotlib.pylab as plt
+from scipy.io.wavfile import write
+
+MAX_WAV_VALUE = 32768.0
+
+
+def plot_spectrogram(spectrogram):
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none')
+ plt.colorbar(im, ax=ax)
+
+ fig.canvas.draw()
+ plt.close()
+
+ return fig
+
+
+def plot_spectrogram_clipped(spectrogram, clip_max=2.):
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none', vmin=1e-6, vmax=clip_max)
+ plt.colorbar(im, ax=ax)
+
+ fig.canvas.draw()
+ plt.close()
+
+ return fig
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size*dilation - dilation)/2)
+
+
+def load_checkpoint(filepath, device):
+ assert os.path.isfile(filepath)
+ print("Loading '{}'".format(filepath))
+ checkpoint_dict = torch.load(filepath, map_location=device)
+ print("Complete.")
+ return checkpoint_dict
+
+
+def save_checkpoint(filepath, obj):
+ print("Saving checkpoint to {}".format(filepath))
+ torch.save(obj, filepath)
+ print("Complete.")
+
+
+def scan_checkpoint(cp_dir, prefix):
+ pattern = os.path.join(cp_dir, prefix + '????????')
+ cp_list = glob.glob(pattern)
+ if len(cp_list) == 0:
+ return None
+ return sorted(cp_list)[-1]
+
+def save_audio(audio, path, sr):
+ # wav: torch with 1d shape
+ audio = audio * MAX_WAV_VALUE
+ audio = audio.cpu().numpy().astype('int16')
+ write(path, sr, audio)
\ No newline at end of file
diff --git a/pitch_controller/train_world_tuner_24k.py b/pitch_controller/train_world_tuner_24k.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9b35b9692b9a9add781cfcf16f831cb6bfbff8f
--- /dev/null
+++ b/pitch_controller/train_world_tuner_24k.py
@@ -0,0 +1,237 @@
+import os, json, argparse, yaml
+import numpy as np
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.cuda.amp import autocast, GradScaler
+
+from diffusers import DDIMScheduler
+
+from dataset import VCDecLPCDataset, VCDecLPCBatchCollate, VCDecLPCTest
+from models.unet import UNetVC
+from modules.BigVGAN.inference import load_model
+from utils import save_plot, save_audio
+from utils import minmax_norm_diff, reverse_minmax_norm_diff
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-config', type=str, default='config/DiffWorld_24k_log.yaml')
+
+parser.add_argument('-seed', type=int, default=98)
+parser.add_argument('-amp', type=bool, default=True)
+parser.add_argument('-compile', type=bool, default=False)
+
+parser.add_argument('-data_dir', type=str, default='../24k_center/')
+parser.add_argument('-lpc_dir', type=str, default='world')
+parser.add_argument('-vocoder_dir', type=str, default='modules/BigVGAN/ckpt/bigvgan_base_24khz_100band/g_05000000')
+
+parser.add_argument('-train_frames', type=int, default=128)
+parser.add_argument('-batch_size', type=int, default=32)
+parser.add_argument('-test_size', type=int, default=1)
+parser.add_argument('-num_workers', type=int, default=4)
+parser.add_argument('-lr', type=float, default=5e-5)
+parser.add_argument('-weight_decay', type=int, default=1e-6)
+
+parser.add_argument('-epochs', type=int, default=80)
+parser.add_argument('-save_every', type=int, default=2)
+parser.add_argument('-log_step', type=int, default=200)
+parser.add_argument('-log_dir', type=str, default='logs_dec_world_24k')
+parser.add_argument('-ckpt_dir', type=str, default='ckpt_world_24k')
+
+args = parser.parse_args()
+args.save_ori = True
+config = yaml.load(open(args.config), Loader=yaml.FullLoader)
+mel_cfg = config['logmel']
+ddpm_cfg = config['ddpm']
+unet_cfg = config['unet']
+f0_type = unet_cfg['pitch_type']
+
+if __name__ == "__main__":
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ if torch.cuda.is_available():
+ args.device = 'cuda'
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if torch.backends.cudnn.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+ else:
+ args.device = 'cpu'
+
+ if os.path.exists(args.log_dir) is False:
+ os.makedirs(args.log_dir)
+
+ if os.path.exists(args.ckpt_dir) is False:
+ os.makedirs(args.ckpt_dir)
+
+ print('Initializing vocoder...')
+ hifigan, cfg = load_model(args.vocoder_dir, device=args.device)
+
+ print('Initializing data loaders...')
+ train_set = VCDecLPCDataset(args.data_dir, subset='train', content_dir=args.lpc_dir, f0_type=f0_type)
+ collate_fn = VCDecLPCBatchCollate(args.train_frames)
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
+ collate_fn=collate_fn, num_workers=args.num_workers, drop_last=True)
+
+ val_set = VCDecLPCTest(args.data_dir, content_dir=args.lpc_dir, f0_type=f0_type)
+ val_loader = DataLoader(val_set, batch_size=1, shuffle=False)
+
+ print('Initializing and loading models...')
+ model = UNetVC(**unet_cfg).to(args.device)
+ print('Number of parameters = %.2fm\n' % (model.nparams / 1e6))
+
+ # prepare DPM scheduler
+ noise_scheduler = DDIMScheduler(num_train_timesteps=ddpm_cfg['num_train_steps'])
+
+ print('Initializing optimizers...')
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+ scaler = GradScaler()
+
+ if args.compile:
+ model = torch.compile(model)
+
+ print('Start training.')
+ global_step = 0
+ for epoch in range(1, args.epochs + 1):
+ print(f'Epoch: {epoch} [iteration: {global_step}]')
+ model.train()
+ losses = []
+
+ for step, batch in enumerate(tqdm(train_loader)):
+ optimizer.zero_grad()
+
+ # make spectrogram range from -1 to 1
+ mel = batch['mel1'].to(args.device)
+ mel = minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+
+ if unet_cfg["use_ref_t"]:
+ mel_ref = batch['mel2'].to(args.device)
+ mel_ref = minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+ else:
+ mel_ref = None
+
+ f0 = batch['f0_1'].to(args.device)
+
+ mean = batch['content1'].to(args.device)
+ mean = minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+
+ noise = torch.randn(mel.shape).to(args.device)
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
+ (args.batch_size,),
+ device=args.device, ).long()
+
+ noisy_mel = noise_scheduler.add_noise(mel, noise, timesteps)
+
+ if args.amp:
+ with autocast():
+ noise_pred = model(x=noisy_mel, mean=mean, f0=f0, t=timesteps, ref=mel_ref, embed=None)
+ loss = F.mse_loss(noise_pred, noise)
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ noise_pred = model(x=noisy_mel, mean=mean, f0=f0, t=timesteps, ref=mel_ref, embed=None)
+ loss = F.mse_loss(noise_pred, noise)
+ # Backward propagation
+ loss.backward()
+ optimizer.step()
+
+ losses.append(loss.item())
+ global_step += 1
+
+ if global_step % args.log_step == 0:
+ losses = np.asarray(losses)
+ # msg = 'Epoch %d: loss = %.4f\n' % (epoch, np.mean(losses))
+ msg = '\nEpoch: [{}][{}]\t' \
+ 'Batch: [{}][{}]\tLoss: {:.6f}\n'.format(epoch,
+ args.epochs,
+ step+1,
+ len(train_loader),
+ np.mean(losses))
+ with open(f'{args.log_dir}/train_dec.log', 'a') as f:
+ f.write(msg)
+ losses = []
+
+ if epoch % args.save_every > 0:
+ continue
+
+ print('Saving model...\n')
+ ckpt = model.state_dict()
+ torch.save(ckpt, f=f"{args.ckpt_dir}/lpc_vc_{epoch}.pt")
+
+ print('Inference...\n')
+ noise = None
+ noise_scheduler.set_timesteps(ddpm_cfg['inference_steps'])
+ model.eval()
+ with torch.no_grad():
+ for i, batch in enumerate(val_loader):
+ # optimizer.zero_grad()
+ generator = torch.Generator(device=args.device).manual_seed(args.seed)
+
+ mel = batch['mel1'].to(args.device)
+ mel = minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+
+ if unet_cfg["use_ref_t"]:
+ mel_ref = batch['mel2'].to(args.device)
+ mel_ref = minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+ else:
+ mel_ref = None
+
+ f0 = batch['f0_1'].to(args.device)
+ embed = batch['embed'].to(args.device)
+
+ mean = batch['content1'].to(args.device)
+ mean = minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+
+ # make spectrogram range from -1 to 1
+ if noise is None:
+ noise = torch.randn(mel.shape,
+ generator=generator,
+ device=args.device,
+ )
+ pred = noise
+
+ for t in noise_scheduler.timesteps:
+ pred = noise_scheduler.scale_model_input(pred, t)
+ model_output = model(x=pred, mean=mean, f0=f0, t=t, ref=mel_ref, embed=None)
+ pred = noise_scheduler.step(model_output=model_output,
+ timestep=t,
+ sample=pred,
+ eta=ddpm_cfg['eta'], generator=generator).prev_sample
+
+
+ if os.path.exists(f'{args.log_dir}/audio/{i}/') is False:
+ os.makedirs(f'{args.log_dir}/audio/{i}/')
+ os.makedirs(f'{args.log_dir}/pic/{i}/')
+
+ # save pred
+ pred = reverse_minmax_norm_diff(pred, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+ save_plot(pred.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_pred.png')
+ audio = hifigan(pred)
+ save_audio(f'{args.log_dir}/audio/{i}/{epoch}_pred.wav', mel_cfg['sampling_rate'], audio)
+
+ if args.save_ori is True:
+ # save ref
+ # mel_ref = reverse_minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+ # save_plot(mel_ref.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_ref.png')
+ # audio = hifigan(mel_ref)
+ # save_audio(f'{args.log_dir}/audio/{i}/{epoch}_ref.wav', mel_cfg['sampling_rate'], audio)
+
+ # save source
+ mel = reverse_minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+ save_plot(mel.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_source.png')
+ audio = hifigan(mel)
+ save_audio(f'{args.log_dir}/audio/{i}/{epoch}_source.wav', mel_cfg['sampling_rate'], audio)
+
+ # save content
+ mean = reverse_minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
+ save_plot(mean.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_avg.png')
+ audio = hifigan(mean)
+ save_audio(f'{args.log_dir}/audio/{i}/{epoch}_avg.wav', mel_cfg['sampling_rate'], audio)
+
+ args.save_ori = False
diff --git a/pitch_controller/utils.py b/pitch_controller/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4aee85dafce5fc9d38ea53ce698e4b0e9262a61
--- /dev/null
+++ b/pitch_controller/utils.py
@@ -0,0 +1,51 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.io import wavfile
+import torch
+from torch.nn import functional as F
+
+
+def repeat_expand_2d(content, target_len):
+ # align content with mel
+
+ src_len = content.shape[-1]
+ target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
+ temp = torch.arange(src_len+1) * target_len / src_len
+ current_pos = 0
+ for i in range(target_len):
+ if i < temp[current_pos+1]:
+ target[:, i] = content[:, current_pos]
+ else:
+ current_pos += 1
+ target[:, i] = content[:, current_pos]
+
+ return target
+
+
+def save_plot(tensor, savepath):
+ plt.style.use('default')
+ fig, ax = plt.subplots(figsize=(12, 3))
+ im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
+ plt.colorbar(im, ax=ax)
+ plt.tight_layout()
+ fig.canvas.draw()
+ plt.savefig(savepath)
+ plt.close()
+
+
+def save_audio(file_path, sampling_rate, audio):
+ audio = np.clip(audio.detach().cpu().squeeze().numpy(), -0.999, 0.999)
+ wavfile.write(file_path, sampling_rate, (audio * 32767).astype("int16"))
+
+
+def minmax_norm_diff(tensor: torch.Tensor, vmax: float = 2.5, vmin: float = -12) -> torch.Tensor:
+ tensor = torch.clip(tensor, vmin, vmax)
+ tensor = 2 * (tensor - vmin) / (vmax - vmin) - 1
+ return tensor
+
+
+def reverse_minmax_norm_diff(tensor: torch.Tensor, vmax: float = 2.5, vmin: float = -12) -> torch.Tensor:
+ tensor = torch.clip(tensor, -1.0, 1.0)
+ tensor = (tensor + 1) / 2
+ tensor = tensor * (vmax - vmin) + vmin
+ return tensor
\ No newline at end of file
diff --git a/pitch_predictor/README.md b/pitch_predictor/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e187365c17c2ab0c25a317f8447e7a629ddfa2ae
--- /dev/null
+++ b/pitch_predictor/README.md
@@ -0,0 +1 @@
+# Pitchformer for Score-based Automatic Pitch Correction
diff --git a/pitch_predictor/config/Pitchformer.yaml b/pitch_predictor/config/Pitchformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bf5802a16e0e2dfe490c790f46d4788f94b9e37f
--- /dev/null
+++ b/pitch_predictor/config/Pitchformer.yaml
@@ -0,0 +1,32 @@
+version: 1.0
+
+unet:
+ sample_size: [1]
+ # spec_dim: 100
+ in_channels: 102
+ out_channels: 1
+ layers_per_block: 2
+ block_out_channels: [256, 256, 256]
+ down_block_types:
+ ["DownBlock1D",
+ "AttnDownBlock1D",
+ "AttnDownBlock1D",
+ ]
+ up_block_types:
+ ["AttnUpBlock1D",
+ "AttnUpBlock1D",
+ "UpBlock1D",
+ ]
+
+ddpm:
+ num_train_steps: 1000
+ inference_steps: 100
+ eta: 0.8
+
+logmel:
+ n_mels: 100
+ sampling_rate: 24000
+ n_fft: 1024
+ hop_size: 256
+ max: 2.5
+ min: -12
\ No newline at end of file
diff --git a/pitch_predictor/data/example/f0/2.npy b/pitch_predictor/data/example/f0/2.npy
new file mode 100644
index 0000000000000000000000000000000000000000..52e268335ea4594a567c425a2746e179f887aa87
--- /dev/null
+++ b/pitch_predictor/data/example/f0/2.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b13cdfa6d0a1a6ef29b832b7348c4282abd18784ce391e6f36e8bbd54da9023e
+size 3184
diff --git a/pitch_predictor/data/example/midi/2.midi b/pitch_predictor/data/example/midi/2.midi
new file mode 100644
index 0000000000000000000000000000000000000000..f5e4a0acc1c5ba6e74416306ff5b8604b008aa41
Binary files /dev/null and b/pitch_predictor/data/example/midi/2.midi differ
diff --git a/pitch_predictor/data/example/roll/2.npy b/pitch_predictor/data/example/roll/2.npy
new file mode 100644
index 0000000000000000000000000000000000000000..f729483ed02449fb4a4585364c92853769803d38
--- /dev/null
+++ b/pitch_predictor/data/example/roll/2.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:105bb8565c9cbf846fe8f1b93cd051c00c106bf1dd318169b2ed7dfd8707c431
+size 417920
diff --git a/pitch_predictor/data/example/roll_align/2.npy b/pitch_predictor/data/example/roll_align/2.npy
new file mode 100644
index 0000000000000000000000000000000000000000..fa47ace58a8c46d6c7a2ebf7862bfd7081b67e9e
--- /dev/null
+++ b/pitch_predictor/data/example/roll_align/2.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10f6405ab3a91c36cfe7a6508d74445c08cfe3210a997b9b37286cda4d47639b
+size 3184
diff --git a/pitch_predictor/data/example/vocal/2.wav b/pitch_predictor/data/example/vocal/2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..0217bed7847312941b1a65a3f31f2fe306ac7f8e
Binary files /dev/null and b/pitch_predictor/data/example/vocal/2.wav differ
diff --git a/pitch_predictor/data/example/world/2.npy b/pitch_predictor/data/example/world/2.npy
new file mode 100644
index 0000000000000000000000000000000000000000..87338585dfccfef4c385e851a94e15910f368924
--- /dev/null
+++ b/pitch_predictor/data/example/world/2.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41b53d64476b019714cacf063296021357980461c3c38eedce3d383a3cd3b4ad
+size 152928
diff --git a/pitch_predictor/data/preprocess_csd.py b/pitch_predictor/data/preprocess_csd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a045472c7b3841a574c8cb6a0cfb498b2659b47b
--- /dev/null
+++ b/pitch_predictor/data/preprocess_csd.py
@@ -0,0 +1,99 @@
+import os.path
+import json
+from tqdm import tqdm
+import pandas as pd
+import numpy as np
+import textgrid
+import pretty_midi
+import music21
+import librosa
+import soundfile as sf
+
+
+def piano_roll_to_pretty_midi(piano_roll, fs=100, program=0, bpm=120):
+ notes, frames = piano_roll.shape
+ pm = pretty_midi.PrettyMIDI()
+ instrument = pretty_midi.Instrument(program=program, )
+
+ # pad 1 column of zeros so we can acknowledge inital and ending events
+ piano_roll = np.pad(piano_roll, [(0, 0), (1, 1)], 'constant')
+
+ # use changes in velocities to find note on / note off events
+ velocity_changes = np.nonzero(np.diff(piano_roll).T)
+
+ # keep track on velocities and note on times
+ prev_velocities = np.zeros(notes, dtype=int)
+ note_on_time = np.zeros(notes)
+
+ for time, note in zip(*velocity_changes):
+ # use time + 1 because of padding above
+ velocity = piano_roll[note, time + 1]
+ time = time / fs * bpm / 120
+ # time = time / fs
+ if velocity > 0:
+ if prev_velocities[note] == 0:
+ note_on_time[note] = time
+ prev_velocities[note] = velocity
+ else:
+ pm_note = pretty_midi.Note(
+ velocity=prev_velocities[note],
+ pitch=note,
+ start=note_on_time[note],
+ end=time)
+ instrument.notes.append(pm_note)
+ prev_velocities[note] = 0
+ pm.instruments.append(instrument)
+
+ beats = np.array([0, int(pm.get_end_time()+1)])
+ pm.adjust_times(beats, beats * 120 / bpm)
+ # print(beats)
+ return pm
+
+
+f = open('CSD/English/metadata.json', encoding="utf8")
+meta = json.load(f)
+folder = 'CSD/English/'
+for wav in tqdm(os.listdir(folder+'wav')):
+ song_id = wav.replace('.wav', '')
+ midi_id = wav.replace('.wav', '.mid')
+ roll_id = wav.replace('.wav', '.npy')
+
+ wav, sr = librosa.load(folder+'wav/'+wav)
+ midi = pretty_midi.PrettyMIDI(folder+'mid/'+midi_id)
+ roll = midi.get_piano_roll()
+
+ bpm = meta[song_id]['tempo']
+
+ for i in range(int(roll.shape[1])//1000):
+ # print(i)
+ start = i*10
+ end = (i+1)*10
+
+ wav_seg = wav[round(start * sr):round(end * sr)]
+
+ os.makedirs('CSD_segements/'+song_id+'/vocal/', exist_ok=True)
+ os.makedirs('CSD_segements/' + song_id + '/roll/', exist_ok=True)
+ os.makedirs('CSD_segements/' + song_id + '/midi/', exist_ok=True)
+
+ sf.write('CSD_segements/'+song_id+'/vocal/'+str(i)+'.wav', wav_seg, samplerate=sr)
+
+ cur_roll = roll[:, round(100*start):round(100*end)]
+
+ if round((end-start)*100) != cur_roll.shape[1]:
+ print(sentence)
+ print(song_id)
+ print((end-start)*100)
+ print(cur_roll.shape)
+
+ # save npy rolls
+ np.save('CSD_segements/'+song_id+'/roll/'+str(i)+'.npy', cur_roll)
+
+ # save midi files
+ cur_midi = piano_roll_to_pretty_midi(cur_roll, fs=100, bpm=bpm)
+ # cur_midi.write('cache/'+song_id+str(num)+'.midi')
+ cur_midi.write('CSD_segements/'+song_id+'/midi/'+str(i)+'.midi')
+ # fctr = bpm/120
+ # score = music21.converter.Converter()
+ # score.parseFile('cache/'+song_id+str(num)+'.midi')
+ # newscore = score.stream.augmentOrDiminish(fctr)
+ # newscore.write('midi', 'segements/'+song_id+'/midi/'+str(num)+'.midi')
\ No newline at end of file
diff --git a/pitch_predictor/dataset/__pycache__/diffpitch.cpython-310.pyc b/pitch_predictor/dataset/__pycache__/diffpitch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ceb489e294fe2ed53c34162a7ea58a158e97bb54
Binary files /dev/null and b/pitch_predictor/dataset/__pycache__/diffpitch.cpython-310.pyc differ
diff --git a/pitch_predictor/dataset/diffpitch.py b/pitch_predictor/dataset/diffpitch.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc624bcea9332b1c5b059fbb0f9a045cac10fa35
--- /dev/null
+++ b/pitch_predictor/dataset/diffpitch.py
@@ -0,0 +1,97 @@
+import os
+import random
+import numpy as np
+import pandas as pd
+import librosa
+
+import torch
+import torchaudio
+from torch.utils.data import Dataset
+
+
+def algin_mapping(content, target_len):
+ # align content with mel
+ src_len = content.shape[-1]
+ target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
+ temp = torch.arange(src_len+1) * target_len / src_len
+
+ for i in range(target_len):
+ cur_idx = torch.argmin(torch.abs(temp-i))
+ target[:, i] = content[:, cur_idx]
+ return target
+
+
+def midi_to_hz(midi):
+ idx = torch.zeros(midi.shape[-1])
+ for frame in range(midi.shape[-1]):
+ midi_frame = midi[:, frame]
+ non_zero = midi_frame.nonzero()
+ if len(non_zero) != 0:
+ hz = librosa.midi_to_hz(non_zero[0])
+ idx[frame] = torch.tensor(hz)
+ return idx
+
+
+# training "average voice" encoder
+class DiffPitch(Dataset):
+ def __init__(self, data_dir, subset, frames, content='world', shift=True, log_scale=False):
+ meta = pd.read_csv(data_dir + 'meta.csv')
+ self.data_dir = data_dir
+ self.meta = meta[meta['subset'] == subset]
+ self.frames = frames
+ self.content = content
+ self.shift = shift
+ self.log_scale = log_scale
+
+ def __getitem__(self, index):
+ row = self.meta.iloc[index]
+ folder = row['folder']
+ subfolder = row['subfolder']
+ file_id = row['file_name']
+ folder = os.path.join(self.data_dir, folder)
+ folder = os.path.join(folder, str(subfolder))
+ folder = os.path.join(folder, 'vocal')
+ folder = os.path.join(folder, file_id)
+
+ content_folder = folder.replace('vocal', self.content).replace('.wav', '.npy')
+ content = torch.tensor(np.load(content_folder), dtype=torch.float32)
+ # print(content.shape)
+
+ midi_folder = folder.replace('vocal', 'roll_align').replace('.wav', '.npy')
+ midi = torch.tensor(np.load(midi_folder), dtype=torch.float32)
+ # print(midi.shape)
+ # midi = algin_mapping(midi, content.shape[-1])
+
+ f0_folder = folder.replace('vocal', 'f0').replace('.wav', '.npy')
+ f0 = torch.tensor(np.load(f0_folder), dtype=torch.float32)
+
+ max_start = max(content.shape[-1] - self.frames, 0)
+ start = random.choice(range(max_start)) if max_start > 0 else 0
+ end = min(int(start + self.frames), content.shape[-1])
+
+ out_content = torch.ones((content.shape[0], self.frames)) * np.log(1e-5)
+ out_midi = torch.zeros(self.frames)
+ out_f0 = torch.zeros(self.frames)
+
+ out_content[:, :end-start] = content[:, start:end]
+ out_midi[:end-start] = midi[start:end]
+ out_f0[:end-start] = f0[start:end]
+
+ # out_midi = midi_to_hz(out_midi)
+
+ if self.shift is True:
+ shift = np.random.choice(25, 1)[0]
+ shift = shift - 12
+
+ # midi[midi != 0] += shift
+ out_midi = out_midi*(2**(shift/12))
+ out_f0 = out_f0*(2**(shift/12))
+
+ if self.log_scale:
+ out_midi = 1127 * np.log(1 + out_midi / 700)
+ out_f0 = 1127 * np.log(1 + out_f0 / 700)
+
+ return out_content, out_midi, out_f0
+
+ def __len__(self):
+ return len(self.meta)
\ No newline at end of file
diff --git a/pitch_predictor/models/__pycache__/transformer.cpython-310.pyc b/pitch_predictor/models/__pycache__/transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0a83cd2f98f308d4fc3be4e8915f7ee4f617ecf
Binary files /dev/null and b/pitch_predictor/models/__pycache__/transformer.cpython-310.pyc differ
diff --git a/pitch_predictor/models/rnn.py b/pitch_predictor/models/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e91cb581b4a7c4b13b78517929c5459726fff6a
--- /dev/null
+++ b/pitch_predictor/models/rnn.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+
+
+class PitchRNN(nn.Module):
+ def __init__(self, n_mels, hidden_size):
+ super(PitchRNN, self).__init__()
+
+ self.sp_linear = nn.Sequential(nn.Conv1d(n_mels, hidden_size*2, kernel_size=1),
+ nn.SiLU(),
+ nn.Conv1d(hidden_size*2, hidden_size, kernel_size=1),
+ nn.SiLU(),)
+
+ self.midi_linear = nn.Sequential(nn.Conv1d(1, hidden_size*2, kernel_size=1),
+ nn.SiLU(),
+ nn.Conv1d(hidden_size*2, hidden_size, kernel_size=1),
+ nn.SiLU(),)
+
+ self.hidden_size = hidden_size
+
+ self.rnn = nn.GRU(input_size=hidden_size*2,
+ hidden_size=hidden_size,
+ num_layers=2,
+ batch_first=True,
+ bidirectional=True)
+ # self.silu = nn.SiLU()
+
+ self.linear = nn.Sequential(nn.Linear(2*hidden_size, hidden_size),
+ nn.SiLU(),
+ nn.Linear(hidden_size, 1))
+
+ def forward(self, midi, sp):
+ midi = midi.unsqueeze(1)
+ midi = self.midi_linear(midi)
+ sp = self.sp_linear(sp)
+
+ x = torch.cat([midi, sp], dim=1)
+ x = torch.transpose(x, 1, 2)
+ x, _ = self.rnn(x)
+ # x = self.silu(x)
+
+ x = self.linear(x)
+
+ return x.squeeze(-1)
+
+
+if __name__ == '__main__':
+
+ model = PitchRNN(100, 256)
+
+ x = torch.rand((4, 128))
+ t = torch.randint(0, 1000, (1, )).long()
+ sp = torch.rand((4, 100, 128))
+ midi = torch.rand((4, 128))
+
+ y = model(midi, sp)
\ No newline at end of file
diff --git a/pitch_predictor/models/transformer.py b/pitch_predictor/models/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ffedb36f077ddb4f81ad0ca3cf7e962b737cd11
--- /dev/null
+++ b/pitch_predictor/models/transformer.py
@@ -0,0 +1,134 @@
+import torch
+import torch.nn as nn
+
+import math
+from einops import rearrange
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=32, q_norm=True):
+ super(LinearAttention, self).__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = torch.nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = torch.nn.Conv1d(hidden_dim, dim, 1)
+ self.q_norm = q_norm
+
+ def forward(self, x):
+ # b, l, c = x.shape
+ x = x.permute(0, 2, 1)
+ # b, c, l = x.shape
+
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) l -> qkv b heads c l',
+ heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ if self.q_norm:
+ q = q.softmax(dim=-2)
+
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c l -> b (heads c) l',
+ heads=self.heads)
+ return self.to_out(out).permute(0, 2, 1)
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads=4, layer_norm_first=True):
+ super(TransformerBlock, self).__init__()
+ dim_head = dim//n_heads
+ self.attention = LinearAttention(dim, heads=n_heads, dim_head=dim_head)
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.feed_forward = nn.Sequential(nn.Linear(dim, dim*2),
+ nn.SiLU(),
+ nn.Linear(dim*2, dim))
+
+ self.dropout1 = nn.Dropout(0.2)
+ self.dropout2 = nn.Dropout(0.2)
+
+ self.layer_norm_first = layer_norm_first
+
+ def forward(self, x):
+ nx = self.norm1(x)
+ x = x + self.dropout1(self.attention(nx))
+ nx = self.norm2(x)
+ nx = x + self.dropout2(self.feed_forward(nx))
+ # attention_out = self.attention(x)
+ # attention_residual_out = attention_out + x
+ # # print(attention_residual_out.shape)
+ # norm1_out = self.dropout1(self.norm1(attention_residual_out))
+ #
+ # feed_fwd_out = self.feed_forward(norm1_out)
+ # feed_fwd_residual_out = feed_fwd_out + norm1_out
+ # norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out))
+ return nx
+
+
+class PitchFormer(nn.Module):
+ def __init__(self, n_mels, hidden_size, attn_layers=4):
+ super(PitchFormer, self).__init__()
+
+ self.sp_linear = nn.Sequential(nn.Conv1d(n_mels, hidden_size, kernel_size=1),
+ nn.SiLU(),
+ nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1)
+ )
+
+ self.midi_linear = nn.Sequential(nn.Conv1d(1, hidden_size, kernel_size=1),
+ nn.SiLU(),
+ nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1),
+ )
+
+ self.hidden_size = hidden_size
+
+ self.pos_conv = nn.Conv1d(hidden_size, hidden_size,
+ kernel_size=63,
+ padding=31,
+ )
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (63 * hidden_size))
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(self.pos_conv.bias, 0)
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, nn.SiLU())
+
+ self.attn_block = nn.ModuleList([TransformerBlock(hidden_size, 4) for i in range(attn_layers)])
+
+ # self.silu = nn.SiLU()
+
+ self.linear = nn.Sequential(nn.Linear(hidden_size, hidden_size),
+ nn.SiLU(),
+ nn.Linear(hidden_size, 1))
+
+ def forward(self, midi, sp):
+ midi = midi.unsqueeze(1)
+ midi = self.midi_linear(midi)
+ sp = self.sp_linear(sp)
+
+ x = torch.cat([midi, sp], dim=1)
+
+ # position encoding
+ x_conv = self.pos_conv(x)
+ x = x + x_conv
+
+ # x = self.silu(x)
+ x = x.permute(0, 2, 1)
+ for layer in self.attn_block:
+ x = layer(x)
+
+ x = self.linear(x)
+
+ return x.squeeze(-1)
+
+
+if __name__ == '__main__':
+
+ model = PitchFormer(100, 256)
+
+ x = torch.rand((4, 64))
+ sp = torch.rand((4, 100, 64))
+ midi = torch.rand((4, 64))
+
+ y = model(midi, sp)
\ No newline at end of file
diff --git a/pitch_predictor/train_transformer.py b/pitch_predictor/train_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fd484402dc1562e68e9189f0c1c06aa6a141eda
--- /dev/null
+++ b/pitch_predictor/train_transformer.py
@@ -0,0 +1,246 @@
+import os, json, argparse, yaml
+import numpy as np
+from tqdm import tqdm
+import librosa
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.cuda.amp import autocast, GradScaler
+
+from dataset.diffpitch import DiffPitch
+from models.transformer import PitchFormer
+from utils import minmax_norm_diff, reverse_minmax_norm_diff, save_curve_plot
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-config', type=str, default='config/DiffPitch.yaml')
+
+parser.add_argument('-seed', type=int, default=9811)
+parser.add_argument('-amp', type=bool, default=False)
+parser.add_argument('-compile', type=bool, default=False)
+
+parser.add_argument('-data_dir', type=str, default='data/')
+parser.add_argument('-content_dir', type=str, default='world')
+
+parser.add_argument('-train_frames', type=int, default=256)
+parser.add_argument('-test_frames', type=int, default=256)
+parser.add_argument('-batch_size', type=int, default=32)
+parser.add_argument('-test_size', type=int, default=1)
+parser.add_argument('-num_workers', type=int, default=4)
+parser.add_argument('-lr', type=float, default=5e-5)
+parser.add_argument('-weight_decay', type=int, default=1e-6)
+
+parser.add_argument('-epochs', type=int, default=1)
+parser.add_argument('-save_every', type=int, default=20)
+parser.add_argument('-log_step', type=int, default=100)
+parser.add_argument('-log_dir', type=str, default='logs_transformer_pitch')
+parser.add_argument('-ckpt_dir', type=str, default='ckpt_transformer_pitch')
+
+args = parser.parse_args()
+args.save_ori = True
+config = yaml.load(open(args.config), Loader=yaml.FullLoader)
+mel_cfg = config['logmel']
+ddpm_cfg = config['ddpm']
+# unet_cfg = config['unet']
+
+
+def RMSE(gen_f0, gt_f0):
+ # Get voiced part
+ gt_f0 = gt_f0[0]
+ gen_f0 = gen_f0[0]
+
+ nonzero_idxs = np.where((gen_f0 != 0) & (gt_f0 != 0))[0]
+ gen_f0_voiced = np.log2(gen_f0[nonzero_idxs])
+ gt_f0_voiced = np.log2(gt_f0[nonzero_idxs])
+ # log F0 RMSE
+ if len(gen_f0_voiced) != 0:
+ f0_rmse = np.sqrt(np.mean((gen_f0_voiced - gt_f0_voiced) ** 2))
+ else:
+ f0_rmse = 0
+ return f0_rmse
+
+
+if __name__ == "__main__":
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ if torch.cuda.is_available():
+ args.device = 'cuda'
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if torch.backends.cudnn.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+ else:
+ args.device = 'cpu'
+
+ if os.path.exists(args.log_dir) is False:
+ os.makedirs(args.log_dir)
+
+ if os.path.exists(args.ckpt_dir) is False:
+ os.makedirs(args.ckpt_dir)
+
+ print('Initializing data loaders...')
+ trainset = DiffPitch('data/', 'train', args.train_frames, shift=True)
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers,
+ drop_last=True, shuffle=True)
+
+ val_set = DiffPitch('data/', 'val', args.test_frames, shift=True)
+ val_loader = DataLoader(val_set, batch_size=1, shuffle=False)
+
+ test_set = DiffPitch('data/', 'test', args.test_frames, shift=True)
+ test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
+
+ real_set = DiffPitch('data/', 'real', args.test_frames, shift=False)
+ read_loader = DataLoader(real_set, batch_size=1, shuffle=False)
+
+ print('Initializing and loading models...')
+ model = PitchFormer(mel_cfg['n_mels'], 512).to(args.device)
+ ckpt = torch.load('ckpt_transformer_pitch/transformer_pitch_460.pt')
+ model.load_state_dict(ckpt)
+
+ print('Initializing optimizers...')
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+ scaler = GradScaler()
+
+ if args.compile:
+ model = torch.compile(model)
+
+ print('Start training.')
+ global_step = 0
+ for epoch in range(1, args.epochs + 1):
+ print(f'Epoch: {epoch} [iteration: {global_step}]')
+ model.train()
+ losses = []
+
+ for step, batch in enumerate(tqdm(train_loader)):
+ optimizer.zero_grad()
+ mel, midi, f0 = batch
+ mel = mel.to(args.device)
+ midi = midi.to(args.device)
+ f0 = f0.to(args.device)
+
+ if args.amp:
+ with autocast():
+ f0_pred = model(sp=mel, midi=midi)
+ loss = F.mse_loss(f0_pred, f0)
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ f0_pred = model(sp=mel, midi=midi)
+ loss = F.l1_loss(f0_pred, f0)
+ # Backward propagation
+ loss.backward()
+ optimizer.step()
+
+ losses.append(loss.item())
+ global_step += 1
+
+ if global_step % args.log_step == 0:
+ losses = np.asarray(losses)
+ # msg = 'Epoch %d: loss = %.4f\n' % (epoch, np.mean(losses))
+ msg = '\nEpoch: [{}][{}]\t' \
+ 'Batch: [{}][{}]\tLoss: {:.6f}\n'.format(epoch,
+ args.epochs,
+ step+1,
+ len(train_loader),
+ np.mean(losses))
+ with open(f'{args.log_dir}/train_dec.log', 'a') as f:
+ f.write(msg)
+ losses = []
+
+ if epoch % args.save_every > 0:
+ continue
+
+ print('Saving model...\n')
+ ckpt = model.state_dict()
+ torch.save(ckpt, f=f"{args.ckpt_dir}/transformer_pitch_{epoch}.pt")
+
+ print('Inference...\n')
+ model.eval()
+ with torch.no_grad():
+ val_loss = []
+ val_rmse = []
+ for i, batch in enumerate(val_loader):
+ # optimizer.zero_grad()
+ mel, midi, f0 = batch
+ mel = mel.to(args.device)
+ midi = midi.to(args.device)
+ f0 = f0.to(args.device)
+
+ f0_pred = model(sp=mel, midi=midi)
+
+ # save pred
+ f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0
+ f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6')
+
+ val_loss.append(F.l1_loss(f0_pred, f0).item())
+ val_rmse.append(RMSE(f0_pred.cpu().numpy(), f0.cpu().numpy()))
+
+ if i <= 4:
+ save_path = f'{args.log_dir}/pic/{i}/{epoch}_val.png'
+ if os.path.exists(os.path.dirname(save_path)) is False:
+ os.makedirs(os.path.dirname(save_path))
+ save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path)
+ # else:
+ # break
+
+ msg = '\nEpoch: [{}][{}]\tLoss: {:.6f}\tRMSE:{:.6f}\n'.\
+ format(epoch, args.epochs, np.mean(val_loss), np.mean(val_rmse))
+ with open(f'{args.log_dir}/eval_dec.log', 'a') as f:
+ f.write(msg)
+
+ test_loss = []
+ test_rmse = []
+ for i, batch in enumerate(test_loader):
+ # optimizer.zero_grad()
+ mel, midi, f0 = batch
+ mel = mel.to(args.device)
+ midi = midi.to(args.device)
+ f0 = f0.to(args.device)
+
+ f0_pred = model(sp=mel, midi=midi)
+
+ # save pred
+ f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0
+ f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6')
+
+ test_loss.append(F.l1_loss(f0_pred, f0).item())
+ test_rmse.append(RMSE(f0_pred.cpu().numpy(), f0.cpu().numpy()))
+
+ if i <= 4:
+ save_path = f'{args.log_dir}/pic/{i}/{epoch}_test.png'
+ if os.path.exists(os.path.dirname(save_path)) is False:
+ os.makedirs(os.path.dirname(save_path))
+ save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path)
+
+ msg = '\nEpoch: [{}][{}]\tLoss: {:.6f}\tRMSE:{:.6f}\n'. \
+ format(epoch, args.epochs, np.mean(test_loss), np.mean(test_rmse))
+ with open(f'{args.log_dir}/test_dec.log', 'a') as f:
+ f.write(msg)
+
+ for i, batch in enumerate(read_loader):
+ # optimizer.zero_grad()
+ mel, midi, f0 = batch
+ mel = mel.to(args.device)
+ midi = midi.to(args.device)
+ f0 = f0.to(args.device)
+
+ f0_pred = model(sp=mel, midi=midi)
+ f0_pred[f0 == 0] = 0
+
+ # save pred
+ f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0
+ f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6')
+
+ save_path = f'{args.log_dir}/pic/{i}/{epoch}_real.png'
+ if os.path.exists(os.path.dirname(save_path)) is False:
+ os.makedirs(os.path.dirname(save_path))
+ save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path)
+
+
+
+
diff --git a/pitch_predictor/utils.py b/pitch_predictor/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f39c5304f5463853870f19e5925c94655af4fc61
--- /dev/null
+++ b/pitch_predictor/utils.py
@@ -0,0 +1,46 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.io import wavfile
+import torch
+import librosa
+from torch.nn import functional as F
+
+
+def save_curve_plot(pred, midi, gt, savepath):
+ plt.style.use('default')
+ fig, ax = plt.subplots(figsize=(12, 3))
+
+ pred[pred == 0] = np.nan
+ midi[midi == 0] = np.nan
+ gt[gt == 0] = np.nan
+
+ # im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
+ ax.plot(range(len(pred)), pred, color='tab:green', label='pred')
+ ax.plot(range(len(midi)), midi, color='tab:blue', label='midi')
+ ax.plot(range(len(gt)), gt, color='grey', label='gt')
+ # plt.colorbar(im, ax=ax)
+ plt.tight_layout()
+ fig.canvas.draw()
+ plt.legend()
+ plt.savefig(savepath)
+ plt.close()
+#
+#
+# def save_audio(file_path, sampling_rate, audio):
+# audio = np.clip(audio.detach().cpu().squeeze().numpy(), -0.999, 0.999)
+# wavfile.write(file_path, sampling_rate, (audio * 32767).astype("int16"))
+
+
+def minmax_norm_diff(tensor: torch.Tensor, vmax: float = librosa.note_to_hz('C6'),
+ vmin: float = 0) -> torch.Tensor:
+ tensor = torch.clip(tensor, vmin, vmax)
+ tensor = 2 * (tensor - vmin) / (vmax - vmin) - 1
+ return tensor
+
+
+def reverse_minmax_norm_diff(tensor: torch.Tensor, vmax: float = librosa.note_to_hz('C6'),
+ vmin: float = 0) -> torch.Tensor:
+ tensor = torch.clip(tensor, -1.0, 1.0)
+ tensor = (tensor + 1) / 2
+ tensor = tensor * (vmax - vmin) + vmin
+ return tensor
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..976247a9d4f771489046246c2f07767b41a84721
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+diffusers
+einops
+fastdtw
+librosa
+matplotlib
+music21
+numpy
+pandas
+pretty_midi
+pysptk
+pyworld
+scipy
+soundfile
+tgt
+torch
+torchaudio
+tqdm
diff --git a/score_based_apc.py b/score_based_apc.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9259d641045d7bfdb6d5aede00f6cebcc28f477
--- /dev/null
+++ b/score_based_apc.py
@@ -0,0 +1,159 @@
+import os.path
+
+import numpy as np
+import pandas as pd
+import torch
+import yaml
+import librosa
+import soundfile as sf
+from tqdm import tqdm
+
+from diffusers import DDIMScheduler
+from pitch_controller.models.unet import UNetPitcher
+from pitch_controller.utils import minmax_norm_diff, reverse_minmax_norm_diff
+from pitch_controller.modules.BigVGAN.inference import load_model
+from utils import get_mel, get_world_mel, get_f0, f0_to_coarse, show_plot, get_matched_f0, log_f0
+from pitch_predictor.models.transformer import PitchFormer
+import pretty_midi
+
+
+def prepare_midi_wav(wav_id, midi_id, sr=24000):
+ midi = pretty_midi.PrettyMIDI(midi_id)
+ roll = midi.get_piano_roll()
+ roll = np.pad(roll, ((0, 0), (0, 1000)), constant_values=0)
+ roll[roll > 0] = 100
+
+ onset = midi.get_onsets()
+ before_onset = list(np.round(onset * 100 - 1).astype(int))
+ roll[:, before_onset] = 0
+
+ wav, sr = librosa.load(wav_id, sr=sr)
+
+ start = 0
+ end = round(100 * len(wav) / sr) / 100
+ # save audio
+ wav_seg = wav[round(start * sr):round(end * sr)]
+ cur_roll = roll[:, round(100 * start):round(100 * end)]
+ return wav_seg, cur_roll
+
+
+def algin_mapping(content, target_len):
+ # align content with mel
+ src_len = content.shape[-1]
+ target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
+ temp = torch.arange(src_len+1) * target_len / src_len
+
+ for i in range(target_len):
+ cur_idx = torch.argmin(torch.abs(temp-i))
+ target[:, i] = content[:, cur_idx]
+ return target
+
+
+def midi_to_hz(midi):
+ idx = torch.zeros(midi.shape[-1])
+ for frame in range(midi.shape[-1]):
+ midi_frame = midi[:, frame]
+ non_zero = midi_frame.nonzero()
+ if len(non_zero) != 0:
+ hz = librosa.midi_to_hz(non_zero[0])
+ idx[frame] = torch.tensor(hz)
+ return idx
+
+
+@torch.no_grad()
+def score_pitcher(source, pitch_ref, model, hifigan, pitcher, steps=50, shift_semi=0, mask_with_source=False):
+ wav, midi = prepare_midi_wav(source, pitch_ref, sr=sr)
+
+ source_mel = get_world_mel(None, sr=sr, wav=wav)
+
+ midi = torch.tensor(midi, dtype=torch.float32)
+ midi = algin_mapping(midi, source_mel.shape[-1])
+ midi = midi_to_hz(midi)
+
+ f0_ori = np.nan_to_num(get_f0(source))
+
+ source_mel = torch.from_numpy(source_mel).float().unsqueeze(0).to(device)
+ f0_ori = torch.from_numpy(f0_ori).float().unsqueeze(0).to(device)
+ midi = midi.unsqueeze(0).to(device)
+
+ f0_pred = pitcher(sp=source_mel, midi=midi)
+ if mask_with_source:
+ # mask unvoiced frames based on original pitch estimation
+ f0_pred[f0_ori == 0] = 0
+ f0_pred = f0_pred.cpu().numpy()[0]
+ # limit range
+ f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0
+ f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6')
+
+ f0_pred = f0_pred * (2 ** (shift_semi / 12))
+
+ f0_pred = log_f0(f0_pred, {'f0_bin': 345,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C#6')})
+ f0_pred = torch.from_numpy(f0_pred).float().unsqueeze(0).to(device)
+
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000)
+ generator = torch.Generator(device=device).manual_seed(2024)
+
+ noise_scheduler.set_timesteps(steps)
+ noise = torch.randn(source_mel.shape, generator=generator, device=device)
+ pred = noise
+ source_x = minmax_norm_diff(source_mel, vmax=max_mel, vmin=min_mel)
+
+ for t in tqdm(noise_scheduler.timesteps):
+ pred = noise_scheduler.scale_model_input(pred, t)
+ model_output = model(x=pred, mean=source_x, f0=f0_pred, t=t, ref=None, embed=None)
+ pred = noise_scheduler.step(model_output=model_output,
+ timestep=t,
+ sample=pred,
+ eta=1, generator=generator).prev_sample
+
+ pred = reverse_minmax_norm_diff(pred, vmax=max_mel, vmin=min_mel)
+
+ pred_audio = hifigan(pred)
+ pred_audio = pred_audio.cpu().squeeze().clamp(-1, 1)
+
+ return pred_audio
+
+
+if __name__ == '__main__':
+ min_mel = np.log(1e-5)
+ max_mel = 2.5
+ sr = 24000
+
+ use_gpu = torch.cuda.is_available()
+ device = 'cuda' if use_gpu else 'cpu'
+
+ # load diffusion model
+ config = yaml.load(open('pitch_controller/config/DiffWorld_24k.yaml'), Loader=yaml.FullLoader)
+ mel_cfg = config['logmel']
+ ddpm_cfg = config['ddpm']
+ unet_cfg = config['unet']
+ model = UNetPitcher(**unet_cfg)
+ unet_path = 'ckpts/world_fixed_40.pt'
+
+ state_dict = torch.load(unet_path)
+ for key in list(state_dict.keys()):
+ state_dict[key.replace('_orig_mod.', '')] = state_dict.pop(key)
+ model.load_state_dict(state_dict)
+ if use_gpu:
+ model.cuda()
+ model.eval()
+
+ # load vocoder
+ hifi_path = 'ckpts/bigvgan_24khz_100band/g_05000000.pt'
+ hifigan, cfg = load_model(hifi_path, device=device)
+ hifigan.eval()
+
+ # load pitch predictor
+ pitcher = PitchFormer(100, 512).to(device)
+ ckpt = torch.load('ckpts/ckpt_transformer_pitch/transformer_pitch_360.pt')
+ pitcher.load_state_dict(ckpt)
+ pitcher.eval()
+
+ pred_audio = score_pitcher('examples/score_vocal.wav', 'examples/score_midi.midi', model, hifigan, pitcher, steps=50)
+ sf.write('output_score.wav', pred_audio, samplerate=sr)
+
+
+
+
diff --git a/template_based_apc.py b/template_based_apc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ecdf7a6d92d1e763656e003e4718ea2e5853d3d
--- /dev/null
+++ b/template_based_apc.py
@@ -0,0 +1,89 @@
+import os.path
+
+import numpy as np
+import pandas as pd
+import torch
+import yaml
+import librosa
+import soundfile as sf
+from tqdm import tqdm
+
+from diffusers import DDIMScheduler
+from pitch_controller.models.unet import UNetPitcher
+from pitch_controller.utils import minmax_norm_diff, reverse_minmax_norm_diff
+from pitch_controller.modules.BigVGAN.inference import load_model
+from utils import get_mel, get_world_mel, get_f0, f0_to_coarse, show_plot, get_matched_f0, log_f0
+
+
+@torch.no_grad()
+def template_pitcher(source, pitch_ref, model, hifigan, steps=50, shift_semi=0):
+
+ source_mel = get_world_mel(source, sr=sr)
+
+ f0_ref = get_matched_f0(source, pitch_ref, 'world')
+ f0_ref = f0_ref * 2 ** (shift_semi / 12)
+
+ f0_ref = log_f0(f0_ref, {'f0_bin': 345,
+ 'f0_min': librosa.note_to_hz('C2'),
+ 'f0_max': librosa.note_to_hz('C#6')})
+
+ source_mel = torch.from_numpy(source_mel).float().unsqueeze(0).to(device)
+ f0_ref = torch.from_numpy(f0_ref).float().unsqueeze(0).to(device)
+
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000)
+ generator = torch.Generator(device=device).manual_seed(2024)
+
+ noise_scheduler.set_timesteps(steps)
+ noise = torch.randn(source_mel.shape, generator=generator, device=device)
+ pred = noise
+ source_x = minmax_norm_diff(source_mel, vmax=max_mel, vmin=min_mel)
+
+ for t in tqdm(noise_scheduler.timesteps):
+ pred = noise_scheduler.scale_model_input(pred, t)
+ model_output = model(x=pred, mean=source_x, f0=f0_ref, t=t, ref=None, embed=None)
+ pred = noise_scheduler.step(model_output=model_output,
+ timestep=t,
+ sample=pred,
+ eta=1, generator=generator).prev_sample
+
+ pred = reverse_minmax_norm_diff(pred, vmax=max_mel, vmin=min_mel)
+
+ pred_audio = hifigan(pred)
+ pred_audio = pred_audio.cpu().squeeze().clamp(-1, 1)
+
+ return pred_audio
+
+
+if __name__ == '__main__':
+ min_mel = np.log(1e-5)
+ max_mel = 2.5
+ sr = 24000
+
+ use_gpu = torch.cuda.is_available()
+ device = 'cuda' if use_gpu else 'cpu'
+
+ # load diffusion model
+ config = yaml.load(open('pitch_controller/config/DiffWorld_24k.yaml'), Loader=yaml.FullLoader)
+ mel_cfg = config['logmel']
+ ddpm_cfg = config['ddpm']
+ unet_cfg = config['unet']
+ model = UNetPitcher(**unet_cfg)
+ unet_path = 'ckpts/world_fixed_40.pt'
+
+ state_dict = torch.load(unet_path)
+ for key in list(state_dict.keys()):
+ state_dict[key.replace('_orig_mod.', '')] = state_dict.pop(key)
+ model.load_state_dict(state_dict)
+ if use_gpu:
+ model.cuda()
+ model.eval()
+
+ # load vocoder
+ hifi_path = 'ckpts/bigvgan_24khz_100band/g_05000000.pt'
+ hifigan, cfg = load_model(hifi_path, device=device)
+ hifigan.eval()
+
+ pred_audio = template_pitcher('examples/off-key.wav', 'examples/reference.wav', model, hifigan, steps=50, shift_semi=0)
+ sf.write('output_template.wav', pred_audio, samplerate=sr)
+
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b1a669482a9314876a5c80a78127bc786b992da
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,206 @@
+import numpy as np
+import torch
+import librosa
+from librosa.core import load
+import matplotlib.pyplot as plt
+import pysptk
+import pyworld as pw
+from fastdtw import fastdtw
+from scipy import spatial
+
+from librosa.filters import mel as librosa_mel_fn
+mel_basis = librosa_mel_fn(sr=24000, n_fft=1024, n_mels=100, fmin=0, fmax=12000)
+
+
+def _get_best_mcep_params(fs):
+ if fs == 16000:
+ return 23, 0.42
+ elif fs == 22050:
+ return 34, 0.45
+ elif fs == 24000:
+ return 34, 0.46
+ elif fs == 44100:
+ return 39, 0.53
+ elif fs == 48000:
+ return 39, 0.55
+ else:
+ raise ValueError(f"Not found the setting for {fs}.")
+
+
+def get_mel(wav_path):
+ wav, _ = load(wav_path, sr=24000)
+ wav = wav[:(wav.shape[0] // 256)*256]
+ wav = np.pad(wav, 384, mode='reflect')
+ stft = librosa.core.stft(wav, n_fft=1024, hop_length=256, win_length=1024, window='hann', center=False)
+ stftm = np.sqrt(np.real(stft) ** 2 + np.imag(stft) ** 2 + (1e-9))
+ mel_spectrogram = np.matmul(mel_basis, stftm)
+ if mel_spectrogram.shape[-1] % 8 != 0:
+ mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, 8 - mel_spectrogram.shape[-1] % 8)), 'minimum')
+
+ log_mel_spectrogram = np.log(np.clip(mel_spectrogram, a_min=1e-5, a_max=None))
+ return log_mel_spectrogram
+
+
+def get_world_mel(wav_path=None, sr=24000, wav=None):
+ if wav_path is not None:
+ wav, _ = librosa.load(wav_path, sr=24000)
+ wav = (wav * 32767).astype(np.int16)
+ wav = (wav / 32767).astype(np.float64)
+ # wav = wav.astype(np.float64)
+ wav = wav[:(wav.shape[0] // 256) * 256]
+
+ # _f0, t = pw.dio(wav, sr, frame_period=256/sr*1000)
+ _f0, t = pw.dio(wav, sr)
+ f0 = pw.stonemask(wav, _f0, t, sr)
+ sp = pw.cheaptrick(wav, f0, t, sr)
+ ap = pw.d4c(wav, f0, t, sr)
+ wav_hat = pw.synthesize(f0 * 0, sp, ap, sr)
+ # wav_hat = pw.synthesize(f0 * 0, sp, ap, sr, frame_period=256/sr*1000)
+
+ # pyworld output does not pad left
+ wav_hat = wav_hat[:len(wav)]
+ # wav_hat = wav_hat[256//2: len(wav)+256//2]
+ assert len(wav_hat) == len(wav)
+ wav = wav_hat.astype(np.float32)
+ wav = np.pad(wav, 384, mode='reflect')
+ stft = librosa.core.stft(wav, n_fft=1024, hop_length=256, win_length=1024, window='hann', center=False)
+ stftm = np.sqrt(np.real(stft) ** 2 + np.imag(stft) ** 2 + (1e-9))
+ mel_spectrogram = np.matmul(mel_basis, stftm)
+ if mel_spectrogram.shape[-1] % 8 != 0:
+ mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, 8 - mel_spectrogram.shape[-1] % 8)), 'minimum')
+
+ log_mel_spectrogram = np.log(np.clip(mel_spectrogram, a_min=1e-5, a_max=None))
+ return log_mel_spectrogram
+
+
+def get_f0(wav_path, method='pyin', padding=True):
+ if method == 'pyin':
+ wav, sr = load(wav_path, sr=24000)
+ wav = wav[:(wav.shape[0] // 256) * 256]
+ wav = np.pad(wav, 384, mode='reflect')
+ f0, _, _ = librosa.pyin(wav, frame_length=1024, hop_length=256, center=False, sr=24000,
+ fmin=librosa.note_to_hz('C2'),
+ fmax=librosa.note_to_hz('C6'), fill_na=0)
+ elif method == 'world':
+ wav, sr = librosa.load(wav_path, sr=24000)
+ wav = (wav * 32767).astype(np.int16)
+ wav = (wav / 32767).astype(np.float64)
+ _f0, t = pw.dio(wav, fs=24000, frame_period=256/sr*1000,
+ f0_floor=librosa.note_to_hz('C2'),
+ f0_ceil=librosa.note_to_hz('C6'))
+ f0 = pw.stonemask(wav, _f0, t, sr)
+ f0 = f0[:-1]
+
+ if padding is True:
+ if f0.shape[-1] % 8 !=0:
+ f0 = np.pad(f0, ((0, 8-f0.shape[-1] % 8)), 'constant', constant_values=0)
+
+ return f0
+
+
+def get_mcep(x, n_fft=1024, n_shift=256, sr=24000):
+ x, sr = load(x, sr=24000)
+ n_frame = (x.shape[0] // 256)
+ x = np.pad(x, 384, mode='reflect')
+ # n_frame = (len(x) - n_fft) // n_shift + 1
+ win = pysptk.sptk.hamming(n_fft)
+ mcep_dim, mcep_alpha = _get_best_mcep_params(sr)
+ mcep = [pysptk.mcep(x[n_shift * i: n_shift * i + n_fft] * win,
+ mcep_dim, mcep_alpha,
+ eps=1e-6, etype=1,)
+ for i in range(n_frame)
+ ]
+ mcep = np.stack(mcep)
+ return mcep
+
+
+def get_matched_f0(x, y, method='world', n_fft=1024, n_shift=256):
+ # f0_x = get_f0(x, method='pyin', padding=False)
+ f0_y = get_f0(y, method=method, padding=False)
+ # print(f0_y.max())
+ # print(f0_y.min())
+
+ mcep_x = get_mcep(x, n_fft=n_fft, n_shift=n_shift)
+ mcep_y = get_mcep(y, n_fft=n_fft, n_shift=n_shift)
+
+ _, path = fastdtw(mcep_x, mcep_y, dist=spatial.distance.euclidean)
+ twf = np.array(path).T
+ # f0_x = gen_mcep[twf[0]]
+ nearest = []
+ for i in range(len(f0_y)):
+ idx = np.argmax(1 * twf[0] == i)
+ nearest.append(twf[1][idx])
+
+ f0_y = f0_y[nearest]
+
+ # f0_y = f0_y.astype(np.float32)
+
+ if f0_y.shape[-1] % 8 != 0:
+ f0_y = np.pad(f0_y, ((0, 8 - f0_y.shape[-1] % 8)), 'constant', constant_values=0)
+
+ return f0_y
+
+
+def f0_to_coarse(f0, hparams):
+
+ f0_bin = hparams['f0_bin']
+ f0_max = hparams['f0_max']
+ f0_min = hparams['f0_min']
+ is_torch = isinstance(f0, torch.Tensor)
+ # to mel scale
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+
+ unvoiced = (f0_mel == 0)
+
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+
+ f0_mel[unvoiced] = 0
+
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 0, (f0_coarse.max(), f0_coarse.min())
+ return f0_coarse
+
+
+def log_f0(f0, hparams):
+ f0_bin = hparams['f0_bin']
+ f0_max = hparams['f0_max']
+ f0_min = hparams['f0_min']
+
+ f0_mel = np.zeros_like(f0)
+ f0_mel[f0 != 0] = 12*np.log2(f0[f0 != 0]/f0_min) + 1
+ f0_mel_min = 12*np.log2(f0_min/f0_min) + 1
+ f0_mel_max = 12*np.log2(f0_max/f0_min) + 1
+
+ unvoiced = (f0_mel == 0)
+
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+
+ f0_mel[unvoiced] = 0
+
+ f0_coarse = np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= (f0_bin-1) and f0_coarse.min() >= 0, (f0_coarse.max(), f0_coarse.min())
+ return f0_coarse
+
+
+def show_plot(tensor):
+ tensor = tensor.squeeze().cpu()
+ # plt.style.use('default')
+ fig, ax = plt.subplots(figsize=(12, 3))
+ im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
+ plt.colorbar(im, ax=ax)
+ plt.tight_layout()
+ fig.canvas.draw()
+ plt.show()
+
+
+if __name__ == '__main__':
+ mel = get_mel('target.wav')
+ f0 = get_f0('target.wav')
\ No newline at end of file