Spaces:
Running
on
Zero
Running
on
Zero
Serhiy Stetskovych
commited on
Commit
•
78e32cc
0
Parent(s):
Initial code
Browse files- .gitattributes +35 -0
- .gitignore +4 -0
- README.md +10 -0
- app.py +142 -0
- configs/apollo.yaml +106 -0
- inference.py +150 -0
- look2hear/__init__.py +0 -0
- look2hear/datas/__init__.py +11 -0
- look2hear/datas/musdb_moisesdb_datamodule.py +215 -0
- look2hear/discriminators/__init__.py +47 -0
- look2hear/discriminators/frequencydis.py +81 -0
- look2hear/losses/__init__.py +14 -0
- look2hear/losses/gan_losses.py +58 -0
- look2hear/losses/matrix.py +46 -0
- look2hear/metrics/__init__.py +9 -0
- look2hear/metrics/wrapper.py +86 -0
- look2hear/models/__init__.py +49 -0
- look2hear/models/apollo.py +303 -0
- look2hear/models/base_model.py +96 -0
- look2hear/system/__init__.py +17 -0
- look2hear/system/audio_litmodule.py +245 -0
- look2hear/system/optimizers.py +113 -0
- look2hear/system/schedulers.py +129 -0
- look2hear/utils/__init__.py +53 -0
- look2hear/utils/complex_utils.py +191 -0
- look2hear/utils/get_layer_from_string.py +43 -0
- look2hear/utils/inversible_interface.py +13 -0
- look2hear/utils/lightning_utils.py +110 -0
- look2hear/utils/nets_utils.py +503 -0
- look2hear/utils/parser_utils.py +178 -0
- look2hear/utils/pylogger.py +54 -0
- look2hear/utils/separator.py +138 -0
- look2hear/utils/stft.py +797 -0
- look2hear/utils/torch_utils.py +49 -0
- requirements.txt +11 -0
- weights/apollo.bin +3 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
__pycache__
|
3 |
+
.venv
|
4 |
+
.DS_Store
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Apollo
|
3 |
+
emoji: 💻
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.5.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
app.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torchaudio
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import yaml
|
7 |
+
import librosa
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
import look2hear.models
|
11 |
+
from ml_collections import ConfigDict
|
12 |
+
|
13 |
+
def load_audio(file_path):
|
14 |
+
audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
|
15 |
+
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
|
16 |
+
#audio = dBgain(audio, -6)
|
17 |
+
return torch.from_numpy(audio), samplerate
|
18 |
+
|
19 |
+
|
20 |
+
def get_config(config_path):
|
21 |
+
with open(config_path) as f:
|
22 |
+
#config = OmegaConf.load(config_path)
|
23 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
24 |
+
return config
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def _getWindowingArray(window_size, fade_size):
|
30 |
+
# IMPORTANT NOTE :
|
31 |
+
# no fades here in the end, only removing the failed ending of the chunk
|
32 |
+
fadein = torch.linspace(1, 1, fade_size)
|
33 |
+
fadeout = torch.linspace(0, 0, fade_size)
|
34 |
+
window = torch.ones(window_size)
|
35 |
+
window[-fade_size:] *= fadeout
|
36 |
+
window[:fade_size] *= fadein
|
37 |
+
return window
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
description = f'''
|
42 |
+
texts
|
43 |
+
'''
|
44 |
+
|
45 |
+
|
46 |
+
apollo_config = get_config('configs/apollo.yaml')
|
47 |
+
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).cuda()
|
48 |
+
|
49 |
+
models = [
|
50 |
+
('MP3 restore', apollo_model)
|
51 |
+
]
|
52 |
+
|
53 |
+
@spaces.GPU
|
54 |
+
def enchance(model, audio):
|
55 |
+
test_data, samplerate = load_audio(audio)
|
56 |
+
C = 10 * samplerate # chunk_size seconds to samples
|
57 |
+
N = 2
|
58 |
+
step = C // N
|
59 |
+
fade_size = 3 * 44100 # 3 seconds
|
60 |
+
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
|
61 |
+
|
62 |
+
border = C - step
|
63 |
+
|
64 |
+
# handle mono inputs correctly
|
65 |
+
if len(test_data.shape) == 1:
|
66 |
+
test_data = test_data.unsqueeze(0)
|
67 |
+
|
68 |
+
# Pad the input if necessary
|
69 |
+
if test_data.shape[1] > 2 * border and (border > 0):
|
70 |
+
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
|
71 |
+
|
72 |
+
windowingArray = _getWindowingArray(C, fade_size)
|
73 |
+
|
74 |
+
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
|
75 |
+
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
|
76 |
+
|
77 |
+
i = 0
|
78 |
+
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
|
79 |
+
|
80 |
+
while i < test_data.shape[1]:
|
81 |
+
part = test_data[:, i:i + C]
|
82 |
+
length = part.shape[-1]
|
83 |
+
if length < C:
|
84 |
+
if length > C // 2 + 1:
|
85 |
+
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
86 |
+
else:
|
87 |
+
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
88 |
+
|
89 |
+
|
90 |
+
chunk = part.unsqueeze(0).cuda()
|
91 |
+
with torch.no_grad():
|
92 |
+
out = model(chunk).squeeze(0).squeeze(0).cpu()
|
93 |
+
|
94 |
+
window = windowingArray
|
95 |
+
if i == 0: # First audio chunk, no fadein
|
96 |
+
window[:fade_size] = 1
|
97 |
+
elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
|
98 |
+
window[-fade_size:] = 1
|
99 |
+
|
100 |
+
result[..., i:i+length] += out[..., :length] * window[..., :length]
|
101 |
+
counter[..., i:i+length] += window[..., :length]
|
102 |
+
|
103 |
+
i += step
|
104 |
+
progress_bar.update(step)
|
105 |
+
|
106 |
+
progress_bar.close()
|
107 |
+
|
108 |
+
final_output = result / counter
|
109 |
+
final_output = final_output.squeeze(0).numpy()
|
110 |
+
np.nan_to_num(final_output, copy=False, nan=0.0)
|
111 |
+
|
112 |
+
# Remove padding if added earlier
|
113 |
+
if test_data.shape[1] > 2 * border and (border > 0):
|
114 |
+
final_output = final_output[..., border:-border]
|
115 |
+
|
116 |
+
return samplerate, final_output.T
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
i = gr.Interface(
|
121 |
+
fn=enchance,
|
122 |
+
description=description,
|
123 |
+
inputs=[
|
124 |
+
gr.Dropdown(label="Model", choices=models, value=models[0]),
|
125 |
+
gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_progress_color': '#3C82F6'}),
|
126 |
+
],
|
127 |
+
outputs=[
|
128 |
+
gr.Audio(
|
129 |
+
label="Output Audio",
|
130 |
+
autoplay=False,
|
131 |
+
streaming=False,
|
132 |
+
type="numpy",
|
133 |
+
),
|
134 |
+
|
135 |
+
],
|
136 |
+
allow_flagging ='never',
|
137 |
+
cache_examples=False,
|
138 |
+
title='Enchanser',
|
139 |
+
|
140 |
+
)
|
141 |
+
i.queue(max_size=20, default_concurrency_limit=4)
|
142 |
+
i.launch(share=False, server_name="0.0.0.0")
|
configs/apollo.yaml
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exp:
|
2 |
+
dir: ./Exps
|
3 |
+
name: Apollo
|
4 |
+
|
5 |
+
# seed: 614020
|
6 |
+
|
7 |
+
datas:
|
8 |
+
_target_: look2hear.datas.MusdbMoisesdbDataModule
|
9 |
+
train_dir: ./hdf5_datas
|
10 |
+
eval_dir: ./eval
|
11 |
+
codec_type: mp3
|
12 |
+
codec_options:
|
13 |
+
bitrate: random
|
14 |
+
compression: random
|
15 |
+
complexity: random
|
16 |
+
vbr: random
|
17 |
+
sr: 44100
|
18 |
+
segments: 3
|
19 |
+
num_stems: 8
|
20 |
+
snr_range: [-10, 10]
|
21 |
+
num_samples: 40000
|
22 |
+
batch_size: 1
|
23 |
+
num_workers: 8
|
24 |
+
|
25 |
+
model:
|
26 |
+
|
27 |
+
sr: 44100
|
28 |
+
win: 20 # ms
|
29 |
+
feature_dim: 256
|
30 |
+
layer: 6
|
31 |
+
|
32 |
+
discriminator:
|
33 |
+
_target_: look2hear.discriminators.frequencydis.MultiFrequencyDiscriminator
|
34 |
+
nch: 2
|
35 |
+
window: [32, 64, 128, 256, 512, 1024, 2048]
|
36 |
+
|
37 |
+
optimizer_g:
|
38 |
+
_target_: torch.optim.AdamW
|
39 |
+
lr: 0.001
|
40 |
+
weight_decay: 0.01
|
41 |
+
|
42 |
+
optimizer_d:
|
43 |
+
_target_: torch.optim.AdamW
|
44 |
+
lr: 0.0001
|
45 |
+
weight_decay: 0.01
|
46 |
+
betas: [0.5, 0.99]
|
47 |
+
|
48 |
+
scheduler_g:
|
49 |
+
_target_: torch.optim.lr_scheduler.StepLR
|
50 |
+
step_size: 2
|
51 |
+
gamma: 0.98
|
52 |
+
|
53 |
+
scheduler_d:
|
54 |
+
_target_: torch.optim.lr_scheduler.StepLR
|
55 |
+
step_size: 2
|
56 |
+
gamma: 0.98
|
57 |
+
|
58 |
+
loss_g:
|
59 |
+
_target_: look2hear.losses.gan_losses.MultiFrequencyGenLoss
|
60 |
+
eps: 1e-8
|
61 |
+
|
62 |
+
loss_d:
|
63 |
+
_target_: look2hear.losses.gan_losses.MultiFrequencyDisLoss
|
64 |
+
eps: 1e-8
|
65 |
+
|
66 |
+
metrics:
|
67 |
+
_target_: look2hear.losses.MultiSrcNegSDR
|
68 |
+
sdr_type: sisdr
|
69 |
+
|
70 |
+
system:
|
71 |
+
_target_: look2hear.system.audio_litmodule.AudioLightningModule
|
72 |
+
|
73 |
+
early_stopping:
|
74 |
+
_target_: pytorch_lightning.callbacks.EarlyStopping
|
75 |
+
monitor: val_loss
|
76 |
+
patience: 20
|
77 |
+
mode: min
|
78 |
+
verbose: true
|
79 |
+
|
80 |
+
checkpoint:
|
81 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
82 |
+
dirpath: ${exp.dir}/${exp.name}/checkpoints
|
83 |
+
monitor: val_loss
|
84 |
+
mode: min
|
85 |
+
verbose: true
|
86 |
+
save_top_k: 5
|
87 |
+
save_last: true
|
88 |
+
filename: '{epoch}-{val_loss:.4f}'
|
89 |
+
|
90 |
+
logger:
|
91 |
+
_target_: pytorch_lightning.loggers.WandbLogger
|
92 |
+
name: ${exp.name}
|
93 |
+
save_dir: ${exp.dir}/${exp.name}/logs
|
94 |
+
offline: false
|
95 |
+
project: Audio-Restoration
|
96 |
+
|
97 |
+
trainer:
|
98 |
+
_target_: pytorch_lightning.Trainer
|
99 |
+
devices: [0,1,2,3,4,5,6,7]
|
100 |
+
max_epochs: 500
|
101 |
+
sync_batchnorm: true
|
102 |
+
default_root_dir: ${exp.dir}/${exp.name}/
|
103 |
+
accelerator: cuda
|
104 |
+
limit_train_batches: 1.0
|
105 |
+
fast_dev_run: false
|
106 |
+
|
inference.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import look2hear.models
|
5 |
+
import soundfile as sf
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
import yaml
|
10 |
+
from ml_collections import ConfigDict
|
11 |
+
#from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
import warnings
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
def get_config(config_path):
|
17 |
+
with open(config_path) as f:
|
18 |
+
#config = OmegaConf.load(config_path)
|
19 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
20 |
+
return config
|
21 |
+
|
22 |
+
def load_audio(file_path):
|
23 |
+
audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
|
24 |
+
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
|
25 |
+
#audio = dBgain(audio, -6)
|
26 |
+
return torch.from_numpy(audio), samplerate
|
27 |
+
|
28 |
+
def save_audio(file_path, audio, samplerate=44100):
|
29 |
+
#audio = dBgain(audio, +6)
|
30 |
+
sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
|
31 |
+
|
32 |
+
def process_chunk(chunk):
|
33 |
+
chunk = chunk.unsqueeze(0).cpu()
|
34 |
+
with torch.no_grad():
|
35 |
+
return model(chunk).squeeze(0).squeeze(0).cpu()
|
36 |
+
|
37 |
+
def _getWindowingArray(window_size, fade_size):
|
38 |
+
# IMPORTANT NOTE :
|
39 |
+
# no fades here in the end, only removing the failed ending of the chunk
|
40 |
+
fadein = torch.linspace(1, 1, fade_size)
|
41 |
+
fadeout = torch.linspace(0, 0, fade_size)
|
42 |
+
window = torch.ones(window_size)
|
43 |
+
window[-fade_size:] *= fadeout
|
44 |
+
window[:fade_size] *= fadein
|
45 |
+
return window
|
46 |
+
|
47 |
+
def dBgain(audio, volume_gain_dB):
|
48 |
+
gain = 10 ** (volume_gain_dB / 20)
|
49 |
+
gained_audio = audio * gain
|
50 |
+
return gained_audio
|
51 |
+
|
52 |
+
|
53 |
+
def main(input_wav, output_wav, ckpt_path):
|
54 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
55 |
+
|
56 |
+
global model
|
57 |
+
feature_dim = config['model']['feature_dim']
|
58 |
+
sr = config['model']['sr']
|
59 |
+
win = config['model']['win']
|
60 |
+
layer = config['model']['layer']
|
61 |
+
model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cpu()
|
62 |
+
|
63 |
+
test_data, samplerate = load_audio(input_wav)
|
64 |
+
|
65 |
+
C = chunk_size * samplerate # chunk_size seconds to samples
|
66 |
+
N = overlap
|
67 |
+
step = C // N
|
68 |
+
fade_size = 3 * 44100 # 3 seconds
|
69 |
+
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
|
70 |
+
|
71 |
+
border = C - step
|
72 |
+
|
73 |
+
# handle mono inputs correctly
|
74 |
+
if len(test_data.shape) == 1:
|
75 |
+
test_data = test_data.unsqueeze(0)
|
76 |
+
|
77 |
+
# Pad the input if necessary
|
78 |
+
if test_data.shape[1] > 2 * border and (border > 0):
|
79 |
+
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
|
80 |
+
|
81 |
+
windowingArray = _getWindowingArray(C, fade_size)
|
82 |
+
|
83 |
+
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
|
84 |
+
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
|
85 |
+
|
86 |
+
i = 0
|
87 |
+
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
|
88 |
+
|
89 |
+
while i < test_data.shape[1]:
|
90 |
+
part = test_data[:, i:i + C]
|
91 |
+
length = part.shape[-1]
|
92 |
+
if length < C:
|
93 |
+
if length > C // 2 + 1:
|
94 |
+
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
95 |
+
else:
|
96 |
+
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
97 |
+
|
98 |
+
out = process_chunk(part)
|
99 |
+
|
100 |
+
window = windowingArray
|
101 |
+
if i == 0: # First audio chunk, no fadein
|
102 |
+
window[:fade_size] = 1
|
103 |
+
elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
|
104 |
+
window[-fade_size:] = 1
|
105 |
+
|
106 |
+
result[..., i:i+length] += out[..., :length] * window[..., :length]
|
107 |
+
counter[..., i:i+length] += window[..., :length]
|
108 |
+
|
109 |
+
i += step
|
110 |
+
progress_bar.update(step)
|
111 |
+
|
112 |
+
progress_bar.close()
|
113 |
+
|
114 |
+
final_output = result / counter
|
115 |
+
final_output = final_output.squeeze(0).numpy()
|
116 |
+
np.nan_to_num(final_output, copy=False, nan=0.0)
|
117 |
+
|
118 |
+
# Remove padding if added earlier
|
119 |
+
if test_data.shape[1] > 2 * border and (border > 0):
|
120 |
+
final_output = final_output[..., border:-border]
|
121 |
+
|
122 |
+
save_audio(output_wav, final_output, samplerate)
|
123 |
+
print(f'Success! Output file saved as {output_wav}')
|
124 |
+
|
125 |
+
# Memory clearing
|
126 |
+
model.cpu()
|
127 |
+
del model
|
128 |
+
torch.cuda.empty_cache()
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
parser = argparse.ArgumentParser(description="Audio Inference Script")
|
132 |
+
parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
|
133 |
+
parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
|
134 |
+
parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file", default="model/pytorch_model.bin")
|
135 |
+
parser.add_argument("--config", type=str, help="Path to model config file", default="config/apollo.yaml")
|
136 |
+
parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=10)
|
137 |
+
parser.add_argument("--overlap", type=int, help="Overlap", default=2)
|
138 |
+
args = parser.parse_args()
|
139 |
+
|
140 |
+
ckpt_path = args.ckpt
|
141 |
+
chunk_size = args.chunk_size
|
142 |
+
overlap = args.overlap
|
143 |
+
config = get_config(args.config)
|
144 |
+
print(config['model'])
|
145 |
+
print(f'ckpt_path = {ckpt_path}')
|
146 |
+
#print(f'config = {config}')
|
147 |
+
print(f'chunk_size = {chunk_size}, overlap = {overlap}')
|
148 |
+
|
149 |
+
|
150 |
+
main(args.in_wav, args.out_wav, ckpt_path)
|
look2hear/__init__.py
ADDED
File without changes
|
look2hear/datas/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-03 18:29:46
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2022-07-29 06:23:03
|
6 |
+
###
|
7 |
+
from .musdb_moisesdb_datamodule import MusdbMoisesdbDataModule
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"MusdbMoisesdbDataModule"
|
11 |
+
]
|
look2hear/datas/musdb_moisesdb_datamodule.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
from typing import Any, Tuple
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
from pytorch_lightning import LightningDataModule
|
8 |
+
import torchaudio
|
9 |
+
from torchaudio.functional import apply_codec
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from typing import Any, Dict, Optional, Tuple
|
12 |
+
|
13 |
+
def compute_mch_rms_dB(mch_wav, fs=16000, energy_thresh=-50):
|
14 |
+
"""Return the wav RMS calculated only in the active portions"""
|
15 |
+
mean_square = max(1e-20, torch.mean(mch_wav ** 2))
|
16 |
+
return 10 * np.log10(mean_square)
|
17 |
+
|
18 |
+
def match2(x, d):
|
19 |
+
assert x.dim()==2, x.shape
|
20 |
+
assert d.dim()==2, d.shape
|
21 |
+
minlen = min(x.shape[-1], d.shape[-1])
|
22 |
+
x, d = x[:,0:minlen], d[:,0:minlen]
|
23 |
+
Fx = torch.fft.rfft(x, dim=-1)
|
24 |
+
Fd = torch.fft.rfft(d, dim=-1)
|
25 |
+
Phi = Fd*Fx.conj()
|
26 |
+
Phi = Phi / (Phi.abs() + 1e-3)
|
27 |
+
Phi[:,0] = 0
|
28 |
+
tmp = torch.fft.irfft(Phi, dim=-1)
|
29 |
+
tau = torch.argmax(tmp.abs(),dim=-1).tolist()
|
30 |
+
return tau
|
31 |
+
|
32 |
+
def codec_simu(wav, sr=16000, options={'bitrate':'random','compression':'random', 'complexity':'random', 'vbr':'random'}):
|
33 |
+
|
34 |
+
if options['bitrate'] == 'random':
|
35 |
+
options['bitrate'] = random.choice([24000, 32000, 48000, 64000, 96000, 128000])
|
36 |
+
compression = int(options['bitrate']//1000)
|
37 |
+
param = {'format': "mp3", "compression": compression}
|
38 |
+
wav_encdec = apply_codec(wav, sr, **param)
|
39 |
+
if wav_encdec.shape[-1] >= wav.shape[-1]:
|
40 |
+
wav_encdec = wav_encdec[...,:wav.shape[-1]]
|
41 |
+
else:
|
42 |
+
wav_encdec = torch.cat([wav_encdec, wav[..., wav_encdec.shape[-1]:]], -1)
|
43 |
+
tau = match2(wav, wav_encdec)
|
44 |
+
wav_encdec = torch.roll(wav_encdec, -tau[0], -1)
|
45 |
+
|
46 |
+
return wav_encdec
|
47 |
+
|
48 |
+
def get_wav_files(root_dir):
|
49 |
+
wav_files = []
|
50 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
51 |
+
for filename in filenames:
|
52 |
+
if filename.endswith('.wav'):
|
53 |
+
if "musdb18hq" in dirpath and "mixture" not in filename:
|
54 |
+
wav_files.append(os.path.join(dirpath, filename))
|
55 |
+
elif "moisesdb" in dirpath:
|
56 |
+
wav_files.append(os.path.join(dirpath, filename))
|
57 |
+
return wav_files
|
58 |
+
|
59 |
+
class MusdbMoisesdbDataset(Dataset):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
data_dir: str,
|
63 |
+
codec_type: str,
|
64 |
+
codec_options: dict,
|
65 |
+
sr: int = 16000,
|
66 |
+
segments: int = 10,
|
67 |
+
num_stems: int = 4,
|
68 |
+
snr_range: Tuple[int, int] = (-10, 10),
|
69 |
+
num_samples: int = 1000,
|
70 |
+
) -> None:
|
71 |
+
|
72 |
+
self.data_dir = data_dir
|
73 |
+
self.codec_type = codec_type
|
74 |
+
self.codec_options = codec_options
|
75 |
+
self.segments = int(segments * sr)
|
76 |
+
self.sr = sr
|
77 |
+
self.num_stems = num_stems
|
78 |
+
self.snr_range = snr_range
|
79 |
+
self.num_samples = num_samples
|
80 |
+
|
81 |
+
self.instruments = [
|
82 |
+
"bass",
|
83 |
+
"bowed_strings",
|
84 |
+
"drums",
|
85 |
+
"guitar",
|
86 |
+
"other",
|
87 |
+
"other_keys",
|
88 |
+
"other_plucked",
|
89 |
+
"percussion",
|
90 |
+
"piano",
|
91 |
+
"vocals",
|
92 |
+
"wind"
|
93 |
+
]
|
94 |
+
|
95 |
+
def __len__(self) -> int:
|
96 |
+
return self.num_samples
|
97 |
+
|
98 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
99 |
+
if random.random() > 0.5:
|
100 |
+
select_stems = random.randint(1, self.num_stems)
|
101 |
+
select_stems = random.choices(self.instruments, k=select_stems)
|
102 |
+
ori_wav = []
|
103 |
+
for stem in select_stems:
|
104 |
+
h5path = random.choice(os.listdir(os.path.join(self.data_dir, stem)))
|
105 |
+
datas = h5py.File(os.path.join(self.data_dir, stem, h5path), 'r')['data']
|
106 |
+
random_index = random.randint(0, datas.shape[0]-1)
|
107 |
+
music_wav = torch.FloatTensor(datas[random_index])
|
108 |
+
start = random.randint(0, music_wav.shape[-1] - self.segments)
|
109 |
+
music_wav = music_wav[:, start:start+self.segments]
|
110 |
+
|
111 |
+
rescale_snr = random.randint(self.snr_range[0], self.snr_range[1])
|
112 |
+
music_wav = music_wav * np.sqrt(10**(rescale_snr/10))
|
113 |
+
ori_wav.append(music_wav)
|
114 |
+
ori_wav = torch.stack(ori_wav).sum(0)
|
115 |
+
else:
|
116 |
+
h5path = random.choice(os.listdir(os.path.join(self.data_dir, "mixture")))
|
117 |
+
datas = h5py.File(os.path.join(self.data_dir, "mixture", h5path), 'r')['data']
|
118 |
+
random_index = random.randint(0, datas.shape[0]-1)
|
119 |
+
music_wav = torch.FloatTensor(datas[random_index])
|
120 |
+
start = random.randint(0, music_wav.shape[-1] - self.segments)
|
121 |
+
ori_wav = music_wav[:, start:start+self.segments]
|
122 |
+
|
123 |
+
codec_wav = codec_simu(ori_wav, sr=self.sr, options=self.codec_options)
|
124 |
+
|
125 |
+
max_scale = max(ori_wav.abs().max(), codec_wav.abs().max())
|
126 |
+
|
127 |
+
if max_scale > 0:
|
128 |
+
ori_wav = ori_wav / max_scale
|
129 |
+
codec_wav = codec_wav / max_scale
|
130 |
+
|
131 |
+
return ori_wav, codec_wav
|
132 |
+
|
133 |
+
|
134 |
+
class MusdbMoisesdbEval(Dataset):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
data_dir: str
|
138 |
+
) -> None:
|
139 |
+
self.data_path = os.listdir(data_dir)
|
140 |
+
self.data_path = [os.path.join(data_dir, i) for i in self.data_path]
|
141 |
+
|
142 |
+
def __len__(self) -> int:
|
143 |
+
return len(self.data_path)
|
144 |
+
|
145 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
146 |
+
ori_wav = torchaudio.load(self.data_path[idx]+"/ori_wav.wav")[0]
|
147 |
+
codec_wav = torchaudio.load(self.data_path[idx]+"/codec_wav.wav")[0]
|
148 |
+
|
149 |
+
return ori_wav, codec_wav, self.data_path[idx]
|
150 |
+
|
151 |
+
class MusdbMoisesdbDataModule(LightningDataModule):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
train_dir: str,
|
155 |
+
eval_dir: str,
|
156 |
+
codec_type: str,
|
157 |
+
codec_options: dict,
|
158 |
+
sr: int = 16000,
|
159 |
+
segments: int = 10,
|
160 |
+
num_stems: int = 4,
|
161 |
+
snr_range: Tuple[int, int] = (-10, 10),
|
162 |
+
num_samples: int = 1000,
|
163 |
+
batch_size: int = 32,
|
164 |
+
num_workers: int = 4,
|
165 |
+
) -> None:
|
166 |
+
super().__init__()
|
167 |
+
self.save_hyperparameters(logger=False)
|
168 |
+
|
169 |
+
self.data_train: Optional[Dataset] = None
|
170 |
+
self.data_val: Optional[Dataset] = None
|
171 |
+
|
172 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
173 |
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
174 |
+
|
175 |
+
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
|
176 |
+
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
|
177 |
+
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
|
178 |
+
`self.setup()` once the data is prepared and available for use.
|
179 |
+
|
180 |
+
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
|
181 |
+
"""
|
182 |
+
# load and split datasets only if not loaded already
|
183 |
+
if not self.data_train and not self.data_val:
|
184 |
+
self.data_train = MusdbMoisesdbDataset(
|
185 |
+
data_dir=self.hparams.train_dir,
|
186 |
+
codec_type=self.hparams.codec_type,
|
187 |
+
codec_options=self.hparams.codec_options,
|
188 |
+
sr=self.hparams.sr,
|
189 |
+
segments=self.hparams.segments,
|
190 |
+
num_stems=self.hparams.num_stems,
|
191 |
+
snr_range=self.hparams.snr_range,
|
192 |
+
num_samples=self.hparams.num_samples,
|
193 |
+
)
|
194 |
+
|
195 |
+
self.data_val = MusdbMoisesdbEval(
|
196 |
+
data_dir=self.hparams.eval_dir
|
197 |
+
)
|
198 |
+
|
199 |
+
def train_dataloader(self) -> DataLoader:
|
200 |
+
return DataLoader(
|
201 |
+
self.data_train,
|
202 |
+
batch_size=self.hparams.batch_size,
|
203 |
+
num_workers=self.hparams.num_workers,
|
204 |
+
shuffle=True,
|
205 |
+
pin_memory=True,
|
206 |
+
)
|
207 |
+
|
208 |
+
def val_dataloader(self) -> DataLoader:
|
209 |
+
return DataLoader(
|
210 |
+
self.data_val,
|
211 |
+
batch_size=self.hparams.batch_size,
|
212 |
+
num_workers=self.hparams.num_workers,
|
213 |
+
shuffle=False,
|
214 |
+
pin_memory=True,
|
215 |
+
)
|
look2hear/discriminators/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2022-02-12 15:16:35
|
4 |
+
# Email: [email protected]
|
5 |
+
# LastEditTime: 2022-10-04 16:24:53
|
6 |
+
###
|
7 |
+
from .frequencydis import MultiFrequencyDiscriminator, FrequencyDiscriminator
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"MultiFrequencyDiscriminator",
|
11 |
+
"FrequencyDiscriminator"
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def register_model(custom_model):
|
16 |
+
"""Register a custom model, gettable with `models.get`.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
custom_model: Custom model to register.
|
20 |
+
|
21 |
+
"""
|
22 |
+
if (
|
23 |
+
custom_model.__name__ in globals().keys()
|
24 |
+
or custom_model.__name__.lower() in globals().keys()
|
25 |
+
):
|
26 |
+
raise ValueError(
|
27 |
+
f"Model {custom_model.__name__} already exists. Choose another name."
|
28 |
+
)
|
29 |
+
globals().update({custom_model.__name__: custom_model})
|
30 |
+
|
31 |
+
|
32 |
+
def get(identifier):
|
33 |
+
"""Returns an model class from a string (case-insensitive).
|
34 |
+
|
35 |
+
Args:
|
36 |
+
identifier (str): the model name.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
:class:`torch.nn.Module`
|
40 |
+
"""
|
41 |
+
if isinstance(identifier, str):
|
42 |
+
to_get = {k.lower(): v for k, v in globals().items()}
|
43 |
+
cls = to_get.get(identifier.lower())
|
44 |
+
if cls is None:
|
45 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
46 |
+
return cls
|
47 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
look2hear/discriminators/frequencydis.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class MultiFrequencyDiscriminator(nn.Module):
|
6 |
+
def __init__(self, nch, window):
|
7 |
+
super(MultiFrequencyDiscriminator, self).__init__()
|
8 |
+
|
9 |
+
self.nch = nch
|
10 |
+
self.window = window
|
11 |
+
self.hidden_channels = 8
|
12 |
+
self.eps = torch.finfo(torch.float32).eps
|
13 |
+
self.discriminators = nn.ModuleList([FrequencyDiscriminator(2*nch, self.hidden_channels) for _ in range(len(self.window))])
|
14 |
+
|
15 |
+
def forward(self, est, sample_rate=44100):
|
16 |
+
|
17 |
+
B, nch, _ = est.shape
|
18 |
+
assert nch == self.nch
|
19 |
+
|
20 |
+
# normalize power
|
21 |
+
est = est / (est.pow(2).sum((1,2)) + self.eps).sqrt().reshape(B, 1, 1)
|
22 |
+
est = est.view(-1, est.shape[-1])
|
23 |
+
|
24 |
+
est_outputs = []
|
25 |
+
est_feature_maps = []
|
26 |
+
|
27 |
+
for i in range(len(self.discriminators)):
|
28 |
+
est_spec = torch.stft(est.float(), self.window[i], self.window[i]//2,
|
29 |
+
window=torch.hann_window(self.window[i]).to(est.device).float(),
|
30 |
+
return_complex=True)
|
31 |
+
est_RI = torch.stack([est_spec.real, est_spec.imag], dim=1)
|
32 |
+
est_RI = est_RI.view(B, nch*2, est_RI.shape[-2], est_RI.shape[-1]).type(est.type())
|
33 |
+
|
34 |
+
valid_enc = int(est_RI.shape[2] * sample_rate / 44100)
|
35 |
+
est_out, est_feat_map = self.discriminators[i](est_RI[:,:,:valid_enc].contiguous())
|
36 |
+
est_outputs.append(est_out)
|
37 |
+
est_feature_maps.append(est_feat_map)
|
38 |
+
|
39 |
+
return est_outputs, est_feature_maps
|
40 |
+
|
41 |
+
|
42 |
+
class FrequencyDiscriminator(nn.Module):
|
43 |
+
def __init__(self, in_channels, hidden_channels=512):
|
44 |
+
super(FrequencyDiscriminator, self).__init__()
|
45 |
+
|
46 |
+
self.eps = torch.finfo(torch.float32).eps
|
47 |
+
self.discriminator = nn.ModuleList()
|
48 |
+
self.discriminator += [
|
49 |
+
nn.Sequential(
|
50 |
+
nn.utils.spectral_norm(nn.Conv2d(in_channels, hidden_channels, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
|
51 |
+
nn.LeakyReLU(0.2, True)
|
52 |
+
),
|
53 |
+
nn.Sequential(
|
54 |
+
nn.utils.spectral_norm(nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
|
55 |
+
nn.LeakyReLU(0.2, True)
|
56 |
+
),
|
57 |
+
nn.Sequential(
|
58 |
+
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
|
59 |
+
nn.LeakyReLU(0.2, True)
|
60 |
+
),
|
61 |
+
nn.Sequential(
|
62 |
+
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
|
63 |
+
nn.LeakyReLU(0.2, True)
|
64 |
+
),
|
65 |
+
nn.Sequential(
|
66 |
+
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*8, hidden_channels*16, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
|
67 |
+
nn.LeakyReLU(0.2, True)
|
68 |
+
),
|
69 |
+
nn.Sequential(
|
70 |
+
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*16, hidden_channels*32, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
|
71 |
+
nn.LeakyReLU(0.2, True)
|
72 |
+
),
|
73 |
+
nn.Conv2d(hidden_channels*32, 1, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
|
74 |
+
]
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
hiddens = []
|
78 |
+
for layer in self.discriminator:
|
79 |
+
x = layer(x)
|
80 |
+
hiddens.append(x)
|
81 |
+
return x, hiddens[:-1]
|
look2hear/losses/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-09 16:34:19
|
4 |
+
# LastEditors: Kai Li
|
5 |
+
# LastEditTime: 2021-07-12 20:55:35
|
6 |
+
###
|
7 |
+
from .gan_losses import MultiFrequencyDisLoss, MultiFrequencyGenLoss
|
8 |
+
from .matrix import MultiSrcNegSDR
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"MultiFrequencyDisLoss",
|
12 |
+
"MultiFrequencyGenLoss",
|
13 |
+
"MultiSrcNegSDR"
|
14 |
+
]
|
look2hear/losses/gan_losses.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-09 16:43:09
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2024-01-24 00:00:52
|
6 |
+
###
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.nn.modules.loss import _Loss
|
10 |
+
|
11 |
+
def freq_MAE(output, target):
|
12 |
+
loss = 0.
|
13 |
+
eps = torch.finfo(torch.float32).eps
|
14 |
+
all_win = [32, 64, 128, 256, 512, 1024, 2048]
|
15 |
+
for win in all_win:
|
16 |
+
est_spec = torch.stft(output.view(-1, output.shape[-1]), n_fft=win, hop_length=win//2,
|
17 |
+
window=torch.hann_window(win).to(output.device).float(),
|
18 |
+
return_complex=True)
|
19 |
+
target_spec = torch.stft(target.view(-1, target.shape[-1]), n_fft=win, hop_length=win//2,
|
20 |
+
window=torch.hann_window(win).to(target.device).float(),
|
21 |
+
return_complex=True)
|
22 |
+
|
23 |
+
loss = loss + (est_spec.abs() - target_spec.abs()).abs().mean() / (target_spec.abs().mean() + eps)
|
24 |
+
|
25 |
+
return loss / len(all_win)
|
26 |
+
|
27 |
+
class MultiFrequencyDisLoss(_Loss):
|
28 |
+
def __init__(self, eps=1e-8):
|
29 |
+
super(MultiFrequencyDisLoss, self).__init__()
|
30 |
+
|
31 |
+
def forward(self, target_outputs, est_outputs):
|
32 |
+
D_real = 0
|
33 |
+
D_fake = 0
|
34 |
+
for i in range(len(target_outputs)):
|
35 |
+
D_real = D_real + (target_outputs[i] - 1).pow(2).mean() / len(target_outputs)
|
36 |
+
D_fake = D_fake + (est_outputs[i]).pow(2).mean() / len(est_outputs)
|
37 |
+
return D_real + D_fake
|
38 |
+
|
39 |
+
class MultiFrequencyGenLoss(_Loss):
|
40 |
+
def __init__(self, eps=1e-8):
|
41 |
+
super(MultiFrequencyGenLoss, self).__init__()
|
42 |
+
self.eps = eps
|
43 |
+
|
44 |
+
def forward(self, est_outputs, est_feature_maps, targets_feature_maps, output, ori_data):
|
45 |
+
G_fake = 0
|
46 |
+
feature_matching = 0
|
47 |
+
eps = self.eps
|
48 |
+
|
49 |
+
for i in range(len(est_outputs)):
|
50 |
+
G_fake = G_fake + (est_outputs[i] - 1).pow(2).mean() / len(est_outputs)
|
51 |
+
for j in range(len(est_feature_maps[i])):
|
52 |
+
feature_matching = feature_matching + (est_feature_maps[i][j] - targets_feature_maps[i][j].detach()).abs().mean() / (targets_feature_maps[i][j].detach().abs().mean() + eps)
|
53 |
+
|
54 |
+
feature_matching = feature_matching / (len(est_outputs) * len(est_feature_maps[0]))
|
55 |
+
freq_loss = freq_MAE(output, ori_data.unsqueeze(1))
|
56 |
+
total_loss = freq_loss + G_fake + feature_matching
|
57 |
+
|
58 |
+
return total_loss
|
look2hear/losses/matrix.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.modules.loss import _Loss
|
3 |
+
|
4 |
+
class MultiSrcNegSDR(_Loss):
|
5 |
+
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
9 |
+
self.sdr_type = sdr_type
|
10 |
+
self.zero_mean = zero_mean
|
11 |
+
self.take_log = take_log
|
12 |
+
self.EPS = 1e-8
|
13 |
+
|
14 |
+
def forward(self, ests, targets):
|
15 |
+
if targets.size() != ests.size() or targets.ndim != 3:
|
16 |
+
raise TypeError(
|
17 |
+
f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead"
|
18 |
+
)
|
19 |
+
# Step 1. Zero-mean norm
|
20 |
+
if self.zero_mean:
|
21 |
+
mean_source = torch.mean(targets, dim=2, keepdim=True)
|
22 |
+
mean_est = torch.mean(ests, dim=2, keepdim=True)
|
23 |
+
targets = targets - mean_source
|
24 |
+
ests = ests - mean_est
|
25 |
+
# Step 2. Pair-wise SI-SDR.
|
26 |
+
if self.sdr_type in ["sisdr", "sdsdr"]:
|
27 |
+
# [batch, n_src]
|
28 |
+
pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True)
|
29 |
+
# [batch, n_src]
|
30 |
+
s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS
|
31 |
+
# [batch, n_src, time]
|
32 |
+
scaled_targets = pair_wise_dot * targets / s_target_energy
|
33 |
+
else:
|
34 |
+
# [batch, n_src, time]
|
35 |
+
scaled_targets = targets
|
36 |
+
if self.sdr_type in ["sdsdr", "snr"]:
|
37 |
+
e_noise = ests - targets
|
38 |
+
else:
|
39 |
+
e_noise = ests - scaled_targets
|
40 |
+
# [batch, n_src]
|
41 |
+
pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / (
|
42 |
+
torch.sum(e_noise ** 2, dim=2) + self.EPS
|
43 |
+
)
|
44 |
+
if self.take_log:
|
45 |
+
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
|
46 |
+
return -torch.mean(pair_wise_sdr, dim=-1).mean(0)
|
look2hear/metrics/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-22 12:22:41
|
4 |
+
# LastEditors: Kai Li
|
5 |
+
# LastEditTime: 2021-07-14 19:15:22
|
6 |
+
###
|
7 |
+
from .wrapper import MetricsTracker
|
8 |
+
|
9 |
+
__all__ = ["MetricsTracker"]
|
look2hear/metrics/wrapper.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-22 12:41:36
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2022-06-05 14:48:00
|
6 |
+
###
|
7 |
+
import csv
|
8 |
+
from sympy import im
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
import librosa
|
14 |
+
from torch_mir_eval.separation import bss_eval_sources
|
15 |
+
import fast_bss_eval
|
16 |
+
from visqol import visqol_lib_py
|
17 |
+
from visqol.pb2 import visqol_config_pb2
|
18 |
+
from visqol.pb2 import similarity_result_pb2
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
def is_silent(wav, threshold=1e-4):
|
23 |
+
return torch.sum(wav ** 2) / wav.numel() < threshold
|
24 |
+
|
25 |
+
class MetricsTracker:
|
26 |
+
def __init__(self, save_file: str = ""):
|
27 |
+
self.all_sdrs = []
|
28 |
+
self.all_sisnrs = []
|
29 |
+
self.all_visqols = []
|
30 |
+
|
31 |
+
csv_columns = ["snt_id", "sdr", "si-snr", "visqol"]
|
32 |
+
self.visqol_config = visqol_config_pb2.VisqolConfig()
|
33 |
+
self.visqol_config.audio.sample_rate = 48000
|
34 |
+
self.visqol_config.options.use_speech_scoring = False
|
35 |
+
svr_model_path = "libsvm_nu_svr_model.txt"
|
36 |
+
self.visqol_config.options.svr_model_path = os.path.join(os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path)
|
37 |
+
self.visqol_api = visqol_lib_py.VisqolApi()
|
38 |
+
self.visqol_api.Create(self.visqol_config)
|
39 |
+
|
40 |
+
self.results_csv = open(save_file, "w")
|
41 |
+
self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
|
42 |
+
self.writer.writeheader()
|
43 |
+
|
44 |
+
def __call__(self, clean, estimate, key):
|
45 |
+
sisnr = fast_bss_eval.si_sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean()
|
46 |
+
sdr = fast_bss_eval.sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean()
|
47 |
+
|
48 |
+
clean = librosa.resample(clean.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64)
|
49 |
+
estimate = librosa.resample(estimate.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64)
|
50 |
+
|
51 |
+
visqol = self.visqol_api.Measure(clean, estimate).moslqo
|
52 |
+
# import pdb; pdb.set_trace()
|
53 |
+
row = {
|
54 |
+
"snt_id": key,
|
55 |
+
"sdr": sdr.item(),
|
56 |
+
"si-snr": sisnr.item(),
|
57 |
+
"visqol": visqol
|
58 |
+
}
|
59 |
+
|
60 |
+
self.writer.writerow(row)
|
61 |
+
# Metric Accumulation
|
62 |
+
self.all_sdrs.append(sdr.item())
|
63 |
+
self.all_sisnrs.append(sisnr.item())
|
64 |
+
self.all_visqols.append(visqol)
|
65 |
+
|
66 |
+
def update(self, ):
|
67 |
+
return {"sdr": np.array(self.all_sdrs).mean(),
|
68 |
+
"si-snr": np.array(self.all_sisnrs).mean(),
|
69 |
+
"visqol": np.array(self.all_visqols).mean()}
|
70 |
+
|
71 |
+
def final(self,):
|
72 |
+
row = {
|
73 |
+
"snt_id": "avg",
|
74 |
+
"sdr": np.array(self.all_sdrs).mean(),
|
75 |
+
"si-snr": np.array(self.all_sisnrs).mean(),
|
76 |
+
"visqol": np.array(self.all_visqols).mean()
|
77 |
+
}
|
78 |
+
self.writer.writerow(row)
|
79 |
+
row = {
|
80 |
+
"snt_id": "std",
|
81 |
+
"sdr": np.array(self.all_sdrs).std(),
|
82 |
+
"si-snr": np.array(self.all_sisnrs).std(),
|
83 |
+
"visqol": np.array(self.all_visqols).std()
|
84 |
+
}
|
85 |
+
self.writer.writerow(row)
|
86 |
+
self.results_csv.close()
|
look2hear/models/__init__.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2022-02-12 15:16:35
|
4 |
+
# Email: [email protected]
|
5 |
+
# LastEditTime: 2022-10-04 16:24:53
|
6 |
+
###
|
7 |
+
from .base_model import BaseModel
|
8 |
+
from .apollo import Apollo
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"BaseModel",
|
12 |
+
"GullFullband",
|
13 |
+
"Apollo"
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
def register_model(custom_model):
|
18 |
+
"""Register a custom model, gettable with `models.get`.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
custom_model: Custom model to register.
|
22 |
+
|
23 |
+
"""
|
24 |
+
if (
|
25 |
+
custom_model.__name__ in globals().keys()
|
26 |
+
or custom_model.__name__.lower() in globals().keys()
|
27 |
+
):
|
28 |
+
raise ValueError(
|
29 |
+
f"Model {custom_model.__name__} already exists. Choose another name."
|
30 |
+
)
|
31 |
+
globals().update({custom_model.__name__: custom_model})
|
32 |
+
|
33 |
+
|
34 |
+
def get(identifier):
|
35 |
+
"""Returns an model class from a string (case-insensitive).
|
36 |
+
|
37 |
+
Args:
|
38 |
+
identifier (str): the model name.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
:class:`torch.nn.Module`
|
42 |
+
"""
|
43 |
+
if isinstance(identifier, str):
|
44 |
+
to_get = {k.lower(): v for k, v in globals().items()}
|
45 |
+
cls = to_get.get(identifier.lower())
|
46 |
+
if cls is None:
|
47 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
48 |
+
return cls
|
49 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
look2hear/models/apollo.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from .base_model import BaseModel
|
6 |
+
|
7 |
+
class RMSNorm(nn.Module):
|
8 |
+
def __init__(self, dimension, groups=1):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.weight = nn.Parameter(torch.ones(dimension))
|
12 |
+
self.groups = groups
|
13 |
+
self.eps = 1e-5
|
14 |
+
|
15 |
+
def forward(self, input):
|
16 |
+
# input size: (B, N, T)
|
17 |
+
B, N, T = input.shape
|
18 |
+
assert N % self.groups == 0
|
19 |
+
|
20 |
+
input_float = input.reshape(B, self.groups, -1, T).float()
|
21 |
+
input_norm = input_float * torch.rsqrt(input_float.pow(2).mean(-2, keepdim=True) + self.eps)
|
22 |
+
|
23 |
+
return input_norm.type_as(input).reshape(B, N, T) * self.weight.reshape(1, -1, 1)
|
24 |
+
|
25 |
+
class RMVN(nn.Module):
|
26 |
+
"""
|
27 |
+
Rescaled MVN.
|
28 |
+
"""
|
29 |
+
def __init__(self, dimension, groups=1):
|
30 |
+
super(RMVN, self).__init__()
|
31 |
+
|
32 |
+
self.mean = nn.Parameter(torch.zeros(dimension))
|
33 |
+
self.std = nn.Parameter(torch.ones(dimension))
|
34 |
+
self.groups = groups
|
35 |
+
self.eps = 1e-5
|
36 |
+
|
37 |
+
def forward(self, input):
|
38 |
+
# input size: (B, N, *)
|
39 |
+
B, N = input.shape[:2]
|
40 |
+
assert N % self.groups == 0
|
41 |
+
input_reshape = input.reshape(B, self.groups, N // self.groups, -1)
|
42 |
+
T = input_reshape.shape[-1]
|
43 |
+
|
44 |
+
input_norm = (input_reshape - input_reshape.mean(2).unsqueeze(2)) / (input_reshape.var(2).unsqueeze(2) + self.eps).sqrt()
|
45 |
+
input_norm = input_norm.reshape(B, N, T) * self.std.reshape(1, -1, 1) + self.mean.reshape(1, -1, 1)
|
46 |
+
|
47 |
+
return input_norm.reshape(input.shape)
|
48 |
+
|
49 |
+
class Roformer(nn.Module):
|
50 |
+
"""
|
51 |
+
Transformer with rotary positional embedding.
|
52 |
+
"""
|
53 |
+
def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000,
|
54 |
+
input_drop=0., attention_drop=0., causal=True):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.input_size = input_size
|
58 |
+
self.hidden_size = hidden_size // num_head
|
59 |
+
self.num_head = num_head
|
60 |
+
self.theta = theta # base frequency for RoPE
|
61 |
+
self.window = window
|
62 |
+
# pre-calculate rotary embeddings
|
63 |
+
cos_freq, sin_freq = self._calc_rotary_emb()
|
64 |
+
self.register_buffer("cos_freq", cos_freq) # win, N
|
65 |
+
self.register_buffer("sin_freq", sin_freq) # win, N
|
66 |
+
|
67 |
+
self.attention_drop = attention_drop
|
68 |
+
self.causal = causal
|
69 |
+
self.eps = 1e-5
|
70 |
+
|
71 |
+
self.input_norm = RMSNorm(self.input_size)
|
72 |
+
self.input_drop = nn.Dropout(p=input_drop)
|
73 |
+
self.weight = nn.Conv1d(self.input_size, self.hidden_size*self.num_head*3, 1, bias=False)
|
74 |
+
self.output = nn.Conv1d(self.hidden_size*self.num_head, self.input_size, 1, bias=False)
|
75 |
+
|
76 |
+
self.MLP = nn.Sequential(RMSNorm(self.input_size),
|
77 |
+
nn.Conv1d(self.input_size, self.input_size*8, 1, bias=False),
|
78 |
+
nn.SiLU()
|
79 |
+
)
|
80 |
+
self.MLP_output = nn.Conv1d(self.input_size*4, self.input_size, 1, bias=False)
|
81 |
+
|
82 |
+
def _calc_rotary_emb(self):
|
83 |
+
freq = 1. / (self.theta ** (torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size)) # theta_i
|
84 |
+
freq = freq.reshape(1, -1) # 1, N//2
|
85 |
+
pos = torch.arange(0, self.window).reshape(-1, 1) # win, 1
|
86 |
+
cos_freq = torch.cos(pos*freq) # win, N//2
|
87 |
+
sin_freq = torch.sin(pos*freq) # win, N//2
|
88 |
+
cos_freq = torch.stack([cos_freq]*2, -1).reshape(self.window, self.hidden_size) # win, N
|
89 |
+
sin_freq = torch.stack([sin_freq]*2, -1).reshape(self.window, self.hidden_size) # win, N
|
90 |
+
|
91 |
+
return cos_freq, sin_freq
|
92 |
+
|
93 |
+
def _add_rotary_emb(self, feature, pos):
|
94 |
+
# feature shape: ..., N
|
95 |
+
N = feature.shape[-1]
|
96 |
+
|
97 |
+
feature_reshape = feature.reshape(-1, N)
|
98 |
+
pos = min(pos, self.window-1)
|
99 |
+
cos_freq = self.cos_freq[pos]
|
100 |
+
sin_freq = self.sin_freq[pos]
|
101 |
+
reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
|
102 |
+
feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, N)
|
103 |
+
feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
|
104 |
+
|
105 |
+
return feature_rope.reshape(feature.shape)
|
106 |
+
|
107 |
+
def _add_rotary_sequence(self, feature):
|
108 |
+
# feature shape: ..., T, N
|
109 |
+
T, N = feature.shape[-2:]
|
110 |
+
feature_reshape = feature.reshape(-1, T, N)
|
111 |
+
|
112 |
+
cos_freq = self.cos_freq[:T]
|
113 |
+
sin_freq = self.sin_freq[:T]
|
114 |
+
reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
|
115 |
+
feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, T, N)
|
116 |
+
feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
|
117 |
+
|
118 |
+
return feature_rope.reshape(feature.shape)
|
119 |
+
|
120 |
+
def forward(self, input):
|
121 |
+
# input shape: B, N, T
|
122 |
+
|
123 |
+
B, _, T = input.shape
|
124 |
+
|
125 |
+
weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size*3, T).mT
|
126 |
+
Q, K, V = torch.split(weight, self.hidden_size, dim=-1) # B, num_head, T, N
|
127 |
+
|
128 |
+
# rotary positional embedding
|
129 |
+
Q_rot = self._add_rotary_sequence(Q)
|
130 |
+
K_rot = self._add_rotary_sequence(K)
|
131 |
+
|
132 |
+
attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(), dropout_p=self.attention_drop, is_causal=self.causal) # B, num_head, T, N
|
133 |
+
attention_output = attention_output.mT.reshape(B, -1, T)
|
134 |
+
output = self.output(attention_output) + input
|
135 |
+
|
136 |
+
gate, z = self.MLP(output).chunk(2, dim=1)
|
137 |
+
output = output + self.MLP_output(F.silu(gate) * z)
|
138 |
+
|
139 |
+
return output, (K_rot, V)
|
140 |
+
|
141 |
+
class ConvActNorm1d(nn.Module):
|
142 |
+
def __init__(self, in_channel, hidden_channel, kernel=7, causal=False):
|
143 |
+
super(ConvActNorm1d, self).__init__()
|
144 |
+
|
145 |
+
self.in_channel = in_channel
|
146 |
+
self.kernel = kernel
|
147 |
+
self.causal = causal
|
148 |
+
if not causal:
|
149 |
+
self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2, groups=in_channel),
|
150 |
+
RMSNorm(in_channel),
|
151 |
+
nn.Conv1d(in_channel, hidden_channel, 1),
|
152 |
+
nn.SiLU(),
|
153 |
+
nn.Conv1d(hidden_channel, in_channel, 1)
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1, groups=in_channel),
|
157 |
+
RMSNorm(in_channel),
|
158 |
+
nn.Conv1d(in_channel, hidden_channel, 1),
|
159 |
+
nn.SiLU(),
|
160 |
+
nn.Conv1d(hidden_channel, in_channel, 1)
|
161 |
+
)
|
162 |
+
|
163 |
+
def forward(self, input):
|
164 |
+
|
165 |
+
output = self.conv(input)
|
166 |
+
if self.causal:
|
167 |
+
output = output[...,:-self.kernel+1]
|
168 |
+
return input + output
|
169 |
+
|
170 |
+
class ICB(nn.Module):
|
171 |
+
def __init__(self, in_channel, kernel=7, causal=False):
|
172 |
+
super(ICB, self).__init__()
|
173 |
+
|
174 |
+
self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal),
|
175 |
+
ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal),
|
176 |
+
ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal)
|
177 |
+
)
|
178 |
+
|
179 |
+
def forward(self, input):
|
180 |
+
|
181 |
+
return self.blocks(input)
|
182 |
+
|
183 |
+
class BSNet(nn.Module):
|
184 |
+
def __init__(self, feature_dim, kernel=7):
|
185 |
+
super(BSNet, self).__init__()
|
186 |
+
|
187 |
+
self.feature_dim = feature_dim
|
188 |
+
|
189 |
+
self.band_net = Roformer(self.feature_dim, self.feature_dim, num_head=8, window=100, causal=False)
|
190 |
+
self.seq_net = ICB(self.feature_dim, kernel=kernel)
|
191 |
+
|
192 |
+
def forward(self, input):
|
193 |
+
# input shape: B, nband, N, T
|
194 |
+
|
195 |
+
B, nband, N, T = input.shape
|
196 |
+
|
197 |
+
# band comm
|
198 |
+
band_input = input.permute(0,3,2,1).reshape(B*T, -1, nband)
|
199 |
+
band_output, _ = self.band_net(band_input)
|
200 |
+
band_output = band_output.reshape(B, T, -1, nband).permute(0,3,2,1)
|
201 |
+
|
202 |
+
# sequence modeling
|
203 |
+
output = self.seq_net(band_output.reshape(B*nband, -1, T)).reshape(B, nband, -1, T) # B, nband, N, T
|
204 |
+
|
205 |
+
return output
|
206 |
+
|
207 |
+
class Apollo(BaseModel):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
sr: int,
|
211 |
+
win: int,
|
212 |
+
feature_dim: int,
|
213 |
+
layer: int
|
214 |
+
):
|
215 |
+
super().__init__(sample_rate=sr)
|
216 |
+
|
217 |
+
self.sr = sr
|
218 |
+
self.win = int(sr * win // 1000)
|
219 |
+
self.stride = self.win // 2
|
220 |
+
self.enc_dim = self.win // 2 + 1
|
221 |
+
self.feature_dim = feature_dim
|
222 |
+
self.eps = torch.finfo(torch.float32).eps
|
223 |
+
|
224 |
+
# 80 bands
|
225 |
+
bandwidth = int(self.win / 160)
|
226 |
+
self.band_width = [bandwidth]*79
|
227 |
+
self.band_width.append(self.enc_dim - np.sum(self.band_width))
|
228 |
+
self.nband = len(self.band_width)
|
229 |
+
print(self.band_width, self.nband)
|
230 |
+
|
231 |
+
self.BN = nn.ModuleList([])
|
232 |
+
for i in range(self.nband):
|
233 |
+
self.BN.append(nn.Sequential(RMSNorm(self.band_width[i]*2+1),
|
234 |
+
nn.Conv1d(self.band_width[i]*2+1, self.feature_dim, 1))
|
235 |
+
)
|
236 |
+
|
237 |
+
self.net = []
|
238 |
+
for _ in range(layer):
|
239 |
+
self.net.append(BSNet(self.feature_dim))
|
240 |
+
self.net = nn.Sequential(*self.net)
|
241 |
+
|
242 |
+
self.output = nn.ModuleList([])
|
243 |
+
for i in range(self.nband):
|
244 |
+
self.output.append(nn.Sequential(RMSNorm(self.feature_dim),
|
245 |
+
nn.Conv1d(self.feature_dim, self.band_width[i]*4, 1),
|
246 |
+
nn.GLU(dim=1)
|
247 |
+
)
|
248 |
+
)
|
249 |
+
|
250 |
+
def spec_band_split(self, input):
|
251 |
+
|
252 |
+
B, nch, nsample = input.shape
|
253 |
+
|
254 |
+
spec = torch.stft(input.view(B*nch, nsample), n_fft=self.win, hop_length=self.stride,
|
255 |
+
window=torch.hann_window(self.win).to(input.device), return_complex=True)
|
256 |
+
|
257 |
+
subband_spec = []
|
258 |
+
subband_spec_norm = []
|
259 |
+
subband_power = []
|
260 |
+
band_idx = 0
|
261 |
+
for i in range(self.nband):
|
262 |
+
this_spec = spec[:,band_idx:band_idx+self.band_width[i]]
|
263 |
+
subband_spec.append(this_spec) # B, BW, T
|
264 |
+
subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T
|
265 |
+
subband_spec_norm.append(torch.complex(this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1])) # B, BW, T
|
266 |
+
band_idx += self.band_width[i]
|
267 |
+
subband_power = torch.cat(subband_power, 1) # B, nband, T
|
268 |
+
|
269 |
+
return subband_spec_norm, subband_power
|
270 |
+
|
271 |
+
def feature_extractor(self, input):
|
272 |
+
|
273 |
+
subband_spec_norm, subband_power = self.spec_band_split(input)
|
274 |
+
|
275 |
+
# normalization and bottleneck
|
276 |
+
subband_feature = []
|
277 |
+
for i in range(self.nband):
|
278 |
+
concat_spec = torch.cat([subband_spec_norm[i].real, subband_spec_norm[i].imag, torch.log(subband_power[:,i].unsqueeze(1))], 1)
|
279 |
+
subband_feature.append(self.BN[i](concat_spec))
|
280 |
+
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
|
281 |
+
|
282 |
+
return subband_feature
|
283 |
+
|
284 |
+
def forward(self, input):
|
285 |
+
|
286 |
+
B, nch, nsample = input.shape
|
287 |
+
|
288 |
+
subband_feature = self.feature_extractor(input)
|
289 |
+
feature = self.net(subband_feature)
|
290 |
+
|
291 |
+
est_spec = []
|
292 |
+
for i in range(self.nband):
|
293 |
+
this_RI = self.output[i](feature[:,i]).view(B*nch, 2, self.band_width[i], -1)
|
294 |
+
est_spec.append(torch.complex(this_RI[:,0], this_RI[:,1]))
|
295 |
+
est_spec = torch.cat(est_spec, 1)
|
296 |
+
output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride,
|
297 |
+
window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1)
|
298 |
+
|
299 |
+
return output
|
300 |
+
|
301 |
+
def get_model_args(self):
|
302 |
+
model_args = {"n_sample_rate": 2}
|
303 |
+
return model_args
|
look2hear/models/base_model.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-17 23:08:32
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2022-05-26 18:06:22
|
6 |
+
###
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from huggingface_hub import PyTorchModelHubMixin
|
11 |
+
|
12 |
+
|
13 |
+
def _unsqueeze_to_3d(x):
|
14 |
+
"""Normalize shape of `x` to [batch, n_chan, time]."""
|
15 |
+
if x.ndim == 1:
|
16 |
+
return x.reshape(1, 1, -1)
|
17 |
+
elif x.ndim == 2:
|
18 |
+
return x.unsqueeze(1)
|
19 |
+
else:
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def pad_to_appropriate_length(x, lcm):
|
24 |
+
values_to_pad = int(x.shape[-1]) % lcm
|
25 |
+
if values_to_pad:
|
26 |
+
appropriate_shape = x.shape
|
27 |
+
padded_x = torch.zeros(
|
28 |
+
list(appropriate_shape[:-1])
|
29 |
+
+ [appropriate_shape[-1] + lcm - values_to_pad],
|
30 |
+
dtype=torch.float32,
|
31 |
+
).to(x.device)
|
32 |
+
padded_x[..., : x.shape[-1]] = x
|
33 |
+
return padded_x
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"):
|
38 |
+
def __init__(self, sample_rate, in_chan=1):
|
39 |
+
super().__init__()
|
40 |
+
self._sample_rate = sample_rate
|
41 |
+
self._in_chan = in_chan
|
42 |
+
|
43 |
+
def forward(self, *args, **kwargs):
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
def sample_rate(self,):
|
47 |
+
return self._sample_rate
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def load_state_dict_in_audio(model, pretrained_dict):
|
51 |
+
model_dict = model.state_dict()
|
52 |
+
update_dict = {}
|
53 |
+
for k, v in pretrained_dict.items():
|
54 |
+
if "audio_model" in k:
|
55 |
+
update_dict[k[12:]] = v
|
56 |
+
model_dict.update(update_dict)
|
57 |
+
model.load_state_dict(model_dict)
|
58 |
+
return model
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs):
|
62 |
+
from . import get
|
63 |
+
|
64 |
+
conf = torch.load(
|
65 |
+
pretrained_model_conf_or_path, map_location="cpu"
|
66 |
+
) # Attempt to find the model and instantiate it.
|
67 |
+
|
68 |
+
model_class = get(conf["model_name"])
|
69 |
+
# model_class = get("Conv_TasNet")
|
70 |
+
model = model_class(*args, **kwargs)
|
71 |
+
model.load_state_dict(conf["state_dict"])
|
72 |
+
return model
|
73 |
+
|
74 |
+
def serialize(self):
|
75 |
+
import pytorch_lightning as pl # Not used in torch.hub
|
76 |
+
|
77 |
+
model_conf = dict(
|
78 |
+
model_name=self.__class__.__name__,
|
79 |
+
state_dict=self.get_state_dict(),
|
80 |
+
model_args=self.get_model_args(),
|
81 |
+
)
|
82 |
+
# Additional infos
|
83 |
+
infos = dict()
|
84 |
+
infos["software_versions"] = dict(
|
85 |
+
torch_version=torch.__version__, pytorch_lightning_version=pl.__version__,
|
86 |
+
)
|
87 |
+
model_conf["infos"] = infos
|
88 |
+
return model_conf
|
89 |
+
|
90 |
+
def get_state_dict(self):
|
91 |
+
"""In case the state dict needs to be modified before sharing the model."""
|
92 |
+
return self.state_dict()
|
93 |
+
|
94 |
+
def get_model_args(self):
|
95 |
+
"""Should return args to re-instantiate the class."""
|
96 |
+
raise NotImplementedError
|
look2hear/system/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-20 17:52:35
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2022-05-26 18:27:43
|
6 |
+
###
|
7 |
+
|
8 |
+
|
9 |
+
from .optimizers import make_optimizer
|
10 |
+
from .audio_litmodule import AudioLightningModule
|
11 |
+
from .schedulers import DPTNetScheduler
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"make_optimizer",
|
15 |
+
"AudioLightningModule",
|
16 |
+
"DPTNetScheduler"
|
17 |
+
]
|
look2hear/system/audio_litmodule.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2022-05-26 18:09:54
|
4 |
+
# Email: [email protected]
|
5 |
+
# LastEditTime: 2024-01-24 00:00:28
|
6 |
+
###
|
7 |
+
import gc
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
import torch
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
12 |
+
from collections.abc import MutableMapping
|
13 |
+
from omegaconf import ListConfig
|
14 |
+
|
15 |
+
def flatten_dict(d, parent_key="", sep="_"):
|
16 |
+
"""Flattens a dictionary into a single-level dictionary while preserving
|
17 |
+
parent keys. Taken from
|
18 |
+
`SO <https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys>`_
|
19 |
+
|
20 |
+
Args:
|
21 |
+
d (MutableMapping): Dictionary to be flattened.
|
22 |
+
parent_key (str): String to use as a prefix to all subsequent keys.
|
23 |
+
sep (str): String to use as a separator between two key levels.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
dict: Single-level dictionary, flattened.
|
27 |
+
"""
|
28 |
+
items = []
|
29 |
+
for k, v in d.items():
|
30 |
+
new_key = parent_key + sep + k if parent_key else k
|
31 |
+
if isinstance(v, MutableMapping):
|
32 |
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
33 |
+
else:
|
34 |
+
items.append((new_key, v))
|
35 |
+
return dict(items)
|
36 |
+
|
37 |
+
|
38 |
+
class AudioLightningModule(pl.LightningModule):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
model=None,
|
42 |
+
discriminator=None,
|
43 |
+
optimizer=None,
|
44 |
+
loss_func=None,
|
45 |
+
metrics=None,
|
46 |
+
scheduler=None,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.audio_model = model
|
50 |
+
self.discriminator = discriminator
|
51 |
+
self.optimizer = list(optimizer)
|
52 |
+
self.loss_func = loss_func
|
53 |
+
self.metrics = metrics
|
54 |
+
self.scheduler = list(scheduler)
|
55 |
+
|
56 |
+
# Save lightning"s AttributeDict under self.hparams
|
57 |
+
self.default_monitor = "val_loss"
|
58 |
+
# self.print(self.audio_model)
|
59 |
+
self.validation_step_outputs = []
|
60 |
+
self.test_step_outputs = []
|
61 |
+
self.automatic_optimization = False
|
62 |
+
|
63 |
+
def forward(self, wav):
|
64 |
+
"""Applies forward pass of the model.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
:class:`torch.Tensor`
|
68 |
+
"""
|
69 |
+
return self.audio_model(wav)
|
70 |
+
|
71 |
+
def training_step(self, batch, batch_nb):
|
72 |
+
ori_data, codec_data = batch
|
73 |
+
optimizer_g, optimizer_d = self.optimizers()
|
74 |
+
# multiple schedulers
|
75 |
+
scheduler_g, scheduler_d = self.lr_schedulers()
|
76 |
+
|
77 |
+
# train discriminator
|
78 |
+
optimizer_g.zero_grad()
|
79 |
+
output = self(codec_data)
|
80 |
+
|
81 |
+
optimizer_d.zero_grad()
|
82 |
+
est_outputs, _ = self.discriminator(output.detach(), sample_rate=44100)
|
83 |
+
target_outputs, _ = self.discriminator(ori_data, sample_rate=44100)
|
84 |
+
|
85 |
+
loss_d = self.loss_func["d"](target_outputs, est_outputs)
|
86 |
+
self.manual_backward(loss_d)
|
87 |
+
self.clip_gradients(optimizer_d, gradient_clip_val=5, gradient_clip_algorithm="norm")
|
88 |
+
optimizer_d.step()
|
89 |
+
# train generator
|
90 |
+
est_outputs, est_feature_maps = self.discriminator(output, sample_rate=44100)
|
91 |
+
_, targets_feature_maps = self.discriminator(ori_data, sample_rate=44100)
|
92 |
+
|
93 |
+
loss_g = self.loss_func["g"](est_outputs, est_feature_maps, targets_feature_maps, output, ori_data)
|
94 |
+
self.manual_backward(loss_g)
|
95 |
+
self.clip_gradients(optimizer_g, gradient_clip_val=5, gradient_clip_algorithm="norm")
|
96 |
+
optimizer_g.step()
|
97 |
+
# print(loss)
|
98 |
+
|
99 |
+
if self.trainer.is_last_batch:
|
100 |
+
scheduler_g.step()
|
101 |
+
scheduler_d.step()
|
102 |
+
|
103 |
+
self.log(
|
104 |
+
"train_loss_d",
|
105 |
+
loss_d,
|
106 |
+
on_epoch=True,
|
107 |
+
prog_bar=True,
|
108 |
+
sync_dist=True,
|
109 |
+
logger=True,
|
110 |
+
)
|
111 |
+
|
112 |
+
self.log(
|
113 |
+
"train_loss_g",
|
114 |
+
loss_g,
|
115 |
+
on_epoch=True,
|
116 |
+
prog_bar=True,
|
117 |
+
sync_dist=True,
|
118 |
+
logger=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def validation_step(self, batch, batch_nb):
|
123 |
+
# cal val loss
|
124 |
+
ori_data, codec_data = batch
|
125 |
+
# print(mixtures.shape)
|
126 |
+
est_sources = self(codec_data)
|
127 |
+
loss = self.metrics(est_sources, ori_data)
|
128 |
+
|
129 |
+
self.log(
|
130 |
+
"val_loss",
|
131 |
+
loss,
|
132 |
+
on_epoch=True,
|
133 |
+
prog_bar=True,
|
134 |
+
sync_dist=True,
|
135 |
+
logger=True,
|
136 |
+
)
|
137 |
+
|
138 |
+
self.validation_step_outputs.append(loss)
|
139 |
+
|
140 |
+
return {"val_loss": loss}
|
141 |
+
|
142 |
+
def on_validation_epoch_end(self):
|
143 |
+
# val
|
144 |
+
avg_loss = torch.stack(self.validation_step_outputs).mean()
|
145 |
+
val_loss = torch.mean(self.all_gather(avg_loss))
|
146 |
+
self.log(
|
147 |
+
"lr",
|
148 |
+
self.optimizer[0].param_groups[0]["lr"],
|
149 |
+
on_epoch=True,
|
150 |
+
prog_bar=True,
|
151 |
+
sync_dist=True,
|
152 |
+
)
|
153 |
+
self.logger.experiment.log(
|
154 |
+
{"learning_rate": self.optimizer[0].param_groups[0]["lr"], "epoch": self.current_epoch}
|
155 |
+
)
|
156 |
+
self.logger.experiment.log(
|
157 |
+
{"val_pit_sisnr": -val_loss, "epoch": self.current_epoch}
|
158 |
+
)
|
159 |
+
|
160 |
+
self.validation_step_outputs.clear() # free memory
|
161 |
+
torch.cuda.empty_cache()
|
162 |
+
|
163 |
+
def test_step(self, batch, batch_nb):
|
164 |
+
mixtures, targets = batch
|
165 |
+
est_sources = self(mixtures)
|
166 |
+
loss = self.metrics(est_sources, targets)
|
167 |
+
self.log(
|
168 |
+
"test_loss",
|
169 |
+
loss,
|
170 |
+
on_epoch=True,
|
171 |
+
prog_bar=True,
|
172 |
+
sync_dist=True,
|
173 |
+
logger=True,
|
174 |
+
)
|
175 |
+
self.test_step_outputs.append(loss)
|
176 |
+
return {"test_loss": loss}
|
177 |
+
|
178 |
+
def on_test_epoch_end(self):
|
179 |
+
# val
|
180 |
+
avg_loss = torch.stack(self.test_step_outputs).mean()
|
181 |
+
test_loss = torch.mean(self.all_gather(avg_loss))
|
182 |
+
self.log(
|
183 |
+
"lr",
|
184 |
+
self.optimizer.param_groups[0]["lr"],
|
185 |
+
on_epoch=True,
|
186 |
+
prog_bar=True,
|
187 |
+
sync_dist=True,
|
188 |
+
)
|
189 |
+
self.logger.experiment.log(
|
190 |
+
{"learning_rate": self.optimizer.param_groups[0]["lr"], "epoch": self.current_epoch}
|
191 |
+
)
|
192 |
+
self.logger.experiment.log(
|
193 |
+
{"test_pit_sisnr": -test_loss, "epoch": self.current_epoch}
|
194 |
+
)
|
195 |
+
|
196 |
+
self.test_step_outputs.clear()
|
197 |
+
|
198 |
+
def configure_optimizers(self):
|
199 |
+
"""Initialize optimizers, batch-wise and epoch-wise schedulers."""
|
200 |
+
if self.scheduler is None:
|
201 |
+
return self.optimizer
|
202 |
+
if not isinstance(self.scheduler, (list, tuple)):
|
203 |
+
self.scheduler = [self.scheduler] # support multiple schedulers
|
204 |
+
|
205 |
+
if not isinstance(self.optimizer, (list, tuple)):
|
206 |
+
self.optimizer = [self.optimizer] # support multiple schedulers
|
207 |
+
|
208 |
+
epoch_schedulers = []
|
209 |
+
for sched in self.scheduler:
|
210 |
+
if not isinstance(sched, dict):
|
211 |
+
if isinstance(sched, ReduceLROnPlateau):
|
212 |
+
sched = {"scheduler": sched, "monitor": self.default_monitor}
|
213 |
+
epoch_schedulers.append(sched)
|
214 |
+
else:
|
215 |
+
sched.setdefault("monitor", self.default_monitor)
|
216 |
+
sched.setdefault("frequency", 1)
|
217 |
+
# Backward compat
|
218 |
+
if sched["interval"] == "batch":
|
219 |
+
sched["interval"] = "step"
|
220 |
+
assert sched["interval"] in [
|
221 |
+
"epoch",
|
222 |
+
"step",
|
223 |
+
], "Scheduler interval should be either step or epoch"
|
224 |
+
epoch_schedulers.append(sched)
|
225 |
+
return self.optimizer, epoch_schedulers
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def config_to_hparams(dic):
|
229 |
+
"""Sanitizes the config dict to be handled correctly by torch
|
230 |
+
SummaryWriter. It flatten the config dict, converts ``None`` to
|
231 |
+
``"None"`` and any list and tuple into torch.Tensors.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
dic (dict): Dictionary to be transformed.
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
dict: Transformed dictionary.
|
238 |
+
"""
|
239 |
+
dic = flatten_dict(dic)
|
240 |
+
for k, v in dic.items():
|
241 |
+
if v is None:
|
242 |
+
dic[k] = str(v)
|
243 |
+
elif isinstance(v, (list, tuple)):
|
244 |
+
dic[k] = torch.tensor(v)
|
245 |
+
return dic
|
look2hear/system/optimizers.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-20 00:21:33
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2022-05-27 11:19:51
|
6 |
+
###
|
7 |
+
|
8 |
+
from torch.optim.optimizer import Optimizer
|
9 |
+
from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD
|
10 |
+
from torch_optimizer import (
|
11 |
+
AccSGD,
|
12 |
+
AdaBound,
|
13 |
+
AdaMod,
|
14 |
+
DiffGrad,
|
15 |
+
Lamb,
|
16 |
+
NovoGrad,
|
17 |
+
PID,
|
18 |
+
QHAdam,
|
19 |
+
QHM,
|
20 |
+
RAdam,
|
21 |
+
SGDW,
|
22 |
+
Yogi,
|
23 |
+
Ranger,
|
24 |
+
RangerQH,
|
25 |
+
RangerVA,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
__all__ = [
|
30 |
+
"AccSGD",
|
31 |
+
"AdaBound",
|
32 |
+
"AdaMod",
|
33 |
+
"DiffGrad",
|
34 |
+
"Lamb",
|
35 |
+
"NovoGrad",
|
36 |
+
"PID",
|
37 |
+
"QHAdam",
|
38 |
+
"QHM",
|
39 |
+
"RAdam",
|
40 |
+
"SGDW",
|
41 |
+
"Yogi",
|
42 |
+
"Ranger",
|
43 |
+
"RangerQH",
|
44 |
+
"RangerVA",
|
45 |
+
"Adam",
|
46 |
+
"RMSprop",
|
47 |
+
"SGD",
|
48 |
+
"Adadelta",
|
49 |
+
"Adagrad",
|
50 |
+
"Adamax",
|
51 |
+
"AdamW",
|
52 |
+
"ASGD",
|
53 |
+
"make_optimizer",
|
54 |
+
"get",
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
+
def make_optimizer(params, optim_name="adam", **kwargs):
|
59 |
+
"""
|
60 |
+
|
61 |
+
Args:
|
62 |
+
params (iterable): Output of `nn.Module.parameters()`.
|
63 |
+
optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood
|
64 |
+
by :func:`~.get`.
|
65 |
+
**kwargs (dict): keyword arguments for the optimizer.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.optim.Optimizer
|
69 |
+
Examples
|
70 |
+
>>> from torch import nn
|
71 |
+
>>> model = nn.Sequential(nn.Linear(10, 10))
|
72 |
+
>>> optimizer = make_optimizer(model.parameters(), optimizer='sgd',
|
73 |
+
>>> lr=1e-3)
|
74 |
+
"""
|
75 |
+
return get(optim_name)(params, **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
def register_optimizer(custom_opt):
|
79 |
+
"""Register a custom opt, gettable with `optimzers.get`.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
custom_opt: Custom optimizer to register.
|
83 |
+
|
84 |
+
"""
|
85 |
+
if (
|
86 |
+
custom_opt.__name__ in globals().keys()
|
87 |
+
or custom_opt.__name__.lower() in globals().keys()
|
88 |
+
):
|
89 |
+
raise ValueError(
|
90 |
+
f"Activation {custom_opt.__name__} already exists. Choose another name."
|
91 |
+
)
|
92 |
+
globals().update({custom_opt.__name__: custom_opt})
|
93 |
+
|
94 |
+
|
95 |
+
def get(identifier):
|
96 |
+
"""Returns an optimizer function from a string. Returns its input if it
|
97 |
+
is callable (already a :class:`torch.optim.Optimizer` for example).
|
98 |
+
|
99 |
+
Args:
|
100 |
+
identifier (str or Callable): the optimizer identifier.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
:class:`torch.optim.Optimizer` or None
|
104 |
+
"""
|
105 |
+
if isinstance(identifier, Optimizer):
|
106 |
+
return identifier
|
107 |
+
elif isinstance(identifier, str):
|
108 |
+
to_get = {k.lower(): v for k, v in globals().items()}
|
109 |
+
cls = to_get.get(identifier.lower())
|
110 |
+
if cls is None:
|
111 |
+
raise ValueError(f"Could not interpret optimizer : {str(identifier)}")
|
112 |
+
return cls
|
113 |
+
raise ValueError(f"Could not interpret optimizer : {str(identifier)}")
|
look2hear/system/schedulers.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim.optimizer import Optimizer
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
5 |
+
|
6 |
+
|
7 |
+
class BaseScheduler(object):
|
8 |
+
"""Base class for the step-wise scheduler logic.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
optimizer (Optimize): Optimizer instance to apply lr schedule on.
|
12 |
+
|
13 |
+
Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, optimizer):
|
17 |
+
self.optimizer = optimizer
|
18 |
+
self.step_num = 0
|
19 |
+
|
20 |
+
def zero_grad(self):
|
21 |
+
self.optimizer.zero_grad()
|
22 |
+
|
23 |
+
def _get_lr(self):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def _set_lr(self, lr):
|
27 |
+
for param_group in self.optimizer.param_groups:
|
28 |
+
param_group["lr"] = lr
|
29 |
+
|
30 |
+
def step(self, metrics=None, epoch=None):
|
31 |
+
"""Update step-wise learning rate before optimizer.step."""
|
32 |
+
self.step_num += 1
|
33 |
+
lr = self._get_lr()
|
34 |
+
self._set_lr(lr)
|
35 |
+
|
36 |
+
def load_state_dict(self, state_dict):
|
37 |
+
self.__dict__.update(state_dict)
|
38 |
+
|
39 |
+
def state_dict(self):
|
40 |
+
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
|
41 |
+
|
42 |
+
def as_tensor(self, start=0, stop=100_000):
|
43 |
+
"""Returns the scheduler values from start to stop."""
|
44 |
+
lr_list = []
|
45 |
+
for _ in range(start, stop):
|
46 |
+
self.step_num += 1
|
47 |
+
lr_list.append(self._get_lr())
|
48 |
+
self.step_num = 0
|
49 |
+
return torch.tensor(lr_list)
|
50 |
+
|
51 |
+
def plot(self, start=0, stop=100_000): # noqa
|
52 |
+
"""Plot the scheduler values from start to stop."""
|
53 |
+
import matplotlib.pyplot as plt
|
54 |
+
|
55 |
+
all_lr = self.as_tensor(start=start, stop=stop)
|
56 |
+
plt.plot(all_lr.numpy())
|
57 |
+
plt.show()
|
58 |
+
|
59 |
+
class DPTNetScheduler(BaseScheduler):
|
60 |
+
"""Dual Path Transformer Scheduler used in [1]
|
61 |
+
|
62 |
+
Args:
|
63 |
+
optimizer (Optimizer): Optimizer instance to apply lr schedule on.
|
64 |
+
steps_per_epoch (int): Number of steps per epoch.
|
65 |
+
d_model(int): The number of units in the layer output.
|
66 |
+
warmup_steps (int): The number of steps in the warmup stage of training.
|
67 |
+
noam_scale (float): Linear increase rate in first phase.
|
68 |
+
exp_max (float): Max learning rate in second phase.
|
69 |
+
exp_base (float): Exp learning rate base in second phase.
|
70 |
+
|
71 |
+
Schedule:
|
72 |
+
This scheduler increases the learning rate linearly for the first
|
73 |
+
``warmup_steps``, and then decay it by 0.98 for every two epochs.
|
74 |
+
|
75 |
+
References
|
76 |
+
[1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context-
|
77 |
+
Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
optimizer,
|
83 |
+
steps_per_epoch,
|
84 |
+
d_model,
|
85 |
+
warmup_steps=4000,
|
86 |
+
noam_scale=1.0,
|
87 |
+
exp_max=0.0004,
|
88 |
+
exp_base=0.98,
|
89 |
+
):
|
90 |
+
super().__init__(optimizer)
|
91 |
+
self.noam_scale = noam_scale
|
92 |
+
self.d_model = d_model
|
93 |
+
self.warmup_steps = warmup_steps
|
94 |
+
self.exp_max = exp_max
|
95 |
+
self.exp_base = exp_base
|
96 |
+
self.steps_per_epoch = steps_per_epoch
|
97 |
+
self.epoch = 0
|
98 |
+
|
99 |
+
def _get_lr(self):
|
100 |
+
if self.step_num % self.steps_per_epoch == 0:
|
101 |
+
self.epoch += 1
|
102 |
+
|
103 |
+
if self.step_num > self.warmup_steps:
|
104 |
+
# exp decaying
|
105 |
+
lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2))
|
106 |
+
else:
|
107 |
+
# noam
|
108 |
+
lr = (
|
109 |
+
self.noam_scale
|
110 |
+
* self.d_model ** (-0.5)
|
111 |
+
* min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5))
|
112 |
+
)
|
113 |
+
return lr
|
114 |
+
|
115 |
+
class CustomExponentialLR(_LRScheduler):
|
116 |
+
def __init__(self, optimizer, gamma, step_size, last_epoch=-1):
|
117 |
+
self.gamma = gamma
|
118 |
+
self.step_size = step_size
|
119 |
+
self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))
|
120 |
+
super(CustomExponentialLR, self).__init__(optimizer, last_epoch)
|
121 |
+
|
122 |
+
def get_lr(self):
|
123 |
+
if self.last_epoch == 0 or (self.last_epoch + 1) % self.step_size != 0:
|
124 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
125 |
+
return [lr * self.gamma for lr in self.base_lrs]
|
126 |
+
|
127 |
+
|
128 |
+
# Backward compat
|
129 |
+
_BaseScheduler = BaseScheduler
|
look2hear/utils/__init__.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-18 16:53:49
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2024-01-22 01:01:02
|
6 |
+
###
|
7 |
+
from .stft import STFT
|
8 |
+
from .torch_utils import pad_x_to_y, shape_reconstructed, tensors_to_device
|
9 |
+
from .parser_utils import (
|
10 |
+
prepare_parser_from_dict,
|
11 |
+
parse_args_as_dict,
|
12 |
+
str_int_float,
|
13 |
+
str2bool,
|
14 |
+
str2bool_arg,
|
15 |
+
isfloat,
|
16 |
+
isint,
|
17 |
+
instantiate
|
18 |
+
)
|
19 |
+
from .lightning_utils import print_only, RichProgressBarTheme, MyRichProgressBar, BatchesProcessedColumn, MyMetricsTextColumn
|
20 |
+
from .complex_utils import is_complex, is_torch_complex_tensor, new_complex_like
|
21 |
+
from .get_layer_from_string import get_layer
|
22 |
+
from .inversible_interface import InversibleInterface
|
23 |
+
from .nets_utils import make_pad_mask
|
24 |
+
from .pylogger import RankedLogger
|
25 |
+
from .separator import wav_chunk_inference
|
26 |
+
|
27 |
+
__all__ = [
|
28 |
+
"wav_chunk_inference",
|
29 |
+
"RankedLogger",
|
30 |
+
"instantiate",
|
31 |
+
"STFT",
|
32 |
+
"pad_x_to_y",
|
33 |
+
"shape_reconstructed",
|
34 |
+
"tensors_to_device",
|
35 |
+
"prepare_parser_from_dict",
|
36 |
+
"parse_args_as_dict",
|
37 |
+
"str_int_float",
|
38 |
+
"str2bool",
|
39 |
+
"str2bool_arg",
|
40 |
+
"isfloat",
|
41 |
+
"isint",
|
42 |
+
"print_only",
|
43 |
+
"RichProgressBarTheme",
|
44 |
+
"MyRichProgressBar",
|
45 |
+
"BatchesProcessedColumn",
|
46 |
+
"MyMetricsTextColumn",
|
47 |
+
"is_complex",
|
48 |
+
"is_torch_complex_tensor",
|
49 |
+
"new_complex_like",
|
50 |
+
"get_layer",
|
51 |
+
"InversibleInterface",
|
52 |
+
"make_pad_mask",
|
53 |
+
]
|
look2hear/utils/complex_utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Beamformer module."""
|
2 |
+
from typing import Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from packaging.version import parse as V
|
6 |
+
from torch_complex import functional as FC
|
7 |
+
from torch_complex.tensor import ComplexTensor
|
8 |
+
|
9 |
+
EPS = torch.finfo(torch.double).eps
|
10 |
+
is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0")
|
11 |
+
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
|
12 |
+
|
13 |
+
|
14 |
+
def new_complex_like(
|
15 |
+
ref: Union[torch.Tensor, ComplexTensor],
|
16 |
+
real_imag: Tuple[torch.Tensor, torch.Tensor],
|
17 |
+
):
|
18 |
+
if isinstance(ref, ComplexTensor):
|
19 |
+
return ComplexTensor(*real_imag)
|
20 |
+
elif is_torch_complex_tensor(ref):
|
21 |
+
return torch.complex(*real_imag)
|
22 |
+
else:
|
23 |
+
raise ValueError(
|
24 |
+
"Please update your PyTorch version to 1.9+ for complex support."
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def is_torch_complex_tensor(c):
|
29 |
+
return (
|
30 |
+
not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def is_complex(c):
|
35 |
+
return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
|
36 |
+
|
37 |
+
|
38 |
+
def to_double(c):
|
39 |
+
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
40 |
+
return c.to(dtype=torch.complex128)
|
41 |
+
else:
|
42 |
+
return c.double()
|
43 |
+
|
44 |
+
|
45 |
+
def to_float(c):
|
46 |
+
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
47 |
+
return c.to(dtype=torch.complex64)
|
48 |
+
else:
|
49 |
+
return c.float()
|
50 |
+
|
51 |
+
|
52 |
+
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
53 |
+
if not isinstance(seq, (list, tuple)):
|
54 |
+
raise TypeError(
|
55 |
+
"cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
|
56 |
+
"not Tensor"
|
57 |
+
)
|
58 |
+
if isinstance(seq[0], ComplexTensor):
|
59 |
+
return FC.cat(seq, *args, **kwargs)
|
60 |
+
else:
|
61 |
+
return torch.cat(seq, *args, **kwargs)
|
62 |
+
|
63 |
+
|
64 |
+
def complex_norm(
|
65 |
+
c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
|
66 |
+
) -> torch.Tensor:
|
67 |
+
if not is_complex(c):
|
68 |
+
raise TypeError("Input is not a complex tensor.")
|
69 |
+
if is_torch_complex_tensor(c):
|
70 |
+
return torch.norm(c, dim=dim, keepdim=keepdim)
|
71 |
+
else:
|
72 |
+
if dim is None:
|
73 |
+
return torch.sqrt((c.real**2 + c.imag**2).sum() + EPS)
|
74 |
+
else:
|
75 |
+
return torch.sqrt(
|
76 |
+
(c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
def einsum(equation, *operands):
|
81 |
+
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
82 |
+
# NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
|
83 |
+
# mixed input with complex and real tensors.
|
84 |
+
if len(operands) == 1:
|
85 |
+
if isinstance(operands[0], (tuple, list)):
|
86 |
+
operands = operands[0]
|
87 |
+
complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
|
88 |
+
return complex_module.einsum(equation, *operands)
|
89 |
+
elif len(operands) != 2:
|
90 |
+
op0 = operands[0]
|
91 |
+
same_type = all(op.dtype == op0.dtype for op in operands[1:])
|
92 |
+
if same_type:
|
93 |
+
_einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
|
94 |
+
return _einsum(equation, *operands)
|
95 |
+
else:
|
96 |
+
raise ValueError("0 or More than 2 operands are not supported.")
|
97 |
+
a, b = operands
|
98 |
+
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
99 |
+
return FC.einsum(equation, a, b)
|
100 |
+
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
101 |
+
if not torch.is_complex(a):
|
102 |
+
o_real = torch.einsum(equation, a, b.real)
|
103 |
+
o_imag = torch.einsum(equation, a, b.imag)
|
104 |
+
return torch.complex(o_real, o_imag)
|
105 |
+
elif not torch.is_complex(b):
|
106 |
+
o_real = torch.einsum(equation, a.real, b)
|
107 |
+
o_imag = torch.einsum(equation, a.imag, b)
|
108 |
+
return torch.complex(o_real, o_imag)
|
109 |
+
else:
|
110 |
+
return torch.einsum(equation, a, b)
|
111 |
+
else:
|
112 |
+
return torch.einsum(equation, a, b)
|
113 |
+
|
114 |
+
|
115 |
+
def inverse(
|
116 |
+
c: Union[torch.Tensor, ComplexTensor]
|
117 |
+
) -> Union[torch.Tensor, ComplexTensor]:
|
118 |
+
if isinstance(c, ComplexTensor):
|
119 |
+
return c.inverse2()
|
120 |
+
else:
|
121 |
+
return c.inverse()
|
122 |
+
|
123 |
+
|
124 |
+
def matmul(
|
125 |
+
a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
|
126 |
+
) -> Union[torch.Tensor, ComplexTensor]:
|
127 |
+
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
128 |
+
# NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
|
129 |
+
# multiplication between complex and real tensors.
|
130 |
+
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
131 |
+
return FC.matmul(a, b)
|
132 |
+
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
133 |
+
if not torch.is_complex(a):
|
134 |
+
o_real = torch.matmul(a, b.real)
|
135 |
+
o_imag = torch.matmul(a, b.imag)
|
136 |
+
return torch.complex(o_real, o_imag)
|
137 |
+
elif not torch.is_complex(b):
|
138 |
+
o_real = torch.matmul(a.real, b)
|
139 |
+
o_imag = torch.matmul(a.imag, b)
|
140 |
+
return torch.complex(o_real, o_imag)
|
141 |
+
else:
|
142 |
+
return torch.matmul(a, b)
|
143 |
+
else:
|
144 |
+
return torch.matmul(a, b)
|
145 |
+
|
146 |
+
|
147 |
+
def trace(a: Union[torch.Tensor, ComplexTensor]):
|
148 |
+
# NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
|
149 |
+
# support bacth processing. Use FC.trace() as fallback.
|
150 |
+
return FC.trace(a)
|
151 |
+
|
152 |
+
|
153 |
+
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
|
154 |
+
if isinstance(a, ComplexTensor):
|
155 |
+
return FC.reverse(a, dim=dim)
|
156 |
+
else:
|
157 |
+
return torch.flip(a, dims=(dim,))
|
158 |
+
|
159 |
+
|
160 |
+
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
|
161 |
+
"""Solve the linear equation ax = b."""
|
162 |
+
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
163 |
+
# NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
|
164 |
+
# mixed input with complex and real tensors.
|
165 |
+
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
166 |
+
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
|
167 |
+
return FC.solve(b, a, return_LU=False)
|
168 |
+
else:
|
169 |
+
return matmul(inverse(a), b)
|
170 |
+
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
171 |
+
if torch.is_complex(a) and torch.is_complex(b):
|
172 |
+
return torch.linalg.solve(a, b)
|
173 |
+
else:
|
174 |
+
return matmul(inverse(a), b)
|
175 |
+
else:
|
176 |
+
if is_torch_1_8_plus:
|
177 |
+
return torch.linalg.solve(a, b)
|
178 |
+
else:
|
179 |
+
return torch.solve(b, a)[0]
|
180 |
+
|
181 |
+
|
182 |
+
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
183 |
+
if not isinstance(seq, (list, tuple)):
|
184 |
+
raise TypeError(
|
185 |
+
"stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
|
186 |
+
"not Tensor"
|
187 |
+
)
|
188 |
+
if isinstance(seq[0], ComplexTensor):
|
189 |
+
return FC.stack(seq, *args, **kwargs)
|
190 |
+
else:
|
191 |
+
return torch.stack(seq, *args, **kwargs)
|
look2hear/utils/get_layer_from_string.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import difflib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def get_layer(l_name, library=torch.nn):
|
7 |
+
"""Return layer object handler from library e.g. from torch.nn
|
8 |
+
|
9 |
+
E.g. if l_name=="elu", returns torch.nn.ELU.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
l_name (string): Case insensitive name for layer in library (e.g. .'elu').
|
13 |
+
library (module): Name of library/module where to search for object handler
|
14 |
+
with l_name e.g. "torch.nn".
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU)
|
18 |
+
|
19 |
+
"""
|
20 |
+
|
21 |
+
all_torch_layers = [x for x in dir(torch.nn)]
|
22 |
+
match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
|
23 |
+
if len(match) == 0:
|
24 |
+
close_matches = difflib.get_close_matches(
|
25 |
+
l_name, [x.lower() for x in all_torch_layers]
|
26 |
+
)
|
27 |
+
raise NotImplementedError(
|
28 |
+
"Layer with name {} not found in {}.\n Closest matches: {}".format(
|
29 |
+
l_name, str(library), close_matches
|
30 |
+
)
|
31 |
+
)
|
32 |
+
elif len(match) > 1:
|
33 |
+
close_matches = difflib.get_close_matches(
|
34 |
+
l_name, [x.lower() for x in all_torch_layers]
|
35 |
+
)
|
36 |
+
raise NotImplementedError(
|
37 |
+
"Multiple matchs for layer with name {} not found in {}.\n "
|
38 |
+
"All matches: {}".format(l_name, str(library), close_matches)
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
# valid
|
42 |
+
layer_handler = getattr(library, match[0])
|
43 |
+
return layer_handler
|
look2hear/utils/inversible_interface.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class InversibleInterface(ABC):
|
8 |
+
@abstractmethod
|
9 |
+
def inverse(
|
10 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor = None
|
11 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
12 |
+
# return output, output_lengths
|
13 |
+
raise NotImplementedError
|
look2hear/utils/lightning_utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2022-05-27 10:27:56
|
4 |
+
# Email: [email protected]
|
5 |
+
# LastEditTime: 2022-06-13 12:11:15
|
6 |
+
###
|
7 |
+
from rich import print
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from pytorch_lightning.utilities import rank_zero_only
|
10 |
+
from typing import Union
|
11 |
+
from pytorch_lightning.callbacks.progress.rich_progress import *
|
12 |
+
from rich.console import Console, RenderableType
|
13 |
+
from rich.progress_bar import ProgressBar
|
14 |
+
from rich.style import Style
|
15 |
+
from rich.text import Text
|
16 |
+
from rich.progress import (
|
17 |
+
BarColumn,
|
18 |
+
DownloadColumn,
|
19 |
+
Progress,
|
20 |
+
TaskID,
|
21 |
+
TextColumn,
|
22 |
+
TimeRemainingColumn,
|
23 |
+
TransferSpeedColumn,
|
24 |
+
ProgressColumn
|
25 |
+
)
|
26 |
+
from rich import print, reconfigure
|
27 |
+
|
28 |
+
@rank_zero_only
|
29 |
+
def print_only(message: str):
|
30 |
+
print(message)
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class RichProgressBarTheme:
|
34 |
+
"""Styles to associate to different base components.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
|
38 |
+
progress_bar: Style for the bar in progress.
|
39 |
+
progress_bar_finished: Style for the finished progress bar.
|
40 |
+
progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
|
41 |
+
batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
|
42 |
+
time: Style for the processed time and estimate time remaining.
|
43 |
+
processing_speed: Style for the speed of the batches being processed.
|
44 |
+
metrics: Style for the metrics
|
45 |
+
|
46 |
+
https://rich.readthedocs.io/en/stable/style.html
|
47 |
+
"""
|
48 |
+
|
49 |
+
description: Union[str, Style] = "#FF4500"
|
50 |
+
progress_bar: Union[str, Style] = "#f92672"
|
51 |
+
progress_bar_finished: Union[str, Style] = "#b7cc8a"
|
52 |
+
progress_bar_pulse: Union[str, Style] = "#f92672"
|
53 |
+
batch_progress: Union[str, Style] = "#fc608a"
|
54 |
+
time: Union[str, Style] = "#45ada2"
|
55 |
+
processing_speed: Union[str, Style] = "#DC143C"
|
56 |
+
metrics: Union[str, Style] = "#228B22"
|
57 |
+
|
58 |
+
class BatchesProcessedColumn(ProgressColumn):
|
59 |
+
def __init__(self, style: Union[str, Style]):
|
60 |
+
self.style = style
|
61 |
+
super().__init__()
|
62 |
+
|
63 |
+
def render(self, task) -> RenderableType:
|
64 |
+
total = task.total if task.total != float("inf") else "--"
|
65 |
+
return Text(f"{int(task.completed)}/{int(total)}", style=self.style)
|
66 |
+
|
67 |
+
class MyMetricsTextColumn(ProgressColumn):
|
68 |
+
"""A column containing text."""
|
69 |
+
|
70 |
+
def __init__(self, style):
|
71 |
+
self._tasks = {}
|
72 |
+
self._current_task_id = 0
|
73 |
+
self._metrics = {}
|
74 |
+
self._style = style
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
def update(self, metrics):
|
78 |
+
# Called when metrics are ready to be rendered.
|
79 |
+
# This is to prevent render from causing deadlock issues by requesting metrics
|
80 |
+
# in separate threads.
|
81 |
+
self._metrics = metrics
|
82 |
+
|
83 |
+
def render(self, task) -> Text:
|
84 |
+
text = ""
|
85 |
+
for k, v in self._metrics.items():
|
86 |
+
text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
|
87 |
+
return Text(text, justify="left", style=self._style)
|
88 |
+
|
89 |
+
class MyRichProgressBar(RichProgressBar):
|
90 |
+
"""A progress bar prints metrics at the end of each epoch
|
91 |
+
"""
|
92 |
+
|
93 |
+
def _init_progress(self, trainer):
|
94 |
+
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
95 |
+
self._reset_progress_bar_ids()
|
96 |
+
reconfigure(**self._console_kwargs)
|
97 |
+
# file = open("/home/likai/data/Look2Hear/Experiments/run_logs/EdgeFRCNN-Noncausal.log", 'w')
|
98 |
+
self._console: Console = Console(force_terminal=True)
|
99 |
+
self._console.clear_live()
|
100 |
+
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
|
101 |
+
self.progress = CustomProgress(
|
102 |
+
*self.configure_columns(trainer),
|
103 |
+
self._metric_component,
|
104 |
+
auto_refresh=False,
|
105 |
+
disable=self.is_disabled,
|
106 |
+
console=self._console,
|
107 |
+
)
|
108 |
+
self.progress.start()
|
109 |
+
# progress has started
|
110 |
+
self._progress_stopped = False
|
look2hear/utils/nets_utils.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""Network related utility tools."""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def to_device(m, x):
|
13 |
+
"""Send tensor into the device of the module.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
m (torch.nn.Module): Torch module.
|
17 |
+
x (Tensor): Torch tensor.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor: Torch tensor located in the same place as torch module.
|
21 |
+
|
22 |
+
"""
|
23 |
+
if isinstance(m, torch.nn.Module):
|
24 |
+
device = next(m.parameters()).device
|
25 |
+
elif isinstance(m, torch.Tensor):
|
26 |
+
device = m.device
|
27 |
+
else:
|
28 |
+
raise TypeError(
|
29 |
+
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
|
30 |
+
)
|
31 |
+
return x.to(device)
|
32 |
+
|
33 |
+
|
34 |
+
def pad_list(xs, pad_value):
|
35 |
+
"""Perform padding for the list of tensors.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
39 |
+
pad_value (float): Value for padding.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
43 |
+
|
44 |
+
Examples:
|
45 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
46 |
+
>>> x
|
47 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
48 |
+
>>> pad_list(x, 0)
|
49 |
+
tensor([[1., 1., 1., 1.],
|
50 |
+
[1., 1., 0., 0.],
|
51 |
+
[1., 0., 0., 0.]])
|
52 |
+
|
53 |
+
"""
|
54 |
+
n_batch = len(xs)
|
55 |
+
max_len = max(x.size(0) for x in xs)
|
56 |
+
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
57 |
+
|
58 |
+
for i in range(n_batch):
|
59 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
60 |
+
|
61 |
+
return pad
|
62 |
+
|
63 |
+
|
64 |
+
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
65 |
+
"""Make mask tensor containing indices of padded part.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
69 |
+
xs (Tensor, optional): The reference tensor.
|
70 |
+
If set, masks will be the same shape as this tensor.
|
71 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
72 |
+
See the example.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tensor: Mask tensor containing indices of padded part.
|
76 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
77 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
78 |
+
|
79 |
+
Examples:
|
80 |
+
With only lengths.
|
81 |
+
|
82 |
+
>>> lengths = [5, 3, 2]
|
83 |
+
>>> make_pad_mask(lengths)
|
84 |
+
masks = [[0, 0, 0, 0 ,0],
|
85 |
+
[0, 0, 0, 1, 1],
|
86 |
+
[0, 0, 1, 1, 1]]
|
87 |
+
|
88 |
+
With the reference tensor.
|
89 |
+
|
90 |
+
>>> xs = torch.zeros((3, 2, 4))
|
91 |
+
>>> make_pad_mask(lengths, xs)
|
92 |
+
tensor([[[0, 0, 0, 0],
|
93 |
+
[0, 0, 0, 0]],
|
94 |
+
[[0, 0, 0, 1],
|
95 |
+
[0, 0, 0, 1]],
|
96 |
+
[[0, 0, 1, 1],
|
97 |
+
[0, 0, 1, 1]]], dtype=torch.uint8)
|
98 |
+
>>> xs = torch.zeros((3, 2, 6))
|
99 |
+
>>> make_pad_mask(lengths, xs)
|
100 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
101 |
+
[0, 0, 0, 0, 0, 1]],
|
102 |
+
[[0, 0, 0, 1, 1, 1],
|
103 |
+
[0, 0, 0, 1, 1, 1]],
|
104 |
+
[[0, 0, 1, 1, 1, 1],
|
105 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
106 |
+
|
107 |
+
With the reference tensor and dimension indicator.
|
108 |
+
|
109 |
+
>>> xs = torch.zeros((3, 6, 6))
|
110 |
+
>>> make_pad_mask(lengths, xs, 1)
|
111 |
+
tensor([[[0, 0, 0, 0, 0, 0],
|
112 |
+
[0, 0, 0, 0, 0, 0],
|
113 |
+
[0, 0, 0, 0, 0, 0],
|
114 |
+
[0, 0, 0, 0, 0, 0],
|
115 |
+
[0, 0, 0, 0, 0, 0],
|
116 |
+
[1, 1, 1, 1, 1, 1]],
|
117 |
+
[[0, 0, 0, 0, 0, 0],
|
118 |
+
[0, 0, 0, 0, 0, 0],
|
119 |
+
[0, 0, 0, 0, 0, 0],
|
120 |
+
[1, 1, 1, 1, 1, 1],
|
121 |
+
[1, 1, 1, 1, 1, 1],
|
122 |
+
[1, 1, 1, 1, 1, 1]],
|
123 |
+
[[0, 0, 0, 0, 0, 0],
|
124 |
+
[0, 0, 0, 0, 0, 0],
|
125 |
+
[1, 1, 1, 1, 1, 1],
|
126 |
+
[1, 1, 1, 1, 1, 1],
|
127 |
+
[1, 1, 1, 1, 1, 1],
|
128 |
+
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
129 |
+
>>> make_pad_mask(lengths, xs, 2)
|
130 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
131 |
+
[0, 0, 0, 0, 0, 1],
|
132 |
+
[0, 0, 0, 0, 0, 1],
|
133 |
+
[0, 0, 0, 0, 0, 1],
|
134 |
+
[0, 0, 0, 0, 0, 1],
|
135 |
+
[0, 0, 0, 0, 0, 1]],
|
136 |
+
[[0, 0, 0, 1, 1, 1],
|
137 |
+
[0, 0, 0, 1, 1, 1],
|
138 |
+
[0, 0, 0, 1, 1, 1],
|
139 |
+
[0, 0, 0, 1, 1, 1],
|
140 |
+
[0, 0, 0, 1, 1, 1],
|
141 |
+
[0, 0, 0, 1, 1, 1]],
|
142 |
+
[[0, 0, 1, 1, 1, 1],
|
143 |
+
[0, 0, 1, 1, 1, 1],
|
144 |
+
[0, 0, 1, 1, 1, 1],
|
145 |
+
[0, 0, 1, 1, 1, 1],
|
146 |
+
[0, 0, 1, 1, 1, 1],
|
147 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
148 |
+
|
149 |
+
"""
|
150 |
+
if length_dim == 0:
|
151 |
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
152 |
+
|
153 |
+
if not isinstance(lengths, list):
|
154 |
+
lengths = lengths.long().tolist()
|
155 |
+
|
156 |
+
bs = int(len(lengths))
|
157 |
+
if maxlen is None:
|
158 |
+
if xs is None:
|
159 |
+
maxlen = int(max(lengths))
|
160 |
+
else:
|
161 |
+
maxlen = xs.size(length_dim)
|
162 |
+
else:
|
163 |
+
assert xs is None
|
164 |
+
assert maxlen >= int(max(lengths))
|
165 |
+
|
166 |
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
167 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
168 |
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
169 |
+
mask = seq_range_expand >= seq_length_expand
|
170 |
+
|
171 |
+
if xs is not None:
|
172 |
+
assert xs.size(0) == bs, (xs.size(0), bs)
|
173 |
+
|
174 |
+
if length_dim < 0:
|
175 |
+
length_dim = xs.dim() + length_dim
|
176 |
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
177 |
+
ind = tuple(
|
178 |
+
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
179 |
+
)
|
180 |
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
181 |
+
return mask
|
182 |
+
|
183 |
+
|
184 |
+
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
185 |
+
"""Make mask tensor containing indices of non-padded part.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
189 |
+
xs (Tensor, optional): The reference tensor.
|
190 |
+
If set, masks will be the same shape as this tensor.
|
191 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
192 |
+
See the example.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
ByteTensor: mask tensor containing indices of padded part.
|
196 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
197 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
198 |
+
|
199 |
+
Examples:
|
200 |
+
With only lengths.
|
201 |
+
|
202 |
+
>>> lengths = [5, 3, 2]
|
203 |
+
>>> make_non_pad_mask(lengths)
|
204 |
+
masks = [[1, 1, 1, 1 ,1],
|
205 |
+
[1, 1, 1, 0, 0],
|
206 |
+
[1, 1, 0, 0, 0]]
|
207 |
+
|
208 |
+
With the reference tensor.
|
209 |
+
|
210 |
+
>>> xs = torch.zeros((3, 2, 4))
|
211 |
+
>>> make_non_pad_mask(lengths, xs)
|
212 |
+
tensor([[[1, 1, 1, 1],
|
213 |
+
[1, 1, 1, 1]],
|
214 |
+
[[1, 1, 1, 0],
|
215 |
+
[1, 1, 1, 0]],
|
216 |
+
[[1, 1, 0, 0],
|
217 |
+
[1, 1, 0, 0]]], dtype=torch.uint8)
|
218 |
+
>>> xs = torch.zeros((3, 2, 6))
|
219 |
+
>>> make_non_pad_mask(lengths, xs)
|
220 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
221 |
+
[1, 1, 1, 1, 1, 0]],
|
222 |
+
[[1, 1, 1, 0, 0, 0],
|
223 |
+
[1, 1, 1, 0, 0, 0]],
|
224 |
+
[[1, 1, 0, 0, 0, 0],
|
225 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
226 |
+
|
227 |
+
With the reference tensor and dimension indicator.
|
228 |
+
|
229 |
+
>>> xs = torch.zeros((3, 6, 6))
|
230 |
+
>>> make_non_pad_mask(lengths, xs, 1)
|
231 |
+
tensor([[[1, 1, 1, 1, 1, 1],
|
232 |
+
[1, 1, 1, 1, 1, 1],
|
233 |
+
[1, 1, 1, 1, 1, 1],
|
234 |
+
[1, 1, 1, 1, 1, 1],
|
235 |
+
[1, 1, 1, 1, 1, 1],
|
236 |
+
[0, 0, 0, 0, 0, 0]],
|
237 |
+
[[1, 1, 1, 1, 1, 1],
|
238 |
+
[1, 1, 1, 1, 1, 1],
|
239 |
+
[1, 1, 1, 1, 1, 1],
|
240 |
+
[0, 0, 0, 0, 0, 0],
|
241 |
+
[0, 0, 0, 0, 0, 0],
|
242 |
+
[0, 0, 0, 0, 0, 0]],
|
243 |
+
[[1, 1, 1, 1, 1, 1],
|
244 |
+
[1, 1, 1, 1, 1, 1],
|
245 |
+
[0, 0, 0, 0, 0, 0],
|
246 |
+
[0, 0, 0, 0, 0, 0],
|
247 |
+
[0, 0, 0, 0, 0, 0],
|
248 |
+
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
249 |
+
>>> make_non_pad_mask(lengths, xs, 2)
|
250 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
251 |
+
[1, 1, 1, 1, 1, 0],
|
252 |
+
[1, 1, 1, 1, 1, 0],
|
253 |
+
[1, 1, 1, 1, 1, 0],
|
254 |
+
[1, 1, 1, 1, 1, 0],
|
255 |
+
[1, 1, 1, 1, 1, 0]],
|
256 |
+
[[1, 1, 1, 0, 0, 0],
|
257 |
+
[1, 1, 1, 0, 0, 0],
|
258 |
+
[1, 1, 1, 0, 0, 0],
|
259 |
+
[1, 1, 1, 0, 0, 0],
|
260 |
+
[1, 1, 1, 0, 0, 0],
|
261 |
+
[1, 1, 1, 0, 0, 0]],
|
262 |
+
[[1, 1, 0, 0, 0, 0],
|
263 |
+
[1, 1, 0, 0, 0, 0],
|
264 |
+
[1, 1, 0, 0, 0, 0],
|
265 |
+
[1, 1, 0, 0, 0, 0],
|
266 |
+
[1, 1, 0, 0, 0, 0],
|
267 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
268 |
+
|
269 |
+
"""
|
270 |
+
return ~make_pad_mask(lengths, xs, length_dim)
|
271 |
+
|
272 |
+
|
273 |
+
def mask_by_length(xs, lengths, fill=0):
|
274 |
+
"""Mask tensor according to length.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
xs (Tensor): Batch of input tensor (B, `*`).
|
278 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
279 |
+
fill (int or float): Value to fill masked part.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Tensor: Batch of masked input tensor (B, `*`).
|
283 |
+
|
284 |
+
Examples:
|
285 |
+
>>> x = torch.arange(5).repeat(3, 1) + 1
|
286 |
+
>>> x
|
287 |
+
tensor([[1, 2, 3, 4, 5],
|
288 |
+
[1, 2, 3, 4, 5],
|
289 |
+
[1, 2, 3, 4, 5]])
|
290 |
+
>>> lengths = [5, 3, 2]
|
291 |
+
>>> mask_by_length(x, lengths)
|
292 |
+
tensor([[1, 2, 3, 4, 5],
|
293 |
+
[1, 2, 3, 0, 0],
|
294 |
+
[1, 2, 0, 0, 0]])
|
295 |
+
|
296 |
+
"""
|
297 |
+
assert xs.size(0) == len(lengths)
|
298 |
+
ret = xs.data.new(*xs.size()).fill_(fill)
|
299 |
+
for i, l in enumerate(lengths):
|
300 |
+
ret[i, :l] = xs[i, :l]
|
301 |
+
return ret
|
302 |
+
|
303 |
+
|
304 |
+
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
305 |
+
"""Calculate accuracy.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
309 |
+
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
310 |
+
ignore_label (int): Ignore label id.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
float: Accuracy value (0.0 - 1.0).
|
314 |
+
|
315 |
+
"""
|
316 |
+
pad_pred = pad_outputs.view(
|
317 |
+
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
318 |
+
).argmax(2)
|
319 |
+
mask = pad_targets != ignore_label
|
320 |
+
numerator = torch.sum(
|
321 |
+
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
|
322 |
+
)
|
323 |
+
denominator = torch.sum(mask)
|
324 |
+
return float(numerator) / float(denominator)
|
325 |
+
|
326 |
+
|
327 |
+
def to_torch_tensor(x):
|
328 |
+
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
Tensor or ComplexTensor: Type converted inputs.
|
335 |
+
|
336 |
+
Examples:
|
337 |
+
>>> xs = np.ones(3, dtype=np.float32)
|
338 |
+
>>> xs = to_torch_tensor(xs)
|
339 |
+
tensor([1., 1., 1.])
|
340 |
+
>>> xs = torch.ones(3, 4, 5)
|
341 |
+
>>> assert to_torch_tensor(xs) is xs
|
342 |
+
>>> xs = {'real': xs, 'imag': xs}
|
343 |
+
>>> to_torch_tensor(xs)
|
344 |
+
ComplexTensor(
|
345 |
+
Real:
|
346 |
+
tensor([1., 1., 1.])
|
347 |
+
Imag;
|
348 |
+
tensor([1., 1., 1.])
|
349 |
+
)
|
350 |
+
|
351 |
+
"""
|
352 |
+
# If numpy, change to torch tensor
|
353 |
+
if isinstance(x, np.ndarray):
|
354 |
+
if x.dtype.kind == "c":
|
355 |
+
# Dynamically importing because torch_complex requires python3
|
356 |
+
from torch_complex.tensor import ComplexTensor
|
357 |
+
|
358 |
+
return ComplexTensor(x)
|
359 |
+
else:
|
360 |
+
return torch.from_numpy(x)
|
361 |
+
|
362 |
+
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
363 |
+
elif isinstance(x, dict):
|
364 |
+
# Dynamically importing because torch_complex requires python3
|
365 |
+
from torch_complex.tensor import ComplexTensor
|
366 |
+
|
367 |
+
if "real" not in x or "imag" not in x:
|
368 |
+
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
369 |
+
# Relative importing because of using python3 syntax
|
370 |
+
return ComplexTensor(x["real"], x["imag"])
|
371 |
+
|
372 |
+
# If torch.Tensor, as it is
|
373 |
+
elif isinstance(x, torch.Tensor):
|
374 |
+
return x
|
375 |
+
|
376 |
+
else:
|
377 |
+
error = (
|
378 |
+
"x must be numpy.ndarray, torch.Tensor or a dict like "
|
379 |
+
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
380 |
+
"but got {}".format(type(x))
|
381 |
+
)
|
382 |
+
try:
|
383 |
+
from torch_complex.tensor import ComplexTensor
|
384 |
+
except Exception:
|
385 |
+
# If PY2
|
386 |
+
raise ValueError(error)
|
387 |
+
else:
|
388 |
+
# If PY3
|
389 |
+
if isinstance(x, ComplexTensor):
|
390 |
+
return x
|
391 |
+
else:
|
392 |
+
raise ValueError(error)
|
393 |
+
|
394 |
+
|
395 |
+
def get_subsample(train_args, mode, arch):
|
396 |
+
"""Parse the subsampling factors from the args for the specified `mode` and `arch`.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
train_args: argument Namespace containing options.
|
400 |
+
mode: one of ('asr', 'mt', 'st')
|
401 |
+
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
np.ndarray / List[np.ndarray]: subsampling factors.
|
405 |
+
"""
|
406 |
+
if arch == "transformer":
|
407 |
+
return np.array([1])
|
408 |
+
|
409 |
+
elif mode == "mt" and arch == "rnn":
|
410 |
+
# +1 means input (+1) and layers outputs (train_args.elayer)
|
411 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
|
412 |
+
logging.warning("Subsampling is not performed for machine translation.")
|
413 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
414 |
+
return subsample
|
415 |
+
|
416 |
+
elif (
|
417 |
+
(mode == "asr" and arch in ("rnn", "rnn-t"))
|
418 |
+
or (mode == "mt" and arch == "rnn")
|
419 |
+
or (mode == "st" and arch == "rnn")
|
420 |
+
):
|
421 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
|
422 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
423 |
+
ss = train_args.subsample.split("_")
|
424 |
+
for j in range(min(train_args.elayers + 1, len(ss))):
|
425 |
+
subsample[j] = int(ss[j])
|
426 |
+
else:
|
427 |
+
logging.warning(
|
428 |
+
"Subsampling is not performed for vgg*. "
|
429 |
+
"It is performed in max pooling layers at CNN."
|
430 |
+
)
|
431 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
432 |
+
return subsample
|
433 |
+
|
434 |
+
elif mode == "asr" and arch == "rnn_mix":
|
435 |
+
subsample = np.ones(
|
436 |
+
train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64
|
437 |
+
)
|
438 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
439 |
+
ss = train_args.subsample.split("_")
|
440 |
+
for j in range(
|
441 |
+
min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
|
442 |
+
):
|
443 |
+
subsample[j] = int(ss[j])
|
444 |
+
else:
|
445 |
+
logging.warning(
|
446 |
+
"Subsampling is not performed for vgg*. "
|
447 |
+
"It is performed in max pooling layers at CNN."
|
448 |
+
)
|
449 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
450 |
+
return subsample
|
451 |
+
|
452 |
+
elif mode == "asr" and arch == "rnn_mulenc":
|
453 |
+
subsample_list = []
|
454 |
+
for idx in range(train_args.num_encs):
|
455 |
+
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64)
|
456 |
+
if train_args.etype[idx].endswith("p") and not train_args.etype[
|
457 |
+
idx
|
458 |
+
].startswith("vgg"):
|
459 |
+
ss = train_args.subsample[idx].split("_")
|
460 |
+
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
461 |
+
subsample[j] = int(ss[j])
|
462 |
+
else:
|
463 |
+
logging.warning(
|
464 |
+
"Encoder %d: Subsampling is not performed for vgg*. "
|
465 |
+
"It is performed in max pooling layers at CNN.",
|
466 |
+
idx + 1,
|
467 |
+
)
|
468 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
469 |
+
subsample_list.append(subsample)
|
470 |
+
return subsample_list
|
471 |
+
|
472 |
+
else:
|
473 |
+
raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
|
474 |
+
|
475 |
+
|
476 |
+
def rename_state_dict(
|
477 |
+
old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
|
478 |
+
):
|
479 |
+
"""Replace keys of old prefix with new prefix in state dict."""
|
480 |
+
# need this list not to break the dict iterator
|
481 |
+
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
482 |
+
if len(old_keys) > 0:
|
483 |
+
logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
|
484 |
+
for k in old_keys:
|
485 |
+
v = state_dict.pop(k)
|
486 |
+
new_k = k.replace(old_prefix, new_prefix)
|
487 |
+
state_dict[new_k] = v
|
488 |
+
|
489 |
+
|
490 |
+
def get_activation(act):
|
491 |
+
"""Return activation function."""
|
492 |
+
# Lazy load to avoid unused import
|
493 |
+
from espnet.nets.pytorch_backend.conformer.swish import Swish
|
494 |
+
|
495 |
+
activation_funcs = {
|
496 |
+
"hardtanh": torch.nn.Hardtanh,
|
497 |
+
"tanh": torch.nn.Tanh,
|
498 |
+
"relu": torch.nn.ReLU,
|
499 |
+
"selu": torch.nn.SELU,
|
500 |
+
"swish": Swish,
|
501 |
+
}
|
502 |
+
|
503 |
+
return activation_funcs[act]()
|
look2hear/utils/parser_utils.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-20 00:36:46
|
4 |
+
# LastEditors: Please set LastEditors
|
5 |
+
# LastEditTime: 2024-01-22 03:02:57
|
6 |
+
###
|
7 |
+
import sys
|
8 |
+
import argparse
|
9 |
+
import importlib
|
10 |
+
from omegaconf import DictConfig
|
11 |
+
|
12 |
+
def prepare_parser_from_dict(dic, parser=None):
|
13 |
+
"""Prepare an argparser from a dictionary.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
dic (dict): Two-level config dictionary with unique bottom-level keys.
|
17 |
+
parser (argparse.ArgumentParser, optional): If a parser already
|
18 |
+
exists, add the keys from the dictionary on the top of it.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
argparse.ArgumentParser:
|
22 |
+
Parser instance with groups corresponding to the first level keys
|
23 |
+
and arguments corresponding to the second level keys with default
|
24 |
+
values given by the values.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def standardized_entry_type(value):
|
28 |
+
"""If the default value is None, replace NoneType by str_int_float.
|
29 |
+
If the default value is boolean, look for boolean strings."""
|
30 |
+
if value is None:
|
31 |
+
return str_int_float
|
32 |
+
if isinstance(str2bool(value), bool):
|
33 |
+
return str2bool_arg
|
34 |
+
return type(value)
|
35 |
+
|
36 |
+
if parser is None:
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
for k in dic.keys():
|
39 |
+
group = parser.add_argument_group(k)
|
40 |
+
if isinstance(dic[k], list):
|
41 |
+
entry_type = standardized_entry_type(dic[k])
|
42 |
+
group.add_argument("--" + k, default=dic[k], type=entry_type)
|
43 |
+
elif isinstance(dic[k], dict):
|
44 |
+
for kk in dic[k].keys():
|
45 |
+
entry_type = standardized_entry_type(dic[k][kk])
|
46 |
+
group.add_argument("--" + kk, default=dic[k][kk], type=entry_type)
|
47 |
+
elif isinstance(dic[k], str):
|
48 |
+
entry_type = standardized_entry_type(dic[k])
|
49 |
+
group.add_argument("--" + k, default=dic[k], type=entry_type)
|
50 |
+
return parser
|
51 |
+
|
52 |
+
|
53 |
+
def str_int_float(value):
|
54 |
+
"""Type to convert strings to int, float (in this order) if possible.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
value (str): Value to convert.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
int, float, str: Converted value.
|
61 |
+
"""
|
62 |
+
if isint(value):
|
63 |
+
return int(value)
|
64 |
+
if isfloat(value):
|
65 |
+
return float(value)
|
66 |
+
elif isinstance(value, str):
|
67 |
+
return value
|
68 |
+
|
69 |
+
|
70 |
+
def str2bool(value):
|
71 |
+
"""Type to convert strings to Boolean (returns input if not boolean)"""
|
72 |
+
if not isinstance(value, str):
|
73 |
+
return value
|
74 |
+
if value.lower() in ("yes", "true", "y", "1"):
|
75 |
+
return True
|
76 |
+
elif value.lower() in ("no", "false", "n", "0"):
|
77 |
+
return False
|
78 |
+
else:
|
79 |
+
return value
|
80 |
+
|
81 |
+
|
82 |
+
def str2bool_arg(value):
|
83 |
+
"""Argparse type to convert strings to Boolean"""
|
84 |
+
value = str2bool(value)
|
85 |
+
if isinstance(value, bool):
|
86 |
+
return value
|
87 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
88 |
+
|
89 |
+
|
90 |
+
def isfloat(value):
|
91 |
+
"""Computes whether `value` can be cast to a float.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
value (str): Value to check.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
bool: Whether `value` can be cast to a float.
|
98 |
+
|
99 |
+
"""
|
100 |
+
try:
|
101 |
+
float(value)
|
102 |
+
return True
|
103 |
+
except ValueError:
|
104 |
+
return False
|
105 |
+
|
106 |
+
|
107 |
+
def isint(value):
|
108 |
+
"""Computes whether `value` can be cast to an int
|
109 |
+
|
110 |
+
Args:
|
111 |
+
value (str): Value to check.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
bool: Whether `value` can be cast to an int.
|
115 |
+
|
116 |
+
"""
|
117 |
+
try:
|
118 |
+
int(value)
|
119 |
+
return True
|
120 |
+
except ValueError:
|
121 |
+
return False
|
122 |
+
|
123 |
+
|
124 |
+
def parse_args_as_dict(parser, return_plain_args=False, args=None):
|
125 |
+
"""Get a dict of dicts out of process `parser.parse_args()`
|
126 |
+
|
127 |
+
Top-level keys corresponding to groups and bottom-level keys corresponding
|
128 |
+
to arguments. Under `'main_args'`, the arguments which don't belong to a
|
129 |
+
argparse group (i.e main arguments defined before parsing from a dict) can
|
130 |
+
be found.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
parser (argparse.ArgumentParser): ArgumentParser instance containing
|
134 |
+
groups. Output of `prepare_parser_from_dict`.
|
135 |
+
return_plain_args (bool): Whether to return the output or
|
136 |
+
`parser.parse_args()`.
|
137 |
+
args (list): List of arguments as read from the command line.
|
138 |
+
Used for unit testing.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
dict:
|
142 |
+
Dictionary of dictionaries containing the arguments. Optionally the
|
143 |
+
direct output `parser.parse_args()`.
|
144 |
+
"""
|
145 |
+
args = parser.parse_args(args=args)
|
146 |
+
args_dic = {}
|
147 |
+
for group in parser._action_groups:
|
148 |
+
group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
|
149 |
+
args_dic[group.title] = group_dict
|
150 |
+
if sys.version_info.minor == 10:
|
151 |
+
args_dic["main_args"] = args_dic["positional arguments"]
|
152 |
+
del args_dic["positional arguments"]
|
153 |
+
else:
|
154 |
+
args_dic["main_args"] = args_dic["optional arguments"]
|
155 |
+
del args_dic["optional arguments"]
|
156 |
+
if return_plain_args:
|
157 |
+
return args_dic, args
|
158 |
+
return args_dic
|
159 |
+
|
160 |
+
def instantiate(config, **kwargs):
|
161 |
+
if '__target__' in config:
|
162 |
+
module_path, class_name = config['__target__'].rsplit('.', 1)
|
163 |
+
module = importlib.import_module(module_path)
|
164 |
+
cls = getattr(module, class_name)
|
165 |
+
# 先处理嵌套的配置
|
166 |
+
params = {}
|
167 |
+
for key, value in config.items():
|
168 |
+
if key != '__target__':
|
169 |
+
if isinstance(value, DictConfig) and '__target__' in value:
|
170 |
+
params[key] = instantiate(value)
|
171 |
+
else:
|
172 |
+
params[key] = value
|
173 |
+
# 添加额外的关键字参数
|
174 |
+
params.update(kwargs)
|
175 |
+
return cls(**params)
|
176 |
+
else:
|
177 |
+
# 对于不包含 '__target__' 的字典,递归处理其每个值
|
178 |
+
return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()}
|
look2hear/utils/pylogger.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Mapping, Optional
|
3 |
+
|
4 |
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
5 |
+
|
6 |
+
|
7 |
+
class RankedLogger(logging.LoggerAdapter):
|
8 |
+
"""A multi-GPU-friendly python command line logger."""
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
name: str = __name__,
|
13 |
+
rank_zero_only: bool = False,
|
14 |
+
extra: Optional[Mapping[str, object]] = None,
|
15 |
+
log_file: str = "log.txt", # 添加日志文件名参数
|
16 |
+
) -> None:
|
17 |
+
logger = logging.getLogger(name)
|
18 |
+
|
19 |
+
# 设置日志格式
|
20 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
21 |
+
|
22 |
+
# 添加文件处理器
|
23 |
+
file_handler = logging.FileHandler(log_file)
|
24 |
+
file_handler.setFormatter(formatter)
|
25 |
+
logger.addHandler(file_handler)
|
26 |
+
|
27 |
+
super().__init__(logger=logger, extra=extra)
|
28 |
+
self.rank_zero_only = rank_zero_only
|
29 |
+
|
30 |
+
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
|
31 |
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
32 |
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
33 |
+
occur on that rank/process.
|
34 |
+
|
35 |
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
36 |
+
:param msg: The message to log.
|
37 |
+
:param rank: The rank to log at.
|
38 |
+
:param args: Additional args to pass to the underlying logging function.
|
39 |
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
40 |
+
"""
|
41 |
+
if self.isEnabledFor(level):
|
42 |
+
msg, kwargs = self.process(msg, kwargs)
|
43 |
+
current_rank = getattr(rank_zero_only, "rank", None)
|
44 |
+
if current_rank is None:
|
45 |
+
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
|
46 |
+
msg = rank_prefixed_message(msg, current_rank)
|
47 |
+
if self.rank_zero_only:
|
48 |
+
if current_rank == 0:
|
49 |
+
self.logger.log(level, msg, *args, **kwargs)
|
50 |
+
else:
|
51 |
+
if rank is None:
|
52 |
+
self.logger.log(level, msg, *args, **kwargs)
|
53 |
+
elif current_rank == rank:
|
54 |
+
self.logger.log(level, msg, *args, **kwargs)
|
look2hear/utils/separator.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-18 16:32:50
|
4 |
+
# LastEditors: Kai Li
|
5 |
+
# LastEditTime: 2021-06-19 01:02:04
|
6 |
+
###
|
7 |
+
import os
|
8 |
+
import warnings
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import soundfile as sf
|
12 |
+
|
13 |
+
|
14 |
+
def get_device(tensor_or_module, default=None):
|
15 |
+
if hasattr(tensor_or_module, "device"):
|
16 |
+
return tensor_or_module.device
|
17 |
+
elif hasattr(tensor_or_module, "parameters"):
|
18 |
+
return next(tensor_or_module.parameters()).device
|
19 |
+
elif default is None:
|
20 |
+
raise TypeError(
|
21 |
+
f"Don't know how to get device of {type(tensor_or_module)} object"
|
22 |
+
)
|
23 |
+
else:
|
24 |
+
return torch.device(default)
|
25 |
+
|
26 |
+
|
27 |
+
class Separator:
|
28 |
+
def forward_wav(self, wav, **kwargs):
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
def sample_rate(self):
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
|
35 |
+
def separate(model, wav, **kwargs):
|
36 |
+
if isinstance(wav, np.ndarray):
|
37 |
+
return numpy_separate(model, wav, **kwargs)
|
38 |
+
elif isinstance(wav, torch.Tensor):
|
39 |
+
return torch_separate(model, wav, **kwargs)
|
40 |
+
else:
|
41 |
+
raise ValueError(
|
42 |
+
f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def torch_separate(model: Separator, wav: torch.Tensor, **kwargs) -> torch.Tensor:
|
48 |
+
"""Core logic of `separate`."""
|
49 |
+
if model.in_channels is not None and wav.shape[-2] != model.in_channels:
|
50 |
+
raise RuntimeError(
|
51 |
+
f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels."
|
52 |
+
f"Please match the number of channels."
|
53 |
+
)
|
54 |
+
# Handle device placement
|
55 |
+
input_device = get_device(wav, default="cpu")
|
56 |
+
model_device = get_device(model, default="cpu")
|
57 |
+
wav = wav.to(model_device)
|
58 |
+
# Forward
|
59 |
+
separate_func = getattr(model, "forward_wav", model)
|
60 |
+
out_wavs = separate_func(wav, **kwargs)
|
61 |
+
|
62 |
+
# FIXME: for now this is the best we can do.
|
63 |
+
out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
|
64 |
+
|
65 |
+
# Back to input device (and numpy if necessary)
|
66 |
+
out_wavs = out_wavs.to(input_device)
|
67 |
+
return out_wavs
|
68 |
+
|
69 |
+
|
70 |
+
def numpy_separate(model: Separator, wav: np.ndarray, **kwargs) -> np.ndarray:
|
71 |
+
"""Numpy interface to `separate`."""
|
72 |
+
wav = torch.from_numpy(wav)
|
73 |
+
out_wavs = torch_separate(model, wav, **kwargs)
|
74 |
+
out_wavs = out_wavs.data.numpy()
|
75 |
+
return out_wavs
|
76 |
+
|
77 |
+
|
78 |
+
def wav_chunk_inference(model, mixture_tensor, sr=16000, target_length=12.0, hop_length=4.0, batch_size=10, n_tracks=3):
|
79 |
+
"""
|
80 |
+
Input:
|
81 |
+
mixture_tensor: Tensor, [nch, input_length]
|
82 |
+
|
83 |
+
Output:
|
84 |
+
all_target_tensor: Tensor, [nch, n_track, input_length]
|
85 |
+
"""
|
86 |
+
batch_mixture = mixture_tensor
|
87 |
+
|
88 |
+
# split data into segments
|
89 |
+
batch_length = batch_mixture.shape[-1]
|
90 |
+
|
91 |
+
session = int(sr * target_length)
|
92 |
+
target = int(sr * target_length)
|
93 |
+
ignore = (session - target) // 2
|
94 |
+
hop = int(sr * hop_length)
|
95 |
+
tr_ratio = target_length / hop_length
|
96 |
+
if ignore > 0:
|
97 |
+
zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], ignore).type(batch_mixture.type()).to(batch_mixture.device)
|
98 |
+
batch_mixture_pad = torch.cat([zero_pad, batch_mixture, zero_pad], -1)
|
99 |
+
else:
|
100 |
+
batch_mixture_pad = batch_mixture
|
101 |
+
if target - hop > 0:
|
102 |
+
hop_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], target-hop).type(batch_mixture.type()).to(batch_mixture.device)
|
103 |
+
batch_mixture_pad = torch.cat([hop_pad, batch_mixture_pad, hop_pad], -1)
|
104 |
+
|
105 |
+
skip_idx = ignore + target - hop
|
106 |
+
zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], session).type(batch_mixture.type()).to(batch_mixture.device)
|
107 |
+
num_session = (batch_mixture_pad.shape[-1] - session) // hop + 2
|
108 |
+
all_target = torch.zeros(batch_mixture_pad.shape[0], n_tracks, batch_mixture_pad.shape[1], batch_mixture_pad.shape[2]).to(batch_mixture_pad.device)
|
109 |
+
all_input = []
|
110 |
+
all_segment_length = []
|
111 |
+
|
112 |
+
for i in range(num_session):
|
113 |
+
this_input = batch_mixture_pad[:,:,i*hop:i*hop+session]
|
114 |
+
segment_length = this_input.shape[-1]
|
115 |
+
if segment_length < session:
|
116 |
+
this_input = torch.cat([this_input, zero_pad[:,:,:session-segment_length]], -1)
|
117 |
+
all_input.append(this_input)
|
118 |
+
all_segment_length.append(segment_length)
|
119 |
+
|
120 |
+
all_input = torch.cat(all_input, 0)
|
121 |
+
num_batch = num_session // batch_size
|
122 |
+
if num_session % batch_size > 0:
|
123 |
+
num_batch += 1
|
124 |
+
|
125 |
+
for i in range(num_batch):
|
126 |
+
|
127 |
+
this_input = all_input[i*batch_size:(i+1)*batch_size]
|
128 |
+
actual_batch_size = this_input.shape[0]
|
129 |
+
with torch.no_grad():
|
130 |
+
est_target = model(this_input)
|
131 |
+
# print(est_target.shape)
|
132 |
+
for j in range(actual_batch_size):
|
133 |
+
this_est_target = est_target[j,:,:,:all_segment_length[i*batch_size+j]][:,:,ignore:ignore+target].unsqueeze(0)
|
134 |
+
all_target[:,:,:,ignore+(i*batch_size+j)*hop:ignore+(i*batch_size+j)*hop+target] += this_est_target
|
135 |
+
|
136 |
+
all_target = all_target[:,:,:,skip_idx:skip_idx+batch_length].contiguous() / tr_ratio
|
137 |
+
|
138 |
+
return all_target.squeeze(0)
|
look2hear/utils/stft.py
ADDED
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Jian Wu
|
2 |
+
# License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as tf
|
10 |
+
import librosa.filters as filters
|
11 |
+
|
12 |
+
from typing import Optional, Tuple
|
13 |
+
from distutils.version import LooseVersion
|
14 |
+
|
15 |
+
EPSILON = float(np.finfo(np.float32).eps)
|
16 |
+
TORCH_VERSION = th.__version__
|
17 |
+
|
18 |
+
if TORCH_VERSION >= LooseVersion("1.7"):
|
19 |
+
from torch.fft import fft as fft_func
|
20 |
+
else:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
def export_jit(transform: nn.Module) -> nn.Module:
|
25 |
+
"""
|
26 |
+
Export transform module for inference
|
27 |
+
"""
|
28 |
+
export_out = [module for module in transform if module.exportable()]
|
29 |
+
return nn.Sequential(*export_out)
|
30 |
+
|
31 |
+
|
32 |
+
def init_window(wnd: str, frame_len: int, device: th.device = "cpu") -> th.Tensor:
|
33 |
+
"""
|
34 |
+
Return window coefficient
|
35 |
+
Args:
|
36 |
+
wnd: window name
|
37 |
+
frame_len: length of the frame
|
38 |
+
"""
|
39 |
+
|
40 |
+
def sqrthann(frame_len, periodic=True):
|
41 |
+
return th.hann_window(frame_len, periodic=periodic) ** 0.5
|
42 |
+
|
43 |
+
if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]:
|
44 |
+
raise RuntimeError(f"Unknown window type: {wnd}")
|
45 |
+
|
46 |
+
wnd_tpl = {
|
47 |
+
"sqrthann": sqrthann,
|
48 |
+
"hann": th.hann_window,
|
49 |
+
"hamm": th.hamming_window,
|
50 |
+
"blackman": th.blackman_window,
|
51 |
+
"bartlett": th.bartlett_window,
|
52 |
+
"rect": th.ones,
|
53 |
+
}
|
54 |
+
if wnd != "rect":
|
55 |
+
# match with librosa
|
56 |
+
c = wnd_tpl[wnd](frame_len, periodic=True)
|
57 |
+
else:
|
58 |
+
c = wnd_tpl[wnd](frame_len)
|
59 |
+
return c.to(device)
|
60 |
+
|
61 |
+
|
62 |
+
def init_kernel(
|
63 |
+
frame_len: int,
|
64 |
+
frame_hop: int,
|
65 |
+
window: th.Tensor,
|
66 |
+
round_pow_of_two: bool = True,
|
67 |
+
normalized: bool = False,
|
68 |
+
inverse: bool = False,
|
69 |
+
mode: str = "librosa",
|
70 |
+
) -> Tuple[th.Tensor, th.Tensor]:
|
71 |
+
"""
|
72 |
+
Return STFT kernels
|
73 |
+
Args:
|
74 |
+
frame_len: length of the frame
|
75 |
+
frame_hop: hop size between frames
|
76 |
+
window: window tensor
|
77 |
+
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
|
78 |
+
normalized: return normalized DFT matrix
|
79 |
+
inverse: return iDFT matrix
|
80 |
+
mode: framing mode (librosa or kaldi)
|
81 |
+
"""
|
82 |
+
if mode not in ["librosa", "kaldi"]:
|
83 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
84 |
+
# FFT size: B
|
85 |
+
if round_pow_of_two or mode == "kaldi":
|
86 |
+
fft_size = 2 ** math.ceil(math.log2(frame_len))
|
87 |
+
else:
|
88 |
+
fft_size = frame_len
|
89 |
+
# center padding window if needed
|
90 |
+
if mode == "librosa" and fft_size != frame_len:
|
91 |
+
lpad = (fft_size - frame_len) // 2
|
92 |
+
window = tf.pad(window, (lpad, fft_size - frame_len - lpad))
|
93 |
+
if normalized:
|
94 |
+
# make K^H * K = I
|
95 |
+
S = fft_size ** 0.5
|
96 |
+
else:
|
97 |
+
S = 1
|
98 |
+
# W x B x 2
|
99 |
+
if TORCH_VERSION >= LooseVersion("1.7"):
|
100 |
+
K = fft_func(th.eye(fft_size) / S, dim=-1)
|
101 |
+
K = th.stack([K.real, K.imag], dim=-1)
|
102 |
+
else:
|
103 |
+
I = th.stack([th.eye(fft_size), th.zeros(fft_size, fft_size)], dim=-1)
|
104 |
+
K = th.fft(I / S, 1)
|
105 |
+
if mode == "kaldi":
|
106 |
+
K = K[:frame_len]
|
107 |
+
if inverse and not normalized:
|
108 |
+
# to make K^H * K = I
|
109 |
+
K = K / fft_size
|
110 |
+
# 2 x B x W
|
111 |
+
K = th.transpose(K, 0, 2)
|
112 |
+
# 2B x 1 x W
|
113 |
+
K = th.reshape(K, (fft_size * 2, 1, K.shape[-1]))
|
114 |
+
return K.to(window.device), window
|
115 |
+
|
116 |
+
|
117 |
+
def mel_filter(
|
118 |
+
frame_len: int,
|
119 |
+
round_pow_of_two: bool = True,
|
120 |
+
num_bins: Optional[int] = None,
|
121 |
+
sr: int = 16000,
|
122 |
+
num_mels: int = 80,
|
123 |
+
fmin: float = 0.0,
|
124 |
+
fmax: Optional[float] = None,
|
125 |
+
norm: bool = False,
|
126 |
+
) -> th.Tensor:
|
127 |
+
"""
|
128 |
+
Return mel filter coefficients
|
129 |
+
Args:
|
130 |
+
frame_len: length of the frame
|
131 |
+
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
|
132 |
+
num_bins: number of the frequency bins produced by STFT
|
133 |
+
num_mels: number of the mel bands
|
134 |
+
fmin: lowest frequency (in Hz)
|
135 |
+
fmax: highest frequency (in Hz)
|
136 |
+
norm: normalize the mel filter coefficients
|
137 |
+
"""
|
138 |
+
# FFT points
|
139 |
+
if num_bins is None:
|
140 |
+
N = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
|
141 |
+
else:
|
142 |
+
N = (num_bins - 1) * 2
|
143 |
+
# fmin & fmax
|
144 |
+
freq_upper = sr // 2
|
145 |
+
if fmax is None:
|
146 |
+
fmax = freq_upper
|
147 |
+
else:
|
148 |
+
fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper)
|
149 |
+
fmin = max(0, fmin)
|
150 |
+
# mel filter coefficients
|
151 |
+
mel = filters.mel(
|
152 |
+
sr,
|
153 |
+
N,
|
154 |
+
n_mels=num_mels,
|
155 |
+
fmax=fmax,
|
156 |
+
fmin=fmin,
|
157 |
+
htk=True,
|
158 |
+
norm="slaney" if norm else None,
|
159 |
+
)
|
160 |
+
# num_mels x (N // 2 + 1)
|
161 |
+
return th.tensor(mel, dtype=th.float32)
|
162 |
+
|
163 |
+
|
164 |
+
def speed_perturb_filter(
|
165 |
+
src_sr: int, dst_sr: int, cutoff_ratio: float = 0.95, num_zeros: int = 64
|
166 |
+
) -> th.Tensor:
|
167 |
+
"""
|
168 |
+
Return speed perturb filters, reference:
|
169 |
+
https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py
|
170 |
+
Args:
|
171 |
+
src_sr: sample rate of the source signal
|
172 |
+
dst_sr: sample rate of the target signal
|
173 |
+
Return:
|
174 |
+
weight (Tensor): coefficients of the filter
|
175 |
+
"""
|
176 |
+
if src_sr == dst_sr:
|
177 |
+
raise ValueError(f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}")
|
178 |
+
gcd = math.gcd(src_sr, dst_sr)
|
179 |
+
src_sr = src_sr // gcd
|
180 |
+
dst_sr = dst_sr // gcd
|
181 |
+
if src_sr == 1 or dst_sr == 1:
|
182 |
+
raise ValueError("do not support integer downsample/upsample")
|
183 |
+
zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio
|
184 |
+
padding = 1 + int(num_zeros / zeros_per_block)
|
185 |
+
# dst_sr x src_sr x K
|
186 |
+
times = (
|
187 |
+
np.arange(dst_sr)[:, None, None] / float(dst_sr)
|
188 |
+
- np.arange(src_sr)[None, :, None] / float(src_sr)
|
189 |
+
- np.arange(2 * padding + 1)[None, None, :]
|
190 |
+
+ padding
|
191 |
+
)
|
192 |
+
window = np.heaviside(1 - np.abs(times / padding), 0.0) * (
|
193 |
+
0.5 + 0.5 * np.cos(times / padding * math.pi)
|
194 |
+
)
|
195 |
+
weight = np.sinc(times * zeros_per_block) * window * zeros_per_block / float(src_sr)
|
196 |
+
return th.tensor(weight, dtype=th.float32)
|
197 |
+
|
198 |
+
|
199 |
+
def splice_feature(
|
200 |
+
feats: th.Tensor, lctx: int = 1, rctx: int = 1, op: str = "cat"
|
201 |
+
) -> th.Tensor:
|
202 |
+
"""
|
203 |
+
Splice feature
|
204 |
+
Args:
|
205 |
+
feats (Tensor): N x ... x T x F, original feature
|
206 |
+
lctx: left context
|
207 |
+
rctx: right context
|
208 |
+
op: operator on feature context
|
209 |
+
Return:
|
210 |
+
splice (Tensor): feature with context padded
|
211 |
+
"""
|
212 |
+
if lctx + rctx == 0:
|
213 |
+
return feats
|
214 |
+
if op not in ["cat", "stack"]:
|
215 |
+
raise ValueError(f"Unknown op for feature splicing: {op}")
|
216 |
+
# [N x ... x T x F, ...]
|
217 |
+
ctx = []
|
218 |
+
T = feats.shape[-2]
|
219 |
+
for c in range(-lctx, rctx + 1):
|
220 |
+
idx = th.arange(c, c + T, device=feats.device, dtype=th.int64)
|
221 |
+
idx = th.clamp(idx, min=0, max=T - 1)
|
222 |
+
ctx.append(th.index_select(feats, -2, idx))
|
223 |
+
if op == "cat":
|
224 |
+
# N x ... x T x FD
|
225 |
+
splice = th.cat(ctx, -1)
|
226 |
+
else:
|
227 |
+
# N x ... x T x F x D
|
228 |
+
splice = th.stack(ctx, -1)
|
229 |
+
return splice
|
230 |
+
|
231 |
+
|
232 |
+
def _forward_stft(
|
233 |
+
wav: th.Tensor,
|
234 |
+
kernel: th.Tensor,
|
235 |
+
window: th.Tensor,
|
236 |
+
return_polar: bool = False,
|
237 |
+
pre_emphasis: float = 0,
|
238 |
+
frame_hop: int = 256,
|
239 |
+
onesided: bool = False,
|
240 |
+
center: bool = False,
|
241 |
+
eps: float = EPSILON,
|
242 |
+
) -> th.Tensor:
|
243 |
+
"""
|
244 |
+
STFT function implemented by conv1d (not efficient, but we don't care during training)
|
245 |
+
Args:
|
246 |
+
wav (Tensor): N x (C) x S
|
247 |
+
kernel (Tensor): STFT transform kernels, from init_kernel(...)
|
248 |
+
return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
|
249 |
+
pre_emphasis: factor of preemphasis
|
250 |
+
frame_hop: frame hop size in number samples
|
251 |
+
onesided: return half FFT bins
|
252 |
+
center: if true, we assumed to have centered frames
|
253 |
+
Return:
|
254 |
+
transform (Tensor): STFT transform results
|
255 |
+
"""
|
256 |
+
wav_dim = wav.dim()
|
257 |
+
if wav_dim not in [2, 3]:
|
258 |
+
raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
|
259 |
+
# if N x S, reshape N x 1 x S
|
260 |
+
# else: reshape NC x 1 x S
|
261 |
+
N, S = wav.shape[0], wav.shape[-1]
|
262 |
+
wav = wav.view(-1, 1, S)
|
263 |
+
# NC x 1 x S+2P
|
264 |
+
if center:
|
265 |
+
pad = kernel.shape[-1] // 2
|
266 |
+
# NOTE: match with librosa
|
267 |
+
wav = tf.pad(wav, (pad, pad), mode="reflect")
|
268 |
+
# STFT
|
269 |
+
kernel = kernel * window
|
270 |
+
if pre_emphasis > 0:
|
271 |
+
# NC x W x T
|
272 |
+
frames = tf.unfold(
|
273 |
+
wav[:, None], (1, kernel.shape[-1]), stride=frame_hop, padding=0
|
274 |
+
)
|
275 |
+
# follow Kaldi's Preemphasize
|
276 |
+
frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1]
|
277 |
+
frames[:, 0] *= 1 - pre_emphasis
|
278 |
+
# 1 x 2B x W, NC x W x T, NC x 2B x T
|
279 |
+
packed = th.matmul(kernel[:, 0][None, ...], frames)
|
280 |
+
else:
|
281 |
+
packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0)
|
282 |
+
# NC x 2B x T => N x C x 2B x T
|
283 |
+
if wav_dim == 3:
|
284 |
+
packed = packed.view(N, -1, packed.shape[-2], packed.shape[-1])
|
285 |
+
# N x (C) x B x T
|
286 |
+
real, imag = th.chunk(packed, 2, dim=-2)
|
287 |
+
# N x (C) x B/2+1 x T
|
288 |
+
if onesided:
|
289 |
+
num_bins = kernel.shape[0] // 4 + 1
|
290 |
+
real = real[..., :num_bins, :]
|
291 |
+
imag = imag[..., :num_bins, :]
|
292 |
+
if return_polar:
|
293 |
+
mag = (real ** 2 + imag ** 2 + eps) ** 0.5
|
294 |
+
pha = th.atan2(imag, real)
|
295 |
+
return th.stack([mag, pha], dim=-1)
|
296 |
+
else:
|
297 |
+
return th.stack([real, imag], dim=-1)
|
298 |
+
|
299 |
+
|
300 |
+
def _inverse_stft(
|
301 |
+
transform: th.Tensor,
|
302 |
+
kernel: th.Tensor,
|
303 |
+
window: th.Tensor,
|
304 |
+
return_polar: bool = False,
|
305 |
+
frame_hop: int = 256,
|
306 |
+
onesided: bool = False,
|
307 |
+
center: bool = False,
|
308 |
+
eps: float = EPSILON,
|
309 |
+
) -> th.Tensor:
|
310 |
+
"""
|
311 |
+
iSTFT function implemented by conv1d
|
312 |
+
Args:
|
313 |
+
transform (Tensor): STFT transform results
|
314 |
+
kernel (Tensor): STFT transform kernels, from init_kernel(...)
|
315 |
+
return_polar (bool): keep same with the one in _forward_stft
|
316 |
+
frame_hop: frame hop size in number samples
|
317 |
+
onesided: return half FFT bins
|
318 |
+
center: used in _forward_stft
|
319 |
+
Return:
|
320 |
+
wav (Tensor), N x S
|
321 |
+
"""
|
322 |
+
# (N) x F x T x 2
|
323 |
+
transform_dim = transform.dim()
|
324 |
+
# if F x T x 2, reshape 1 x F x T x 2
|
325 |
+
if transform_dim == 3:
|
326 |
+
transform = th.unsqueeze(transform, 0)
|
327 |
+
if transform_dim != 4:
|
328 |
+
raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
|
329 |
+
|
330 |
+
if return_polar:
|
331 |
+
real = transform[..., 0] * th.cos(transform[..., 1])
|
332 |
+
imag = transform[..., 0] * th.sin(transform[..., 1])
|
333 |
+
else:
|
334 |
+
real, imag = transform[..., 0], transform[..., 1]
|
335 |
+
|
336 |
+
if onesided:
|
337 |
+
# [self.num_bins - 2, ..., 1]
|
338 |
+
reverse = range(kernel.shape[0] // 4 - 1, 0, -1)
|
339 |
+
# extend matrix: N x B x T
|
340 |
+
real = th.cat([real, real[:, reverse]], 1)
|
341 |
+
imag = th.cat([imag, -imag[:, reverse]], 1)
|
342 |
+
# pack: N x 2B x T
|
343 |
+
packed = th.cat([real, imag], dim=1)
|
344 |
+
# N x 1 x T
|
345 |
+
wav = tf.conv_transpose1d(packed, kernel * window, stride=frame_hop, padding=0)
|
346 |
+
# normalized audio samples
|
347 |
+
# refer: https://github.com/pytorch/audio/blob/2ebbbf511fb1e6c47b59fd32ad7e66023fa0dff1/torchaudio/functional.py#L171
|
348 |
+
num_frames = packed.shape[-1]
|
349 |
+
win_length = window.shape[0]
|
350 |
+
# W x T
|
351 |
+
win = th.repeat_interleave(window[..., None] ** 2, num_frames, dim=-1)
|
352 |
+
# Do OLA on windows
|
353 |
+
# v1)
|
354 |
+
I = th.eye(win_length, device=win.device)[:, None]
|
355 |
+
denorm = tf.conv_transpose1d(win[None, ...], I, stride=frame_hop, padding=0)
|
356 |
+
# v2)
|
357 |
+
# num_samples = (num_frames - 1) * frame_hop + win_length
|
358 |
+
# denorm = tf.fold(win[None, ...], (num_samples, 1), (win_length, 1),
|
359 |
+
# stride=frame_hop)[..., 0]
|
360 |
+
if center:
|
361 |
+
pad = kernel.shape[-1] // 2
|
362 |
+
wav = wav[..., pad:-pad]
|
363 |
+
denorm = denorm[..., pad:-pad]
|
364 |
+
wav = wav / (denorm + eps)
|
365 |
+
# N x S
|
366 |
+
return wav.squeeze(1)
|
367 |
+
|
368 |
+
|
369 |
+
def _pytorch_stft(
|
370 |
+
wav: th.Tensor,
|
371 |
+
frame_len: int,
|
372 |
+
frame_hop: int,
|
373 |
+
n_fft: int = 512,
|
374 |
+
return_polar: bool = False,
|
375 |
+
window: str = "sqrthann",
|
376 |
+
normalized: bool = False,
|
377 |
+
onesided: bool = True,
|
378 |
+
center: bool = False,
|
379 |
+
eps: float = EPSILON,
|
380 |
+
) -> th.Tensor:
|
381 |
+
"""
|
382 |
+
Wrapper of PyTorch STFT function
|
383 |
+
Args:
|
384 |
+
wav (Tensor): source audio signal
|
385 |
+
frame_len: length of the frame
|
386 |
+
frame_hop: hop size between frames
|
387 |
+
n_fft: number of the FFT size
|
388 |
+
return_polar: return the results in polar coordinate
|
389 |
+
window: window tensor
|
390 |
+
center: same definition with the parameter in librosa.stft
|
391 |
+
normalized: use normalized DFT kernel
|
392 |
+
onesided: output onesided STFT
|
393 |
+
Return:
|
394 |
+
transform (Tensor), STFT transform results
|
395 |
+
"""
|
396 |
+
if TORCH_VERSION < LooseVersion("1.7"):
|
397 |
+
raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
|
398 |
+
wav_dim = wav.dim()
|
399 |
+
if wav_dim not in [2, 3]:
|
400 |
+
raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
|
401 |
+
# if N x C x S, reshape NC x S
|
402 |
+
wav = wav.view(-1, wav.shape[-1])
|
403 |
+
# STFT: N x F x T x 2
|
404 |
+
stft = th.stft(
|
405 |
+
wav,
|
406 |
+
n_fft,
|
407 |
+
hop_length=frame_hop,
|
408 |
+
win_length=window.shape[-1],
|
409 |
+
window=window,
|
410 |
+
center=center,
|
411 |
+
normalized=normalized,
|
412 |
+
onesided=onesided,
|
413 |
+
return_complex=False,
|
414 |
+
)
|
415 |
+
if wav_dim == 3:
|
416 |
+
N, F, T, _ = stft.shape
|
417 |
+
stft = stft.view(N, -1, F, T, 2)
|
418 |
+
# N x (C) x F x T x 2
|
419 |
+
if not return_polar:
|
420 |
+
return stft
|
421 |
+
# N x (C) x F x T
|
422 |
+
real, imag = stft[..., 0], stft[..., 1]
|
423 |
+
mag = (real ** 2 + imag ** 2 + eps) ** 0.5
|
424 |
+
pha = th.atan2(imag, real)
|
425 |
+
return th.stack([mag, pha], dim=-1)
|
426 |
+
|
427 |
+
|
428 |
+
def _pytorch_istft(
|
429 |
+
transform: th.Tensor,
|
430 |
+
frame_len: int,
|
431 |
+
frame_hop: int,
|
432 |
+
window: th.Tensor,
|
433 |
+
n_fft: int = 512,
|
434 |
+
return_polar: bool = False,
|
435 |
+
normalized: bool = False,
|
436 |
+
onesided: bool = True,
|
437 |
+
center: bool = False,
|
438 |
+
eps: float = EPSILON,
|
439 |
+
) -> th.Tensor:
|
440 |
+
"""
|
441 |
+
Wrapper of PyTorch iSTFT function
|
442 |
+
Args:
|
443 |
+
transform (Tensor): results of STFT
|
444 |
+
frame_len: length of the frame
|
445 |
+
frame_hop: hop size between frames
|
446 |
+
window: window tensor
|
447 |
+
n_fft: number of the FFT size
|
448 |
+
return_polar: keep same with _pytorch_stft
|
449 |
+
center: same definition with the parameter in librosa.stft
|
450 |
+
normalized: use normalized DFT kernel
|
451 |
+
onesided: output onesided STFT
|
452 |
+
Return:
|
453 |
+
wav (Tensor): synthetic audio
|
454 |
+
"""
|
455 |
+
if TORCH_VERSION < LooseVersion("1.7"):
|
456 |
+
raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
|
457 |
+
|
458 |
+
transform_dim = transform.dim()
|
459 |
+
# if F x T x 2, reshape 1 x F x T x 2
|
460 |
+
if transform_dim == 3:
|
461 |
+
transform = th.unsqueeze(transform, 0)
|
462 |
+
if transform_dim != 4:
|
463 |
+
raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
|
464 |
+
|
465 |
+
if return_polar:
|
466 |
+
real = transform[..., 0] * th.cos(transform[..., 1])
|
467 |
+
imag = transform[..., 0] * th.sin(transform[..., 1])
|
468 |
+
transform = th.stack([real, imag], -1)
|
469 |
+
# stft is a complex tensor of PyTorch
|
470 |
+
stft = th.view_as_complex(transform)
|
471 |
+
# (N) x S
|
472 |
+
wav = th.istft(
|
473 |
+
stft,
|
474 |
+
n_fft,
|
475 |
+
hop_length=frame_hop,
|
476 |
+
win_length=window.shape[-1],
|
477 |
+
window=window,
|
478 |
+
center=center,
|
479 |
+
normalized=normalized,
|
480 |
+
onesided=onesided,
|
481 |
+
return_complex=False,
|
482 |
+
)
|
483 |
+
return wav
|
484 |
+
|
485 |
+
|
486 |
+
def forward_stft(
|
487 |
+
wav: th.Tensor,
|
488 |
+
frame_len: int,
|
489 |
+
frame_hop: int,
|
490 |
+
window: str = "sqrthann",
|
491 |
+
round_pow_of_two: bool = True,
|
492 |
+
return_polar: bool = False,
|
493 |
+
pre_emphasis: float = 0,
|
494 |
+
normalized: bool = False,
|
495 |
+
onesided: bool = True,
|
496 |
+
center: bool = False,
|
497 |
+
mode: str = "librosa",
|
498 |
+
eps: float = EPSILON,
|
499 |
+
) -> th.Tensor:
|
500 |
+
"""
|
501 |
+
STFT function implementation, equals to STFT layer
|
502 |
+
Args:
|
503 |
+
wav: source audio signal
|
504 |
+
frame_len: length of the frame
|
505 |
+
frame_hop: hop size between frames
|
506 |
+
return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
|
507 |
+
window: window name
|
508 |
+
center: center flag (similar with that in librosa.stft)
|
509 |
+
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
|
510 |
+
pre_emphasis: factor of preemphasis
|
511 |
+
normalized: use normalized DFT kernel
|
512 |
+
onesided: output onesided STFT
|
513 |
+
inverse: using iDFT kernel (for iSTFT)
|
514 |
+
mode: STFT mode, "kaldi" or "librosa" or "torch"
|
515 |
+
Return:
|
516 |
+
transform: results of STFT
|
517 |
+
"""
|
518 |
+
window = init_window(window, frame_len, device=wav.device)
|
519 |
+
if mode == "torch":
|
520 |
+
n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
|
521 |
+
return _pytorch_stft(
|
522 |
+
wav,
|
523 |
+
frame_len,
|
524 |
+
frame_hop,
|
525 |
+
n_fft=n_fft,
|
526 |
+
return_polar=return_polar,
|
527 |
+
window=window,
|
528 |
+
normalized=normalized,
|
529 |
+
onesided=onesided,
|
530 |
+
center=center,
|
531 |
+
eps=eps,
|
532 |
+
)
|
533 |
+
else:
|
534 |
+
kernel, window = init_kernel(
|
535 |
+
frame_len,
|
536 |
+
frame_hop,
|
537 |
+
window=window,
|
538 |
+
round_pow_of_two=round_pow_of_two,
|
539 |
+
normalized=normalized,
|
540 |
+
inverse=False,
|
541 |
+
mode=mode,
|
542 |
+
)
|
543 |
+
return _forward_stft(
|
544 |
+
wav,
|
545 |
+
kernel,
|
546 |
+
window,
|
547 |
+
return_polar=return_polar,
|
548 |
+
frame_hop=frame_hop,
|
549 |
+
pre_emphasis=pre_emphasis,
|
550 |
+
onesided=onesided,
|
551 |
+
center=center,
|
552 |
+
eps=eps,
|
553 |
+
)
|
554 |
+
|
555 |
+
|
556 |
+
def inverse_stft(
|
557 |
+
transform: th.Tensor,
|
558 |
+
frame_len: int,
|
559 |
+
frame_hop: int,
|
560 |
+
return_polar: bool = False,
|
561 |
+
window: str = "sqrthann",
|
562 |
+
round_pow_of_two: bool = True,
|
563 |
+
normalized: bool = False,
|
564 |
+
onesided: bool = True,
|
565 |
+
center: bool = False,
|
566 |
+
mode: str = "librosa",
|
567 |
+
eps: float = EPSILON,
|
568 |
+
) -> th.Tensor:
|
569 |
+
"""
|
570 |
+
iSTFT function implementation, equals to iSTFT layer
|
571 |
+
Args:
|
572 |
+
transform: results of STFT
|
573 |
+
frame_len: length of the frame
|
574 |
+
frame_hop: hop size between frames
|
575 |
+
return_polar: keep same with function forward_stft(...)
|
576 |
+
window: window name
|
577 |
+
center: center flag (similar with that in librosa.stft)
|
578 |
+
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
|
579 |
+
normalized: use normalized DFT kernel
|
580 |
+
onesided: output onesided STFT
|
581 |
+
mode: STFT mode, "kaldi" or "librosa" or "torch"
|
582 |
+
Return:
|
583 |
+
wav: synthetic signals
|
584 |
+
"""
|
585 |
+
window = init_window(window, frame_len, device=transform.device)
|
586 |
+
if mode == "torch":
|
587 |
+
n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
|
588 |
+
return _pytorch_istft(
|
589 |
+
transform,
|
590 |
+
frame_len,
|
591 |
+
frame_hop,
|
592 |
+
n_fft=n_fft,
|
593 |
+
return_polar=return_polar,
|
594 |
+
window=window,
|
595 |
+
normalized=normalized,
|
596 |
+
onesided=onesided,
|
597 |
+
center=center,
|
598 |
+
eps=eps,
|
599 |
+
)
|
600 |
+
else:
|
601 |
+
kernel, window = init_kernel(
|
602 |
+
frame_len,
|
603 |
+
frame_hop,
|
604 |
+
window,
|
605 |
+
round_pow_of_two=round_pow_of_two,
|
606 |
+
normalized=normalized,
|
607 |
+
inverse=True,
|
608 |
+
mode=mode,
|
609 |
+
)
|
610 |
+
return _inverse_stft(
|
611 |
+
transform,
|
612 |
+
kernel,
|
613 |
+
window,
|
614 |
+
return_polar=return_polar,
|
615 |
+
frame_hop=frame_hop,
|
616 |
+
onesided=onesided,
|
617 |
+
center=center,
|
618 |
+
eps=eps,
|
619 |
+
)
|
620 |
+
|
621 |
+
|
622 |
+
class STFTBase(nn.Module):
|
623 |
+
"""
|
624 |
+
Base layer for (i)STFT
|
625 |
+
Args:
|
626 |
+
frame_len: length of the frame
|
627 |
+
frame_hop: hop size between frames
|
628 |
+
window: window name
|
629 |
+
center: center flag (similar with that in librosa.stft)
|
630 |
+
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
|
631 |
+
normalized: use normalized DFT kernel
|
632 |
+
pre_emphasis: factor of preemphasis
|
633 |
+
mode: STFT mode, "kaldi" or "librosa" or "torch"
|
634 |
+
onesided: output onesided STFT
|
635 |
+
inverse: using iDFT kernel (for iSTFT)
|
636 |
+
"""
|
637 |
+
|
638 |
+
def __init__(
|
639 |
+
self,
|
640 |
+
frame_len: int,
|
641 |
+
frame_hop: int,
|
642 |
+
window: str = "sqrthann",
|
643 |
+
round_pow_of_two: bool = True,
|
644 |
+
normalized: bool = False,
|
645 |
+
pre_emphasis: float = 0,
|
646 |
+
onesided: bool = True,
|
647 |
+
inverse: bool = False,
|
648 |
+
center: bool = False,
|
649 |
+
mode: str = "librosa",
|
650 |
+
) -> None:
|
651 |
+
super(STFTBase, self).__init__()
|
652 |
+
if mode != "torch":
|
653 |
+
K, w = init_kernel(
|
654 |
+
frame_len,
|
655 |
+
frame_hop,
|
656 |
+
init_window(window, frame_len),
|
657 |
+
round_pow_of_two=round_pow_of_two,
|
658 |
+
normalized=normalized,
|
659 |
+
inverse=inverse,
|
660 |
+
mode=mode,
|
661 |
+
)
|
662 |
+
self.K = nn.Parameter(K, requires_grad=False)
|
663 |
+
self.w = nn.Parameter(w, requires_grad=False)
|
664 |
+
self.num_bins = self.K.shape[0] // 4 + 1
|
665 |
+
self.pre_emphasis = pre_emphasis
|
666 |
+
self.win_length = self.K.shape[2]
|
667 |
+
else:
|
668 |
+
self.K = None
|
669 |
+
w = init_window(window, frame_len)
|
670 |
+
self.w = nn.Parameter(w, requires_grad=False)
|
671 |
+
fft_size = (
|
672 |
+
2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
|
673 |
+
)
|
674 |
+
self.num_bins = fft_size // 2 + 1
|
675 |
+
self.pre_emphasis = 0
|
676 |
+
self.win_length = fft_size
|
677 |
+
self.frame_len = frame_len
|
678 |
+
self.frame_hop = frame_hop
|
679 |
+
self.window = window
|
680 |
+
self.normalized = normalized
|
681 |
+
self.onesided = onesided
|
682 |
+
self.center = center
|
683 |
+
self.mode = mode
|
684 |
+
|
685 |
+
def num_frames(self, wav_len: th.Tensor) -> th.Tensor:
|
686 |
+
"""
|
687 |
+
Compute number of the frames
|
688 |
+
"""
|
689 |
+
assert th.sum(wav_len <= self.win_length) == 0
|
690 |
+
if self.center:
|
691 |
+
wav_len += self.win_length
|
692 |
+
return (
|
693 |
+
th.div(wav_len - self.win_length, self.frame_hop, rounding_mode="trunc") + 1
|
694 |
+
)
|
695 |
+
|
696 |
+
def extra_repr(self) -> str:
|
697 |
+
str_repr = (
|
698 |
+
f"num_bins={self.num_bins}, win_length={self.win_length}, "
|
699 |
+
+ f"stride={self.frame_hop}, window={self.window}, "
|
700 |
+
+ f"center={self.center}, mode={self.mode}"
|
701 |
+
)
|
702 |
+
if not self.onesided:
|
703 |
+
str_repr += f", onesided={self.onesided}"
|
704 |
+
if self.pre_emphasis > 0:
|
705 |
+
str_repr += f", pre_emphasis={self.pre_emphasis}"
|
706 |
+
if self.normalized:
|
707 |
+
str_repr += f", normalized={self.normalized}"
|
708 |
+
return str_repr
|
709 |
+
|
710 |
+
|
711 |
+
class STFT(STFTBase):
|
712 |
+
"""
|
713 |
+
Short-time Fourier Transform as a Layer
|
714 |
+
"""
|
715 |
+
|
716 |
+
def __init__(self, *args, **kwargs):
|
717 |
+
super(STFT, self).__init__(*args, inverse=False, **kwargs)
|
718 |
+
|
719 |
+
def forward(
|
720 |
+
self, wav: th.Tensor, return_polar: bool = False, eps: float = EPSILON
|
721 |
+
) -> th.Tensor:
|
722 |
+
"""
|
723 |
+
Accept (single or multiple channel) raw waveform and output magnitude and phase
|
724 |
+
Args
|
725 |
+
wav (Tensor) input signal, N x (C) x S
|
726 |
+
Return
|
727 |
+
transform (Tensor), N x (C) x F x T x 2
|
728 |
+
"""
|
729 |
+
if self.mode == "torch":
|
730 |
+
return _pytorch_stft(
|
731 |
+
wav,
|
732 |
+
self.frame_len,
|
733 |
+
self.frame_hop,
|
734 |
+
n_fft=(self.num_bins - 1) * 2,
|
735 |
+
return_polar=return_polar,
|
736 |
+
window=self.w,
|
737 |
+
normalized=self.normalized,
|
738 |
+
onesided=self.onesided,
|
739 |
+
center=self.center,
|
740 |
+
eps=eps,
|
741 |
+
)
|
742 |
+
else:
|
743 |
+
return _forward_stft(
|
744 |
+
wav,
|
745 |
+
self.K,
|
746 |
+
self.w,
|
747 |
+
return_polar=return_polar,
|
748 |
+
frame_hop=self.frame_hop,
|
749 |
+
pre_emphasis=self.pre_emphasis,
|
750 |
+
onesided=self.onesided,
|
751 |
+
center=self.center,
|
752 |
+
eps=eps,
|
753 |
+
)
|
754 |
+
|
755 |
+
|
756 |
+
class iSTFT(STFTBase):
|
757 |
+
"""
|
758 |
+
Inverse Short-time Fourier Transform as a Layer
|
759 |
+
"""
|
760 |
+
|
761 |
+
def __init__(self, *args, **kwargs):
|
762 |
+
super(iSTFT, self).__init__(*args, inverse=True, **kwargs)
|
763 |
+
|
764 |
+
def forward(
|
765 |
+
self, transform: th.Tensor, return_polar: bool = False, eps: float = EPSILON
|
766 |
+
) -> th.Tensor:
|
767 |
+
"""
|
768 |
+
Accept phase & magnitude and output raw waveform
|
769 |
+
Args
|
770 |
+
transform (Tensor): STFT output, N x F x T x 2
|
771 |
+
Return
|
772 |
+
s (Tensor): N x S
|
773 |
+
"""
|
774 |
+
if self.mode == "torch":
|
775 |
+
return _pytorch_istft(
|
776 |
+
transform,
|
777 |
+
self.frame_len,
|
778 |
+
self.frame_hop,
|
779 |
+
n_fft=(self.num_bins - 1) * 2,
|
780 |
+
return_polar=return_polar,
|
781 |
+
window=self.w,
|
782 |
+
normalized=self.normalized,
|
783 |
+
onesided=self.onesided,
|
784 |
+
center=self.center,
|
785 |
+
eps=eps,
|
786 |
+
)
|
787 |
+
else:
|
788 |
+
return _inverse_stft(
|
789 |
+
transform,
|
790 |
+
self.K,
|
791 |
+
self.w,
|
792 |
+
return_polar=return_polar,
|
793 |
+
frame_hop=self.frame_hop,
|
794 |
+
onesided=self.onesided,
|
795 |
+
center=self.center,
|
796 |
+
eps=eps,
|
797 |
+
)
|
look2hear/utils/torch_utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###
|
2 |
+
# Author: Kai Li
|
3 |
+
# Date: 2021-06-18 17:29:21
|
4 |
+
# LastEditors: Kai Li
|
5 |
+
# LastEditTime: 2021-06-21 23:52:52
|
6 |
+
###
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
|
12 |
+
def pad_x_to_y(x, y, axis: int = -1):
|
13 |
+
if axis != -1:
|
14 |
+
raise NotImplementedError
|
15 |
+
inp_len = y.shape[axis]
|
16 |
+
output_len = x.shape[axis]
|
17 |
+
return nn.functional.pad(x, [0, inp_len - output_len])
|
18 |
+
|
19 |
+
|
20 |
+
def shape_reconstructed(reconstructed, size):
|
21 |
+
if len(size) == 1:
|
22 |
+
return reconstructed.squeeze(0)
|
23 |
+
return reconstructed
|
24 |
+
|
25 |
+
|
26 |
+
def tensors_to_device(tensors, device):
|
27 |
+
"""Transfer tensor, dict or list of tensors to device.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
tensors (:class:`torch.Tensor`): May be a single, a list or a
|
31 |
+
dictionary of tensors.
|
32 |
+
device (:class: `torch.device`): the device where to place the tensors.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Union [:class:`torch.Tensor`, list, tuple, dict]:
|
36 |
+
Same as input but transferred to device.
|
37 |
+
Goes through lists and dicts and transfers the torch.Tensor to
|
38 |
+
device. Leaves the rest untouched.
|
39 |
+
"""
|
40 |
+
if isinstance(tensors, torch.Tensor):
|
41 |
+
return tensors.to(device)
|
42 |
+
elif isinstance(tensors, (list, tuple)):
|
43 |
+
return [tensors_to_device(tens, device) for tens in tensors]
|
44 |
+
elif isinstance(tensors, dict):
|
45 |
+
for key in tensors.keys():
|
46 |
+
tensors[key] = tensors_to_device(tensors[key], device)
|
47 |
+
return tensors
|
48 |
+
else:
|
49 |
+
return tensors
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchaudio==2.2.0
|
2 |
+
torch==2.2.0
|
3 |
+
huggingface
|
4 |
+
huggingface_hub
|
5 |
+
numpy<2.0
|
6 |
+
omegaconf
|
7 |
+
ml_collections
|
8 |
+
librosa
|
9 |
+
gradio
|
10 |
+
tqdm
|
11 |
+
spaces
|
weights/apollo.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99d9af7f1ff20e63c393035513a655392818d66b4d7fc23d658175c1f15e8d76
|
3 |
+
size 66541845
|