Serhiy Stetskovych commited on
Commit
78e32cc
0 Parent(s):

Initial code

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pyc
2
+ __pycache__
3
+ .venv
4
+ .DS_Store
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Apollo
3
+ emoji: 💻
4
+ colorFrom: green
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.5.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torchaudio
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import yaml
7
+ import librosa
8
+ import tqdm
9
+
10
+ import look2hear.models
11
+ from ml_collections import ConfigDict
12
+
13
+ def load_audio(file_path):
14
+ audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
15
+ print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
16
+ #audio = dBgain(audio, -6)
17
+ return torch.from_numpy(audio), samplerate
18
+
19
+
20
+ def get_config(config_path):
21
+ with open(config_path) as f:
22
+ #config = OmegaConf.load(config_path)
23
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
24
+ return config
25
+
26
+
27
+
28
+
29
+ def _getWindowingArray(window_size, fade_size):
30
+ # IMPORTANT NOTE :
31
+ # no fades here in the end, only removing the failed ending of the chunk
32
+ fadein = torch.linspace(1, 1, fade_size)
33
+ fadeout = torch.linspace(0, 0, fade_size)
34
+ window = torch.ones(window_size)
35
+ window[-fade_size:] *= fadeout
36
+ window[:fade_size] *= fadein
37
+ return window
38
+
39
+
40
+
41
+ description = f'''
42
+ texts
43
+ '''
44
+
45
+
46
+ apollo_config = get_config('configs/apollo.yaml')
47
+ apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).cuda()
48
+
49
+ models = [
50
+ ('MP3 restore', apollo_model)
51
+ ]
52
+
53
+ @spaces.GPU
54
+ def enchance(model, audio):
55
+ test_data, samplerate = load_audio(audio)
56
+ C = 10 * samplerate # chunk_size seconds to samples
57
+ N = 2
58
+ step = C // N
59
+ fade_size = 3 * 44100 # 3 seconds
60
+ print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
61
+
62
+ border = C - step
63
+
64
+ # handle mono inputs correctly
65
+ if len(test_data.shape) == 1:
66
+ test_data = test_data.unsqueeze(0)
67
+
68
+ # Pad the input if necessary
69
+ if test_data.shape[1] > 2 * border and (border > 0):
70
+ test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
71
+
72
+ windowingArray = _getWindowingArray(C, fade_size)
73
+
74
+ result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
75
+ counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
76
+
77
+ i = 0
78
+ progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
79
+
80
+ while i < test_data.shape[1]:
81
+ part = test_data[:, i:i + C]
82
+ length = part.shape[-1]
83
+ if length < C:
84
+ if length > C // 2 + 1:
85
+ part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
86
+ else:
87
+ part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
88
+
89
+
90
+ chunk = part.unsqueeze(0).cuda()
91
+ with torch.no_grad():
92
+ out = model(chunk).squeeze(0).squeeze(0).cpu()
93
+
94
+ window = windowingArray
95
+ if i == 0: # First audio chunk, no fadein
96
+ window[:fade_size] = 1
97
+ elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
98
+ window[-fade_size:] = 1
99
+
100
+ result[..., i:i+length] += out[..., :length] * window[..., :length]
101
+ counter[..., i:i+length] += window[..., :length]
102
+
103
+ i += step
104
+ progress_bar.update(step)
105
+
106
+ progress_bar.close()
107
+
108
+ final_output = result / counter
109
+ final_output = final_output.squeeze(0).numpy()
110
+ np.nan_to_num(final_output, copy=False, nan=0.0)
111
+
112
+ # Remove padding if added earlier
113
+ if test_data.shape[1] > 2 * border and (border > 0):
114
+ final_output = final_output[..., border:-border]
115
+
116
+ return samplerate, final_output.T
117
+
118
+
119
+ if __name__ == "__main__":
120
+ i = gr.Interface(
121
+ fn=enchance,
122
+ description=description,
123
+ inputs=[
124
+ gr.Dropdown(label="Model", choices=models, value=models[0]),
125
+ gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_progress_color': '#3C82F6'}),
126
+ ],
127
+ outputs=[
128
+ gr.Audio(
129
+ label="Output Audio",
130
+ autoplay=False,
131
+ streaming=False,
132
+ type="numpy",
133
+ ),
134
+
135
+ ],
136
+ allow_flagging ='never',
137
+ cache_examples=False,
138
+ title='Enchanser',
139
+
140
+ )
141
+ i.queue(max_size=20, default_concurrency_limit=4)
142
+ i.launch(share=False, server_name="0.0.0.0")
configs/apollo.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp:
2
+ dir: ./Exps
3
+ name: Apollo
4
+
5
+ # seed: 614020
6
+
7
+ datas:
8
+ _target_: look2hear.datas.MusdbMoisesdbDataModule
9
+ train_dir: ./hdf5_datas
10
+ eval_dir: ./eval
11
+ codec_type: mp3
12
+ codec_options:
13
+ bitrate: random
14
+ compression: random
15
+ complexity: random
16
+ vbr: random
17
+ sr: 44100
18
+ segments: 3
19
+ num_stems: 8
20
+ snr_range: [-10, 10]
21
+ num_samples: 40000
22
+ batch_size: 1
23
+ num_workers: 8
24
+
25
+ model:
26
+
27
+ sr: 44100
28
+ win: 20 # ms
29
+ feature_dim: 256
30
+ layer: 6
31
+
32
+ discriminator:
33
+ _target_: look2hear.discriminators.frequencydis.MultiFrequencyDiscriminator
34
+ nch: 2
35
+ window: [32, 64, 128, 256, 512, 1024, 2048]
36
+
37
+ optimizer_g:
38
+ _target_: torch.optim.AdamW
39
+ lr: 0.001
40
+ weight_decay: 0.01
41
+
42
+ optimizer_d:
43
+ _target_: torch.optim.AdamW
44
+ lr: 0.0001
45
+ weight_decay: 0.01
46
+ betas: [0.5, 0.99]
47
+
48
+ scheduler_g:
49
+ _target_: torch.optim.lr_scheduler.StepLR
50
+ step_size: 2
51
+ gamma: 0.98
52
+
53
+ scheduler_d:
54
+ _target_: torch.optim.lr_scheduler.StepLR
55
+ step_size: 2
56
+ gamma: 0.98
57
+
58
+ loss_g:
59
+ _target_: look2hear.losses.gan_losses.MultiFrequencyGenLoss
60
+ eps: 1e-8
61
+
62
+ loss_d:
63
+ _target_: look2hear.losses.gan_losses.MultiFrequencyDisLoss
64
+ eps: 1e-8
65
+
66
+ metrics:
67
+ _target_: look2hear.losses.MultiSrcNegSDR
68
+ sdr_type: sisdr
69
+
70
+ system:
71
+ _target_: look2hear.system.audio_litmodule.AudioLightningModule
72
+
73
+ early_stopping:
74
+ _target_: pytorch_lightning.callbacks.EarlyStopping
75
+ monitor: val_loss
76
+ patience: 20
77
+ mode: min
78
+ verbose: true
79
+
80
+ checkpoint:
81
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
82
+ dirpath: ${exp.dir}/${exp.name}/checkpoints
83
+ monitor: val_loss
84
+ mode: min
85
+ verbose: true
86
+ save_top_k: 5
87
+ save_last: true
88
+ filename: '{epoch}-{val_loss:.4f}'
89
+
90
+ logger:
91
+ _target_: pytorch_lightning.loggers.WandbLogger
92
+ name: ${exp.name}
93
+ save_dir: ${exp.dir}/${exp.name}/logs
94
+ offline: false
95
+ project: Audio-Restoration
96
+
97
+ trainer:
98
+ _target_: pytorch_lightning.Trainer
99
+ devices: [0,1,2,3,4,5,6,7]
100
+ max_epochs: 500
101
+ sync_batchnorm: true
102
+ default_root_dir: ${exp.dir}/${exp.name}/
103
+ accelerator: cuda
104
+ limit_train_batches: 1.0
105
+ fast_dev_run: false
106
+
inference.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import look2hear.models
5
+ import soundfile as sf
6
+ from tqdm.auto import tqdm
7
+ import argparse
8
+ import numpy as np
9
+ import yaml
10
+ from ml_collections import ConfigDict
11
+ #from omegaconf import OmegaConf
12
+
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ def get_config(config_path):
17
+ with open(config_path) as f:
18
+ #config = OmegaConf.load(config_path)
19
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
+ return config
21
+
22
+ def load_audio(file_path):
23
+ audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
24
+ print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
25
+ #audio = dBgain(audio, -6)
26
+ return torch.from_numpy(audio), samplerate
27
+
28
+ def save_audio(file_path, audio, samplerate=44100):
29
+ #audio = dBgain(audio, +6)
30
+ sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
31
+
32
+ def process_chunk(chunk):
33
+ chunk = chunk.unsqueeze(0).cpu()
34
+ with torch.no_grad():
35
+ return model(chunk).squeeze(0).squeeze(0).cpu()
36
+
37
+ def _getWindowingArray(window_size, fade_size):
38
+ # IMPORTANT NOTE :
39
+ # no fades here in the end, only removing the failed ending of the chunk
40
+ fadein = torch.linspace(1, 1, fade_size)
41
+ fadeout = torch.linspace(0, 0, fade_size)
42
+ window = torch.ones(window_size)
43
+ window[-fade_size:] *= fadeout
44
+ window[:fade_size] *= fadein
45
+ return window
46
+
47
+ def dBgain(audio, volume_gain_dB):
48
+ gain = 10 ** (volume_gain_dB / 20)
49
+ gained_audio = audio * gain
50
+ return gained_audio
51
+
52
+
53
+ def main(input_wav, output_wav, ckpt_path):
54
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
55
+
56
+ global model
57
+ feature_dim = config['model']['feature_dim']
58
+ sr = config['model']['sr']
59
+ win = config['model']['win']
60
+ layer = config['model']['layer']
61
+ model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cpu()
62
+
63
+ test_data, samplerate = load_audio(input_wav)
64
+
65
+ C = chunk_size * samplerate # chunk_size seconds to samples
66
+ N = overlap
67
+ step = C // N
68
+ fade_size = 3 * 44100 # 3 seconds
69
+ print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
70
+
71
+ border = C - step
72
+
73
+ # handle mono inputs correctly
74
+ if len(test_data.shape) == 1:
75
+ test_data = test_data.unsqueeze(0)
76
+
77
+ # Pad the input if necessary
78
+ if test_data.shape[1] > 2 * border and (border > 0):
79
+ test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
80
+
81
+ windowingArray = _getWindowingArray(C, fade_size)
82
+
83
+ result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
84
+ counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
85
+
86
+ i = 0
87
+ progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
88
+
89
+ while i < test_data.shape[1]:
90
+ part = test_data[:, i:i + C]
91
+ length = part.shape[-1]
92
+ if length < C:
93
+ if length > C // 2 + 1:
94
+ part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
95
+ else:
96
+ part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
97
+
98
+ out = process_chunk(part)
99
+
100
+ window = windowingArray
101
+ if i == 0: # First audio chunk, no fadein
102
+ window[:fade_size] = 1
103
+ elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
104
+ window[-fade_size:] = 1
105
+
106
+ result[..., i:i+length] += out[..., :length] * window[..., :length]
107
+ counter[..., i:i+length] += window[..., :length]
108
+
109
+ i += step
110
+ progress_bar.update(step)
111
+
112
+ progress_bar.close()
113
+
114
+ final_output = result / counter
115
+ final_output = final_output.squeeze(0).numpy()
116
+ np.nan_to_num(final_output, copy=False, nan=0.0)
117
+
118
+ # Remove padding if added earlier
119
+ if test_data.shape[1] > 2 * border and (border > 0):
120
+ final_output = final_output[..., border:-border]
121
+
122
+ save_audio(output_wav, final_output, samplerate)
123
+ print(f'Success! Output file saved as {output_wav}')
124
+
125
+ # Memory clearing
126
+ model.cpu()
127
+ del model
128
+ torch.cuda.empty_cache()
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser(description="Audio Inference Script")
132
+ parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
133
+ parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
134
+ parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file", default="model/pytorch_model.bin")
135
+ parser.add_argument("--config", type=str, help="Path to model config file", default="config/apollo.yaml")
136
+ parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=10)
137
+ parser.add_argument("--overlap", type=int, help="Overlap", default=2)
138
+ args = parser.parse_args()
139
+
140
+ ckpt_path = args.ckpt
141
+ chunk_size = args.chunk_size
142
+ overlap = args.overlap
143
+ config = get_config(args.config)
144
+ print(config['model'])
145
+ print(f'ckpt_path = {ckpt_path}')
146
+ #print(f'config = {config}')
147
+ print(f'chunk_size = {chunk_size}, overlap = {overlap}')
148
+
149
+
150
+ main(args.in_wav, args.out_wav, ckpt_path)
look2hear/__init__.py ADDED
File without changes
look2hear/datas/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-03 18:29:46
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2022-07-29 06:23:03
6
+ ###
7
+ from .musdb_moisesdb_datamodule import MusdbMoisesdbDataModule
8
+
9
+ __all__ = [
10
+ "MusdbMoisesdbDataModule"
11
+ ]
look2hear/datas/musdb_moisesdb_datamodule.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import numpy as np
4
+ from typing import Any, Tuple
5
+ import torch
6
+ import random
7
+ from pytorch_lightning import LightningDataModule
8
+ import torchaudio
9
+ from torchaudio.functional import apply_codec
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from typing import Any, Dict, Optional, Tuple
12
+
13
+ def compute_mch_rms_dB(mch_wav, fs=16000, energy_thresh=-50):
14
+ """Return the wav RMS calculated only in the active portions"""
15
+ mean_square = max(1e-20, torch.mean(mch_wav ** 2))
16
+ return 10 * np.log10(mean_square)
17
+
18
+ def match2(x, d):
19
+ assert x.dim()==2, x.shape
20
+ assert d.dim()==2, d.shape
21
+ minlen = min(x.shape[-1], d.shape[-1])
22
+ x, d = x[:,0:minlen], d[:,0:minlen]
23
+ Fx = torch.fft.rfft(x, dim=-1)
24
+ Fd = torch.fft.rfft(d, dim=-1)
25
+ Phi = Fd*Fx.conj()
26
+ Phi = Phi / (Phi.abs() + 1e-3)
27
+ Phi[:,0] = 0
28
+ tmp = torch.fft.irfft(Phi, dim=-1)
29
+ tau = torch.argmax(tmp.abs(),dim=-1).tolist()
30
+ return tau
31
+
32
+ def codec_simu(wav, sr=16000, options={'bitrate':'random','compression':'random', 'complexity':'random', 'vbr':'random'}):
33
+
34
+ if options['bitrate'] == 'random':
35
+ options['bitrate'] = random.choice([24000, 32000, 48000, 64000, 96000, 128000])
36
+ compression = int(options['bitrate']//1000)
37
+ param = {'format': "mp3", "compression": compression}
38
+ wav_encdec = apply_codec(wav, sr, **param)
39
+ if wav_encdec.shape[-1] >= wav.shape[-1]:
40
+ wav_encdec = wav_encdec[...,:wav.shape[-1]]
41
+ else:
42
+ wav_encdec = torch.cat([wav_encdec, wav[..., wav_encdec.shape[-1]:]], -1)
43
+ tau = match2(wav, wav_encdec)
44
+ wav_encdec = torch.roll(wav_encdec, -tau[0], -1)
45
+
46
+ return wav_encdec
47
+
48
+ def get_wav_files(root_dir):
49
+ wav_files = []
50
+ for dirpath, dirnames, filenames in os.walk(root_dir):
51
+ for filename in filenames:
52
+ if filename.endswith('.wav'):
53
+ if "musdb18hq" in dirpath and "mixture" not in filename:
54
+ wav_files.append(os.path.join(dirpath, filename))
55
+ elif "moisesdb" in dirpath:
56
+ wav_files.append(os.path.join(dirpath, filename))
57
+ return wav_files
58
+
59
+ class MusdbMoisesdbDataset(Dataset):
60
+ def __init__(
61
+ self,
62
+ data_dir: str,
63
+ codec_type: str,
64
+ codec_options: dict,
65
+ sr: int = 16000,
66
+ segments: int = 10,
67
+ num_stems: int = 4,
68
+ snr_range: Tuple[int, int] = (-10, 10),
69
+ num_samples: int = 1000,
70
+ ) -> None:
71
+
72
+ self.data_dir = data_dir
73
+ self.codec_type = codec_type
74
+ self.codec_options = codec_options
75
+ self.segments = int(segments * sr)
76
+ self.sr = sr
77
+ self.num_stems = num_stems
78
+ self.snr_range = snr_range
79
+ self.num_samples = num_samples
80
+
81
+ self.instruments = [
82
+ "bass",
83
+ "bowed_strings",
84
+ "drums",
85
+ "guitar",
86
+ "other",
87
+ "other_keys",
88
+ "other_plucked",
89
+ "percussion",
90
+ "piano",
91
+ "vocals",
92
+ "wind"
93
+ ]
94
+
95
+ def __len__(self) -> int:
96
+ return self.num_samples
97
+
98
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
99
+ if random.random() > 0.5:
100
+ select_stems = random.randint(1, self.num_stems)
101
+ select_stems = random.choices(self.instruments, k=select_stems)
102
+ ori_wav = []
103
+ for stem in select_stems:
104
+ h5path = random.choice(os.listdir(os.path.join(self.data_dir, stem)))
105
+ datas = h5py.File(os.path.join(self.data_dir, stem, h5path), 'r')['data']
106
+ random_index = random.randint(0, datas.shape[0]-1)
107
+ music_wav = torch.FloatTensor(datas[random_index])
108
+ start = random.randint(0, music_wav.shape[-1] - self.segments)
109
+ music_wav = music_wav[:, start:start+self.segments]
110
+
111
+ rescale_snr = random.randint(self.snr_range[0], self.snr_range[1])
112
+ music_wav = music_wav * np.sqrt(10**(rescale_snr/10))
113
+ ori_wav.append(music_wav)
114
+ ori_wav = torch.stack(ori_wav).sum(0)
115
+ else:
116
+ h5path = random.choice(os.listdir(os.path.join(self.data_dir, "mixture")))
117
+ datas = h5py.File(os.path.join(self.data_dir, "mixture", h5path), 'r')['data']
118
+ random_index = random.randint(0, datas.shape[0]-1)
119
+ music_wav = torch.FloatTensor(datas[random_index])
120
+ start = random.randint(0, music_wav.shape[-1] - self.segments)
121
+ ori_wav = music_wav[:, start:start+self.segments]
122
+
123
+ codec_wav = codec_simu(ori_wav, sr=self.sr, options=self.codec_options)
124
+
125
+ max_scale = max(ori_wav.abs().max(), codec_wav.abs().max())
126
+
127
+ if max_scale > 0:
128
+ ori_wav = ori_wav / max_scale
129
+ codec_wav = codec_wav / max_scale
130
+
131
+ return ori_wav, codec_wav
132
+
133
+
134
+ class MusdbMoisesdbEval(Dataset):
135
+ def __init__(
136
+ self,
137
+ data_dir: str
138
+ ) -> None:
139
+ self.data_path = os.listdir(data_dir)
140
+ self.data_path = [os.path.join(data_dir, i) for i in self.data_path]
141
+
142
+ def __len__(self) -> int:
143
+ return len(self.data_path)
144
+
145
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ ori_wav = torchaudio.load(self.data_path[idx]+"/ori_wav.wav")[0]
147
+ codec_wav = torchaudio.load(self.data_path[idx]+"/codec_wav.wav")[0]
148
+
149
+ return ori_wav, codec_wav, self.data_path[idx]
150
+
151
+ class MusdbMoisesdbDataModule(LightningDataModule):
152
+ def __init__(
153
+ self,
154
+ train_dir: str,
155
+ eval_dir: str,
156
+ codec_type: str,
157
+ codec_options: dict,
158
+ sr: int = 16000,
159
+ segments: int = 10,
160
+ num_stems: int = 4,
161
+ snr_range: Tuple[int, int] = (-10, 10),
162
+ num_samples: int = 1000,
163
+ batch_size: int = 32,
164
+ num_workers: int = 4,
165
+ ) -> None:
166
+ super().__init__()
167
+ self.save_hyperparameters(logger=False)
168
+
169
+ self.data_train: Optional[Dataset] = None
170
+ self.data_val: Optional[Dataset] = None
171
+
172
+ def setup(self, stage: Optional[str] = None) -> None:
173
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
174
+
175
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
176
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
177
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
178
+ `self.setup()` once the data is prepared and available for use.
179
+
180
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
181
+ """
182
+ # load and split datasets only if not loaded already
183
+ if not self.data_train and not self.data_val:
184
+ self.data_train = MusdbMoisesdbDataset(
185
+ data_dir=self.hparams.train_dir,
186
+ codec_type=self.hparams.codec_type,
187
+ codec_options=self.hparams.codec_options,
188
+ sr=self.hparams.sr,
189
+ segments=self.hparams.segments,
190
+ num_stems=self.hparams.num_stems,
191
+ snr_range=self.hparams.snr_range,
192
+ num_samples=self.hparams.num_samples,
193
+ )
194
+
195
+ self.data_val = MusdbMoisesdbEval(
196
+ data_dir=self.hparams.eval_dir
197
+ )
198
+
199
+ def train_dataloader(self) -> DataLoader:
200
+ return DataLoader(
201
+ self.data_train,
202
+ batch_size=self.hparams.batch_size,
203
+ num_workers=self.hparams.num_workers,
204
+ shuffle=True,
205
+ pin_memory=True,
206
+ )
207
+
208
+ def val_dataloader(self) -> DataLoader:
209
+ return DataLoader(
210
+ self.data_val,
211
+ batch_size=self.hparams.batch_size,
212
+ num_workers=self.hparams.num_workers,
213
+ shuffle=False,
214
+ pin_memory=True,
215
+ )
look2hear/discriminators/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2022-02-12 15:16:35
4
+ # Email: [email protected]
5
+ # LastEditTime: 2022-10-04 16:24:53
6
+ ###
7
+ from .frequencydis import MultiFrequencyDiscriminator, FrequencyDiscriminator
8
+
9
+ __all__ = [
10
+ "MultiFrequencyDiscriminator",
11
+ "FrequencyDiscriminator"
12
+ ]
13
+
14
+
15
+ def register_model(custom_model):
16
+ """Register a custom model, gettable with `models.get`.
17
+
18
+ Args:
19
+ custom_model: Custom model to register.
20
+
21
+ """
22
+ if (
23
+ custom_model.__name__ in globals().keys()
24
+ or custom_model.__name__.lower() in globals().keys()
25
+ ):
26
+ raise ValueError(
27
+ f"Model {custom_model.__name__} already exists. Choose another name."
28
+ )
29
+ globals().update({custom_model.__name__: custom_model})
30
+
31
+
32
+ def get(identifier):
33
+ """Returns an model class from a string (case-insensitive).
34
+
35
+ Args:
36
+ identifier (str): the model name.
37
+
38
+ Returns:
39
+ :class:`torch.nn.Module`
40
+ """
41
+ if isinstance(identifier, str):
42
+ to_get = {k.lower(): v for k, v in globals().items()}
43
+ cls = to_get.get(identifier.lower())
44
+ if cls is None:
45
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
46
+ return cls
47
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
look2hear/discriminators/frequencydis.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ class MultiFrequencyDiscriminator(nn.Module):
6
+ def __init__(self, nch, window):
7
+ super(MultiFrequencyDiscriminator, self).__init__()
8
+
9
+ self.nch = nch
10
+ self.window = window
11
+ self.hidden_channels = 8
12
+ self.eps = torch.finfo(torch.float32).eps
13
+ self.discriminators = nn.ModuleList([FrequencyDiscriminator(2*nch, self.hidden_channels) for _ in range(len(self.window))])
14
+
15
+ def forward(self, est, sample_rate=44100):
16
+
17
+ B, nch, _ = est.shape
18
+ assert nch == self.nch
19
+
20
+ # normalize power
21
+ est = est / (est.pow(2).sum((1,2)) + self.eps).sqrt().reshape(B, 1, 1)
22
+ est = est.view(-1, est.shape[-1])
23
+
24
+ est_outputs = []
25
+ est_feature_maps = []
26
+
27
+ for i in range(len(self.discriminators)):
28
+ est_spec = torch.stft(est.float(), self.window[i], self.window[i]//2,
29
+ window=torch.hann_window(self.window[i]).to(est.device).float(),
30
+ return_complex=True)
31
+ est_RI = torch.stack([est_spec.real, est_spec.imag], dim=1)
32
+ est_RI = est_RI.view(B, nch*2, est_RI.shape[-2], est_RI.shape[-1]).type(est.type())
33
+
34
+ valid_enc = int(est_RI.shape[2] * sample_rate / 44100)
35
+ est_out, est_feat_map = self.discriminators[i](est_RI[:,:,:valid_enc].contiguous())
36
+ est_outputs.append(est_out)
37
+ est_feature_maps.append(est_feat_map)
38
+
39
+ return est_outputs, est_feature_maps
40
+
41
+
42
+ class FrequencyDiscriminator(nn.Module):
43
+ def __init__(self, in_channels, hidden_channels=512):
44
+ super(FrequencyDiscriminator, self).__init__()
45
+
46
+ self.eps = torch.finfo(torch.float32).eps
47
+ self.discriminator = nn.ModuleList()
48
+ self.discriminator += [
49
+ nn.Sequential(
50
+ nn.utils.spectral_norm(nn.Conv2d(in_channels, hidden_channels, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
51
+ nn.LeakyReLU(0.2, True)
52
+ ),
53
+ nn.Sequential(
54
+ nn.utils.spectral_norm(nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
55
+ nn.LeakyReLU(0.2, True)
56
+ ),
57
+ nn.Sequential(
58
+ nn.utils.spectral_norm(nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
59
+ nn.LeakyReLU(0.2, True)
60
+ ),
61
+ nn.Sequential(
62
+ nn.utils.spectral_norm(nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
63
+ nn.LeakyReLU(0.2, True)
64
+ ),
65
+ nn.Sequential(
66
+ nn.utils.spectral_norm(nn.Conv2d(hidden_channels*8, hidden_channels*16, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
67
+ nn.LeakyReLU(0.2, True)
68
+ ),
69
+ nn.Sequential(
70
+ nn.utils.spectral_norm(nn.Conv2d(hidden_channels*16, hidden_channels*32, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
71
+ nn.LeakyReLU(0.2, True)
72
+ ),
73
+ nn.Conv2d(hidden_channels*32, 1, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
74
+ ]
75
+
76
+ def forward(self, x):
77
+ hiddens = []
78
+ for layer in self.discriminator:
79
+ x = layer(x)
80
+ hiddens.append(x)
81
+ return x, hiddens[:-1]
look2hear/losses/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-09 16:34:19
4
+ # LastEditors: Kai Li
5
+ # LastEditTime: 2021-07-12 20:55:35
6
+ ###
7
+ from .gan_losses import MultiFrequencyDisLoss, MultiFrequencyGenLoss
8
+ from .matrix import MultiSrcNegSDR
9
+
10
+ __all__ = [
11
+ "MultiFrequencyDisLoss",
12
+ "MultiFrequencyGenLoss",
13
+ "MultiSrcNegSDR"
14
+ ]
look2hear/losses/gan_losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-09 16:43:09
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2024-01-24 00:00:52
6
+ ###
7
+
8
+ import torch
9
+ from torch.nn.modules.loss import _Loss
10
+
11
+ def freq_MAE(output, target):
12
+ loss = 0.
13
+ eps = torch.finfo(torch.float32).eps
14
+ all_win = [32, 64, 128, 256, 512, 1024, 2048]
15
+ for win in all_win:
16
+ est_spec = torch.stft(output.view(-1, output.shape[-1]), n_fft=win, hop_length=win//2,
17
+ window=torch.hann_window(win).to(output.device).float(),
18
+ return_complex=True)
19
+ target_spec = torch.stft(target.view(-1, target.shape[-1]), n_fft=win, hop_length=win//2,
20
+ window=torch.hann_window(win).to(target.device).float(),
21
+ return_complex=True)
22
+
23
+ loss = loss + (est_spec.abs() - target_spec.abs()).abs().mean() / (target_spec.abs().mean() + eps)
24
+
25
+ return loss / len(all_win)
26
+
27
+ class MultiFrequencyDisLoss(_Loss):
28
+ def __init__(self, eps=1e-8):
29
+ super(MultiFrequencyDisLoss, self).__init__()
30
+
31
+ def forward(self, target_outputs, est_outputs):
32
+ D_real = 0
33
+ D_fake = 0
34
+ for i in range(len(target_outputs)):
35
+ D_real = D_real + (target_outputs[i] - 1).pow(2).mean() / len(target_outputs)
36
+ D_fake = D_fake + (est_outputs[i]).pow(2).mean() / len(est_outputs)
37
+ return D_real + D_fake
38
+
39
+ class MultiFrequencyGenLoss(_Loss):
40
+ def __init__(self, eps=1e-8):
41
+ super(MultiFrequencyGenLoss, self).__init__()
42
+ self.eps = eps
43
+
44
+ def forward(self, est_outputs, est_feature_maps, targets_feature_maps, output, ori_data):
45
+ G_fake = 0
46
+ feature_matching = 0
47
+ eps = self.eps
48
+
49
+ for i in range(len(est_outputs)):
50
+ G_fake = G_fake + (est_outputs[i] - 1).pow(2).mean() / len(est_outputs)
51
+ for j in range(len(est_feature_maps[i])):
52
+ feature_matching = feature_matching + (est_feature_maps[i][j] - targets_feature_maps[i][j].detach()).abs().mean() / (targets_feature_maps[i][j].detach().abs().mean() + eps)
53
+
54
+ feature_matching = feature_matching / (len(est_outputs) * len(est_feature_maps[0]))
55
+ freq_loss = freq_MAE(output, ori_data.unsqueeze(1))
56
+ total_loss = freq_loss + G_fake + feature_matching
57
+
58
+ return total_loss
look2hear/losses/matrix.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.modules.loss import _Loss
3
+
4
+ class MultiSrcNegSDR(_Loss):
5
+ def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
6
+ super().__init__()
7
+
8
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
9
+ self.sdr_type = sdr_type
10
+ self.zero_mean = zero_mean
11
+ self.take_log = take_log
12
+ self.EPS = 1e-8
13
+
14
+ def forward(self, ests, targets):
15
+ if targets.size() != ests.size() or targets.ndim != 3:
16
+ raise TypeError(
17
+ f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead"
18
+ )
19
+ # Step 1. Zero-mean norm
20
+ if self.zero_mean:
21
+ mean_source = torch.mean(targets, dim=2, keepdim=True)
22
+ mean_est = torch.mean(ests, dim=2, keepdim=True)
23
+ targets = targets - mean_source
24
+ ests = ests - mean_est
25
+ # Step 2. Pair-wise SI-SDR.
26
+ if self.sdr_type in ["sisdr", "sdsdr"]:
27
+ # [batch, n_src]
28
+ pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True)
29
+ # [batch, n_src]
30
+ s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS
31
+ # [batch, n_src, time]
32
+ scaled_targets = pair_wise_dot * targets / s_target_energy
33
+ else:
34
+ # [batch, n_src, time]
35
+ scaled_targets = targets
36
+ if self.sdr_type in ["sdsdr", "snr"]:
37
+ e_noise = ests - targets
38
+ else:
39
+ e_noise = ests - scaled_targets
40
+ # [batch, n_src]
41
+ pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / (
42
+ torch.sum(e_noise ** 2, dim=2) + self.EPS
43
+ )
44
+ if self.take_log:
45
+ pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
46
+ return -torch.mean(pair_wise_sdr, dim=-1).mean(0)
look2hear/metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-22 12:22:41
4
+ # LastEditors: Kai Li
5
+ # LastEditTime: 2021-07-14 19:15:22
6
+ ###
7
+ from .wrapper import MetricsTracker
8
+
9
+ __all__ = ["MetricsTracker"]
look2hear/metrics/wrapper.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-22 12:41:36
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2022-06-05 14:48:00
6
+ ###
7
+ import csv
8
+ from sympy import im
9
+ import torch
10
+ import numpy as np
11
+ import logging
12
+ import os
13
+ import librosa
14
+ from torch_mir_eval.separation import bss_eval_sources
15
+ import fast_bss_eval
16
+ from visqol import visqol_lib_py
17
+ from visqol.pb2 import visqol_config_pb2
18
+ from visqol.pb2 import similarity_result_pb2
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def is_silent(wav, threshold=1e-4):
23
+ return torch.sum(wav ** 2) / wav.numel() < threshold
24
+
25
+ class MetricsTracker:
26
+ def __init__(self, save_file: str = ""):
27
+ self.all_sdrs = []
28
+ self.all_sisnrs = []
29
+ self.all_visqols = []
30
+
31
+ csv_columns = ["snt_id", "sdr", "si-snr", "visqol"]
32
+ self.visqol_config = visqol_config_pb2.VisqolConfig()
33
+ self.visqol_config.audio.sample_rate = 48000
34
+ self.visqol_config.options.use_speech_scoring = False
35
+ svr_model_path = "libsvm_nu_svr_model.txt"
36
+ self.visqol_config.options.svr_model_path = os.path.join(os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path)
37
+ self.visqol_api = visqol_lib_py.VisqolApi()
38
+ self.visqol_api.Create(self.visqol_config)
39
+
40
+ self.results_csv = open(save_file, "w")
41
+ self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
42
+ self.writer.writeheader()
43
+
44
+ def __call__(self, clean, estimate, key):
45
+ sisnr = fast_bss_eval.si_sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean()
46
+ sdr = fast_bss_eval.sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean()
47
+
48
+ clean = librosa.resample(clean.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64)
49
+ estimate = librosa.resample(estimate.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64)
50
+
51
+ visqol = self.visqol_api.Measure(clean, estimate).moslqo
52
+ # import pdb; pdb.set_trace()
53
+ row = {
54
+ "snt_id": key,
55
+ "sdr": sdr.item(),
56
+ "si-snr": sisnr.item(),
57
+ "visqol": visqol
58
+ }
59
+
60
+ self.writer.writerow(row)
61
+ # Metric Accumulation
62
+ self.all_sdrs.append(sdr.item())
63
+ self.all_sisnrs.append(sisnr.item())
64
+ self.all_visqols.append(visqol)
65
+
66
+ def update(self, ):
67
+ return {"sdr": np.array(self.all_sdrs).mean(),
68
+ "si-snr": np.array(self.all_sisnrs).mean(),
69
+ "visqol": np.array(self.all_visqols).mean()}
70
+
71
+ def final(self,):
72
+ row = {
73
+ "snt_id": "avg",
74
+ "sdr": np.array(self.all_sdrs).mean(),
75
+ "si-snr": np.array(self.all_sisnrs).mean(),
76
+ "visqol": np.array(self.all_visqols).mean()
77
+ }
78
+ self.writer.writerow(row)
79
+ row = {
80
+ "snt_id": "std",
81
+ "sdr": np.array(self.all_sdrs).std(),
82
+ "si-snr": np.array(self.all_sisnrs).std(),
83
+ "visqol": np.array(self.all_visqols).std()
84
+ }
85
+ self.writer.writerow(row)
86
+ self.results_csv.close()
look2hear/models/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2022-02-12 15:16:35
4
+ # Email: [email protected]
5
+ # LastEditTime: 2022-10-04 16:24:53
6
+ ###
7
+ from .base_model import BaseModel
8
+ from .apollo import Apollo
9
+
10
+ __all__ = [
11
+ "BaseModel",
12
+ "GullFullband",
13
+ "Apollo"
14
+ ]
15
+
16
+
17
+ def register_model(custom_model):
18
+ """Register a custom model, gettable with `models.get`.
19
+
20
+ Args:
21
+ custom_model: Custom model to register.
22
+
23
+ """
24
+ if (
25
+ custom_model.__name__ in globals().keys()
26
+ or custom_model.__name__.lower() in globals().keys()
27
+ ):
28
+ raise ValueError(
29
+ f"Model {custom_model.__name__} already exists. Choose another name."
30
+ )
31
+ globals().update({custom_model.__name__: custom_model})
32
+
33
+
34
+ def get(identifier):
35
+ """Returns an model class from a string (case-insensitive).
36
+
37
+ Args:
38
+ identifier (str): the model name.
39
+
40
+ Returns:
41
+ :class:`torch.nn.Module`
42
+ """
43
+ if isinstance(identifier, str):
44
+ to_get = {k.lower(): v for k, v in globals().items()}
45
+ cls = to_get.get(identifier.lower())
46
+ if cls is None:
47
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
48
+ return cls
49
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
look2hear/models/apollo.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from .base_model import BaseModel
6
+
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, dimension, groups=1):
9
+ super().__init__()
10
+
11
+ self.weight = nn.Parameter(torch.ones(dimension))
12
+ self.groups = groups
13
+ self.eps = 1e-5
14
+
15
+ def forward(self, input):
16
+ # input size: (B, N, T)
17
+ B, N, T = input.shape
18
+ assert N % self.groups == 0
19
+
20
+ input_float = input.reshape(B, self.groups, -1, T).float()
21
+ input_norm = input_float * torch.rsqrt(input_float.pow(2).mean(-2, keepdim=True) + self.eps)
22
+
23
+ return input_norm.type_as(input).reshape(B, N, T) * self.weight.reshape(1, -1, 1)
24
+
25
+ class RMVN(nn.Module):
26
+ """
27
+ Rescaled MVN.
28
+ """
29
+ def __init__(self, dimension, groups=1):
30
+ super(RMVN, self).__init__()
31
+
32
+ self.mean = nn.Parameter(torch.zeros(dimension))
33
+ self.std = nn.Parameter(torch.ones(dimension))
34
+ self.groups = groups
35
+ self.eps = 1e-5
36
+
37
+ def forward(self, input):
38
+ # input size: (B, N, *)
39
+ B, N = input.shape[:2]
40
+ assert N % self.groups == 0
41
+ input_reshape = input.reshape(B, self.groups, N // self.groups, -1)
42
+ T = input_reshape.shape[-1]
43
+
44
+ input_norm = (input_reshape - input_reshape.mean(2).unsqueeze(2)) / (input_reshape.var(2).unsqueeze(2) + self.eps).sqrt()
45
+ input_norm = input_norm.reshape(B, N, T) * self.std.reshape(1, -1, 1) + self.mean.reshape(1, -1, 1)
46
+
47
+ return input_norm.reshape(input.shape)
48
+
49
+ class Roformer(nn.Module):
50
+ """
51
+ Transformer with rotary positional embedding.
52
+ """
53
+ def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000,
54
+ input_drop=0., attention_drop=0., causal=True):
55
+ super().__init__()
56
+
57
+ self.input_size = input_size
58
+ self.hidden_size = hidden_size // num_head
59
+ self.num_head = num_head
60
+ self.theta = theta # base frequency for RoPE
61
+ self.window = window
62
+ # pre-calculate rotary embeddings
63
+ cos_freq, sin_freq = self._calc_rotary_emb()
64
+ self.register_buffer("cos_freq", cos_freq) # win, N
65
+ self.register_buffer("sin_freq", sin_freq) # win, N
66
+
67
+ self.attention_drop = attention_drop
68
+ self.causal = causal
69
+ self.eps = 1e-5
70
+
71
+ self.input_norm = RMSNorm(self.input_size)
72
+ self.input_drop = nn.Dropout(p=input_drop)
73
+ self.weight = nn.Conv1d(self.input_size, self.hidden_size*self.num_head*3, 1, bias=False)
74
+ self.output = nn.Conv1d(self.hidden_size*self.num_head, self.input_size, 1, bias=False)
75
+
76
+ self.MLP = nn.Sequential(RMSNorm(self.input_size),
77
+ nn.Conv1d(self.input_size, self.input_size*8, 1, bias=False),
78
+ nn.SiLU()
79
+ )
80
+ self.MLP_output = nn.Conv1d(self.input_size*4, self.input_size, 1, bias=False)
81
+
82
+ def _calc_rotary_emb(self):
83
+ freq = 1. / (self.theta ** (torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size)) # theta_i
84
+ freq = freq.reshape(1, -1) # 1, N//2
85
+ pos = torch.arange(0, self.window).reshape(-1, 1) # win, 1
86
+ cos_freq = torch.cos(pos*freq) # win, N//2
87
+ sin_freq = torch.sin(pos*freq) # win, N//2
88
+ cos_freq = torch.stack([cos_freq]*2, -1).reshape(self.window, self.hidden_size) # win, N
89
+ sin_freq = torch.stack([sin_freq]*2, -1).reshape(self.window, self.hidden_size) # win, N
90
+
91
+ return cos_freq, sin_freq
92
+
93
+ def _add_rotary_emb(self, feature, pos):
94
+ # feature shape: ..., N
95
+ N = feature.shape[-1]
96
+
97
+ feature_reshape = feature.reshape(-1, N)
98
+ pos = min(pos, self.window-1)
99
+ cos_freq = self.cos_freq[pos]
100
+ sin_freq = self.sin_freq[pos]
101
+ reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
102
+ feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, N)
103
+ feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
104
+
105
+ return feature_rope.reshape(feature.shape)
106
+
107
+ def _add_rotary_sequence(self, feature):
108
+ # feature shape: ..., T, N
109
+ T, N = feature.shape[-2:]
110
+ feature_reshape = feature.reshape(-1, T, N)
111
+
112
+ cos_freq = self.cos_freq[:T]
113
+ sin_freq = self.sin_freq[:T]
114
+ reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
115
+ feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, T, N)
116
+ feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
117
+
118
+ return feature_rope.reshape(feature.shape)
119
+
120
+ def forward(self, input):
121
+ # input shape: B, N, T
122
+
123
+ B, _, T = input.shape
124
+
125
+ weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size*3, T).mT
126
+ Q, K, V = torch.split(weight, self.hidden_size, dim=-1) # B, num_head, T, N
127
+
128
+ # rotary positional embedding
129
+ Q_rot = self._add_rotary_sequence(Q)
130
+ K_rot = self._add_rotary_sequence(K)
131
+
132
+ attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(), dropout_p=self.attention_drop, is_causal=self.causal) # B, num_head, T, N
133
+ attention_output = attention_output.mT.reshape(B, -1, T)
134
+ output = self.output(attention_output) + input
135
+
136
+ gate, z = self.MLP(output).chunk(2, dim=1)
137
+ output = output + self.MLP_output(F.silu(gate) * z)
138
+
139
+ return output, (K_rot, V)
140
+
141
+ class ConvActNorm1d(nn.Module):
142
+ def __init__(self, in_channel, hidden_channel, kernel=7, causal=False):
143
+ super(ConvActNorm1d, self).__init__()
144
+
145
+ self.in_channel = in_channel
146
+ self.kernel = kernel
147
+ self.causal = causal
148
+ if not causal:
149
+ self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2, groups=in_channel),
150
+ RMSNorm(in_channel),
151
+ nn.Conv1d(in_channel, hidden_channel, 1),
152
+ nn.SiLU(),
153
+ nn.Conv1d(hidden_channel, in_channel, 1)
154
+ )
155
+ else:
156
+ self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1, groups=in_channel),
157
+ RMSNorm(in_channel),
158
+ nn.Conv1d(in_channel, hidden_channel, 1),
159
+ nn.SiLU(),
160
+ nn.Conv1d(hidden_channel, in_channel, 1)
161
+ )
162
+
163
+ def forward(self, input):
164
+
165
+ output = self.conv(input)
166
+ if self.causal:
167
+ output = output[...,:-self.kernel+1]
168
+ return input + output
169
+
170
+ class ICB(nn.Module):
171
+ def __init__(self, in_channel, kernel=7, causal=False):
172
+ super(ICB, self).__init__()
173
+
174
+ self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal),
175
+ ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal),
176
+ ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal)
177
+ )
178
+
179
+ def forward(self, input):
180
+
181
+ return self.blocks(input)
182
+
183
+ class BSNet(nn.Module):
184
+ def __init__(self, feature_dim, kernel=7):
185
+ super(BSNet, self).__init__()
186
+
187
+ self.feature_dim = feature_dim
188
+
189
+ self.band_net = Roformer(self.feature_dim, self.feature_dim, num_head=8, window=100, causal=False)
190
+ self.seq_net = ICB(self.feature_dim, kernel=kernel)
191
+
192
+ def forward(self, input):
193
+ # input shape: B, nband, N, T
194
+
195
+ B, nband, N, T = input.shape
196
+
197
+ # band comm
198
+ band_input = input.permute(0,3,2,1).reshape(B*T, -1, nband)
199
+ band_output, _ = self.band_net(band_input)
200
+ band_output = band_output.reshape(B, T, -1, nband).permute(0,3,2,1)
201
+
202
+ # sequence modeling
203
+ output = self.seq_net(band_output.reshape(B*nband, -1, T)).reshape(B, nband, -1, T) # B, nband, N, T
204
+
205
+ return output
206
+
207
+ class Apollo(BaseModel):
208
+ def __init__(
209
+ self,
210
+ sr: int,
211
+ win: int,
212
+ feature_dim: int,
213
+ layer: int
214
+ ):
215
+ super().__init__(sample_rate=sr)
216
+
217
+ self.sr = sr
218
+ self.win = int(sr * win // 1000)
219
+ self.stride = self.win // 2
220
+ self.enc_dim = self.win // 2 + 1
221
+ self.feature_dim = feature_dim
222
+ self.eps = torch.finfo(torch.float32).eps
223
+
224
+ # 80 bands
225
+ bandwidth = int(self.win / 160)
226
+ self.band_width = [bandwidth]*79
227
+ self.band_width.append(self.enc_dim - np.sum(self.band_width))
228
+ self.nband = len(self.band_width)
229
+ print(self.band_width, self.nband)
230
+
231
+ self.BN = nn.ModuleList([])
232
+ for i in range(self.nband):
233
+ self.BN.append(nn.Sequential(RMSNorm(self.band_width[i]*2+1),
234
+ nn.Conv1d(self.band_width[i]*2+1, self.feature_dim, 1))
235
+ )
236
+
237
+ self.net = []
238
+ for _ in range(layer):
239
+ self.net.append(BSNet(self.feature_dim))
240
+ self.net = nn.Sequential(*self.net)
241
+
242
+ self.output = nn.ModuleList([])
243
+ for i in range(self.nband):
244
+ self.output.append(nn.Sequential(RMSNorm(self.feature_dim),
245
+ nn.Conv1d(self.feature_dim, self.band_width[i]*4, 1),
246
+ nn.GLU(dim=1)
247
+ )
248
+ )
249
+
250
+ def spec_band_split(self, input):
251
+
252
+ B, nch, nsample = input.shape
253
+
254
+ spec = torch.stft(input.view(B*nch, nsample), n_fft=self.win, hop_length=self.stride,
255
+ window=torch.hann_window(self.win).to(input.device), return_complex=True)
256
+
257
+ subband_spec = []
258
+ subband_spec_norm = []
259
+ subband_power = []
260
+ band_idx = 0
261
+ for i in range(self.nband):
262
+ this_spec = spec[:,band_idx:band_idx+self.band_width[i]]
263
+ subband_spec.append(this_spec) # B, BW, T
264
+ subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T
265
+ subband_spec_norm.append(torch.complex(this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1])) # B, BW, T
266
+ band_idx += self.band_width[i]
267
+ subband_power = torch.cat(subband_power, 1) # B, nband, T
268
+
269
+ return subband_spec_norm, subband_power
270
+
271
+ def feature_extractor(self, input):
272
+
273
+ subband_spec_norm, subband_power = self.spec_band_split(input)
274
+
275
+ # normalization and bottleneck
276
+ subband_feature = []
277
+ for i in range(self.nband):
278
+ concat_spec = torch.cat([subband_spec_norm[i].real, subband_spec_norm[i].imag, torch.log(subband_power[:,i].unsqueeze(1))], 1)
279
+ subband_feature.append(self.BN[i](concat_spec))
280
+ subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
281
+
282
+ return subband_feature
283
+
284
+ def forward(self, input):
285
+
286
+ B, nch, nsample = input.shape
287
+
288
+ subband_feature = self.feature_extractor(input)
289
+ feature = self.net(subband_feature)
290
+
291
+ est_spec = []
292
+ for i in range(self.nband):
293
+ this_RI = self.output[i](feature[:,i]).view(B*nch, 2, self.band_width[i], -1)
294
+ est_spec.append(torch.complex(this_RI[:,0], this_RI[:,1]))
295
+ est_spec = torch.cat(est_spec, 1)
296
+ output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride,
297
+ window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1)
298
+
299
+ return output
300
+
301
+ def get_model_args(self):
302
+ model_args = {"n_sample_rate": 2}
303
+ return model_args
look2hear/models/base_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-17 23:08:32
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2022-05-26 18:06:22
6
+ ###
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from huggingface_hub import PyTorchModelHubMixin
11
+
12
+
13
+ def _unsqueeze_to_3d(x):
14
+ """Normalize shape of `x` to [batch, n_chan, time]."""
15
+ if x.ndim == 1:
16
+ return x.reshape(1, 1, -1)
17
+ elif x.ndim == 2:
18
+ return x.unsqueeze(1)
19
+ else:
20
+ return x
21
+
22
+
23
+ def pad_to_appropriate_length(x, lcm):
24
+ values_to_pad = int(x.shape[-1]) % lcm
25
+ if values_to_pad:
26
+ appropriate_shape = x.shape
27
+ padded_x = torch.zeros(
28
+ list(appropriate_shape[:-1])
29
+ + [appropriate_shape[-1] + lcm - values_to_pad],
30
+ dtype=torch.float32,
31
+ ).to(x.device)
32
+ padded_x[..., : x.shape[-1]] = x
33
+ return padded_x
34
+ return x
35
+
36
+
37
+ class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"):
38
+ def __init__(self, sample_rate, in_chan=1):
39
+ super().__init__()
40
+ self._sample_rate = sample_rate
41
+ self._in_chan = in_chan
42
+
43
+ def forward(self, *args, **kwargs):
44
+ raise NotImplementedError
45
+
46
+ def sample_rate(self,):
47
+ return self._sample_rate
48
+
49
+ @staticmethod
50
+ def load_state_dict_in_audio(model, pretrained_dict):
51
+ model_dict = model.state_dict()
52
+ update_dict = {}
53
+ for k, v in pretrained_dict.items():
54
+ if "audio_model" in k:
55
+ update_dict[k[12:]] = v
56
+ model_dict.update(update_dict)
57
+ model.load_state_dict(model_dict)
58
+ return model
59
+
60
+ @staticmethod
61
+ def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs):
62
+ from . import get
63
+
64
+ conf = torch.load(
65
+ pretrained_model_conf_or_path, map_location="cpu"
66
+ ) # Attempt to find the model and instantiate it.
67
+
68
+ model_class = get(conf["model_name"])
69
+ # model_class = get("Conv_TasNet")
70
+ model = model_class(*args, **kwargs)
71
+ model.load_state_dict(conf["state_dict"])
72
+ return model
73
+
74
+ def serialize(self):
75
+ import pytorch_lightning as pl # Not used in torch.hub
76
+
77
+ model_conf = dict(
78
+ model_name=self.__class__.__name__,
79
+ state_dict=self.get_state_dict(),
80
+ model_args=self.get_model_args(),
81
+ )
82
+ # Additional infos
83
+ infos = dict()
84
+ infos["software_versions"] = dict(
85
+ torch_version=torch.__version__, pytorch_lightning_version=pl.__version__,
86
+ )
87
+ model_conf["infos"] = infos
88
+ return model_conf
89
+
90
+ def get_state_dict(self):
91
+ """In case the state dict needs to be modified before sharing the model."""
92
+ return self.state_dict()
93
+
94
+ def get_model_args(self):
95
+ """Should return args to re-instantiate the class."""
96
+ raise NotImplementedError
look2hear/system/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-20 17:52:35
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2022-05-26 18:27:43
6
+ ###
7
+
8
+
9
+ from .optimizers import make_optimizer
10
+ from .audio_litmodule import AudioLightningModule
11
+ from .schedulers import DPTNetScheduler
12
+
13
+ __all__ = [
14
+ "make_optimizer",
15
+ "AudioLightningModule",
16
+ "DPTNetScheduler"
17
+ ]
look2hear/system/audio_litmodule.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2022-05-26 18:09:54
4
+ # Email: [email protected]
5
+ # LastEditTime: 2024-01-24 00:00:28
6
+ ###
7
+ import gc
8
+ from omegaconf import OmegaConf
9
+ import torch
10
+ import pytorch_lightning as pl
11
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
12
+ from collections.abc import MutableMapping
13
+ from omegaconf import ListConfig
14
+
15
+ def flatten_dict(d, parent_key="", sep="_"):
16
+ """Flattens a dictionary into a single-level dictionary while preserving
17
+ parent keys. Taken from
18
+ `SO <https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys>`_
19
+
20
+ Args:
21
+ d (MutableMapping): Dictionary to be flattened.
22
+ parent_key (str): String to use as a prefix to all subsequent keys.
23
+ sep (str): String to use as a separator between two key levels.
24
+
25
+ Returns:
26
+ dict: Single-level dictionary, flattened.
27
+ """
28
+ items = []
29
+ for k, v in d.items():
30
+ new_key = parent_key + sep + k if parent_key else k
31
+ if isinstance(v, MutableMapping):
32
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
33
+ else:
34
+ items.append((new_key, v))
35
+ return dict(items)
36
+
37
+
38
+ class AudioLightningModule(pl.LightningModule):
39
+ def __init__(
40
+ self,
41
+ model=None,
42
+ discriminator=None,
43
+ optimizer=None,
44
+ loss_func=None,
45
+ metrics=None,
46
+ scheduler=None,
47
+ ):
48
+ super().__init__()
49
+ self.audio_model = model
50
+ self.discriminator = discriminator
51
+ self.optimizer = list(optimizer)
52
+ self.loss_func = loss_func
53
+ self.metrics = metrics
54
+ self.scheduler = list(scheduler)
55
+
56
+ # Save lightning"s AttributeDict under self.hparams
57
+ self.default_monitor = "val_loss"
58
+ # self.print(self.audio_model)
59
+ self.validation_step_outputs = []
60
+ self.test_step_outputs = []
61
+ self.automatic_optimization = False
62
+
63
+ def forward(self, wav):
64
+ """Applies forward pass of the model.
65
+
66
+ Returns:
67
+ :class:`torch.Tensor`
68
+ """
69
+ return self.audio_model(wav)
70
+
71
+ def training_step(self, batch, batch_nb):
72
+ ori_data, codec_data = batch
73
+ optimizer_g, optimizer_d = self.optimizers()
74
+ # multiple schedulers
75
+ scheduler_g, scheduler_d = self.lr_schedulers()
76
+
77
+ # train discriminator
78
+ optimizer_g.zero_grad()
79
+ output = self(codec_data)
80
+
81
+ optimizer_d.zero_grad()
82
+ est_outputs, _ = self.discriminator(output.detach(), sample_rate=44100)
83
+ target_outputs, _ = self.discriminator(ori_data, sample_rate=44100)
84
+
85
+ loss_d = self.loss_func["d"](target_outputs, est_outputs)
86
+ self.manual_backward(loss_d)
87
+ self.clip_gradients(optimizer_d, gradient_clip_val=5, gradient_clip_algorithm="norm")
88
+ optimizer_d.step()
89
+ # train generator
90
+ est_outputs, est_feature_maps = self.discriminator(output, sample_rate=44100)
91
+ _, targets_feature_maps = self.discriminator(ori_data, sample_rate=44100)
92
+
93
+ loss_g = self.loss_func["g"](est_outputs, est_feature_maps, targets_feature_maps, output, ori_data)
94
+ self.manual_backward(loss_g)
95
+ self.clip_gradients(optimizer_g, gradient_clip_val=5, gradient_clip_algorithm="norm")
96
+ optimizer_g.step()
97
+ # print(loss)
98
+
99
+ if self.trainer.is_last_batch:
100
+ scheduler_g.step()
101
+ scheduler_d.step()
102
+
103
+ self.log(
104
+ "train_loss_d",
105
+ loss_d,
106
+ on_epoch=True,
107
+ prog_bar=True,
108
+ sync_dist=True,
109
+ logger=True,
110
+ )
111
+
112
+ self.log(
113
+ "train_loss_g",
114
+ loss_g,
115
+ on_epoch=True,
116
+ prog_bar=True,
117
+ sync_dist=True,
118
+ logger=True,
119
+ )
120
+
121
+
122
+ def validation_step(self, batch, batch_nb):
123
+ # cal val loss
124
+ ori_data, codec_data = batch
125
+ # print(mixtures.shape)
126
+ est_sources = self(codec_data)
127
+ loss = self.metrics(est_sources, ori_data)
128
+
129
+ self.log(
130
+ "val_loss",
131
+ loss,
132
+ on_epoch=True,
133
+ prog_bar=True,
134
+ sync_dist=True,
135
+ logger=True,
136
+ )
137
+
138
+ self.validation_step_outputs.append(loss)
139
+
140
+ return {"val_loss": loss}
141
+
142
+ def on_validation_epoch_end(self):
143
+ # val
144
+ avg_loss = torch.stack(self.validation_step_outputs).mean()
145
+ val_loss = torch.mean(self.all_gather(avg_loss))
146
+ self.log(
147
+ "lr",
148
+ self.optimizer[0].param_groups[0]["lr"],
149
+ on_epoch=True,
150
+ prog_bar=True,
151
+ sync_dist=True,
152
+ )
153
+ self.logger.experiment.log(
154
+ {"learning_rate": self.optimizer[0].param_groups[0]["lr"], "epoch": self.current_epoch}
155
+ )
156
+ self.logger.experiment.log(
157
+ {"val_pit_sisnr": -val_loss, "epoch": self.current_epoch}
158
+ )
159
+
160
+ self.validation_step_outputs.clear() # free memory
161
+ torch.cuda.empty_cache()
162
+
163
+ def test_step(self, batch, batch_nb):
164
+ mixtures, targets = batch
165
+ est_sources = self(mixtures)
166
+ loss = self.metrics(est_sources, targets)
167
+ self.log(
168
+ "test_loss",
169
+ loss,
170
+ on_epoch=True,
171
+ prog_bar=True,
172
+ sync_dist=True,
173
+ logger=True,
174
+ )
175
+ self.test_step_outputs.append(loss)
176
+ return {"test_loss": loss}
177
+
178
+ def on_test_epoch_end(self):
179
+ # val
180
+ avg_loss = torch.stack(self.test_step_outputs).mean()
181
+ test_loss = torch.mean(self.all_gather(avg_loss))
182
+ self.log(
183
+ "lr",
184
+ self.optimizer.param_groups[0]["lr"],
185
+ on_epoch=True,
186
+ prog_bar=True,
187
+ sync_dist=True,
188
+ )
189
+ self.logger.experiment.log(
190
+ {"learning_rate": self.optimizer.param_groups[0]["lr"], "epoch": self.current_epoch}
191
+ )
192
+ self.logger.experiment.log(
193
+ {"test_pit_sisnr": -test_loss, "epoch": self.current_epoch}
194
+ )
195
+
196
+ self.test_step_outputs.clear()
197
+
198
+ def configure_optimizers(self):
199
+ """Initialize optimizers, batch-wise and epoch-wise schedulers."""
200
+ if self.scheduler is None:
201
+ return self.optimizer
202
+ if not isinstance(self.scheduler, (list, tuple)):
203
+ self.scheduler = [self.scheduler] # support multiple schedulers
204
+
205
+ if not isinstance(self.optimizer, (list, tuple)):
206
+ self.optimizer = [self.optimizer] # support multiple schedulers
207
+
208
+ epoch_schedulers = []
209
+ for sched in self.scheduler:
210
+ if not isinstance(sched, dict):
211
+ if isinstance(sched, ReduceLROnPlateau):
212
+ sched = {"scheduler": sched, "monitor": self.default_monitor}
213
+ epoch_schedulers.append(sched)
214
+ else:
215
+ sched.setdefault("monitor", self.default_monitor)
216
+ sched.setdefault("frequency", 1)
217
+ # Backward compat
218
+ if sched["interval"] == "batch":
219
+ sched["interval"] = "step"
220
+ assert sched["interval"] in [
221
+ "epoch",
222
+ "step",
223
+ ], "Scheduler interval should be either step or epoch"
224
+ epoch_schedulers.append(sched)
225
+ return self.optimizer, epoch_schedulers
226
+
227
+ @staticmethod
228
+ def config_to_hparams(dic):
229
+ """Sanitizes the config dict to be handled correctly by torch
230
+ SummaryWriter. It flatten the config dict, converts ``None`` to
231
+ ``"None"`` and any list and tuple into torch.Tensors.
232
+
233
+ Args:
234
+ dic (dict): Dictionary to be transformed.
235
+
236
+ Returns:
237
+ dict: Transformed dictionary.
238
+ """
239
+ dic = flatten_dict(dic)
240
+ for k, v in dic.items():
241
+ if v is None:
242
+ dic[k] = str(v)
243
+ elif isinstance(v, (list, tuple)):
244
+ dic[k] = torch.tensor(v)
245
+ return dic
look2hear/system/optimizers.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-20 00:21:33
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2022-05-27 11:19:51
6
+ ###
7
+
8
+ from torch.optim.optimizer import Optimizer
9
+ from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD
10
+ from torch_optimizer import (
11
+ AccSGD,
12
+ AdaBound,
13
+ AdaMod,
14
+ DiffGrad,
15
+ Lamb,
16
+ NovoGrad,
17
+ PID,
18
+ QHAdam,
19
+ QHM,
20
+ RAdam,
21
+ SGDW,
22
+ Yogi,
23
+ Ranger,
24
+ RangerQH,
25
+ RangerVA,
26
+ )
27
+
28
+
29
+ __all__ = [
30
+ "AccSGD",
31
+ "AdaBound",
32
+ "AdaMod",
33
+ "DiffGrad",
34
+ "Lamb",
35
+ "NovoGrad",
36
+ "PID",
37
+ "QHAdam",
38
+ "QHM",
39
+ "RAdam",
40
+ "SGDW",
41
+ "Yogi",
42
+ "Ranger",
43
+ "RangerQH",
44
+ "RangerVA",
45
+ "Adam",
46
+ "RMSprop",
47
+ "SGD",
48
+ "Adadelta",
49
+ "Adagrad",
50
+ "Adamax",
51
+ "AdamW",
52
+ "ASGD",
53
+ "make_optimizer",
54
+ "get",
55
+ ]
56
+
57
+
58
+ def make_optimizer(params, optim_name="adam", **kwargs):
59
+ """
60
+
61
+ Args:
62
+ params (iterable): Output of `nn.Module.parameters()`.
63
+ optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood
64
+ by :func:`~.get`.
65
+ **kwargs (dict): keyword arguments for the optimizer.
66
+
67
+ Returns:
68
+ torch.optim.Optimizer
69
+ Examples
70
+ >>> from torch import nn
71
+ >>> model = nn.Sequential(nn.Linear(10, 10))
72
+ >>> optimizer = make_optimizer(model.parameters(), optimizer='sgd',
73
+ >>> lr=1e-3)
74
+ """
75
+ return get(optim_name)(params, **kwargs)
76
+
77
+
78
+ def register_optimizer(custom_opt):
79
+ """Register a custom opt, gettable with `optimzers.get`.
80
+
81
+ Args:
82
+ custom_opt: Custom optimizer to register.
83
+
84
+ """
85
+ if (
86
+ custom_opt.__name__ in globals().keys()
87
+ or custom_opt.__name__.lower() in globals().keys()
88
+ ):
89
+ raise ValueError(
90
+ f"Activation {custom_opt.__name__} already exists. Choose another name."
91
+ )
92
+ globals().update({custom_opt.__name__: custom_opt})
93
+
94
+
95
+ def get(identifier):
96
+ """Returns an optimizer function from a string. Returns its input if it
97
+ is callable (already a :class:`torch.optim.Optimizer` for example).
98
+
99
+ Args:
100
+ identifier (str or Callable): the optimizer identifier.
101
+
102
+ Returns:
103
+ :class:`torch.optim.Optimizer` or None
104
+ """
105
+ if isinstance(identifier, Optimizer):
106
+ return identifier
107
+ elif isinstance(identifier, str):
108
+ to_get = {k.lower(): v for k, v in globals().items()}
109
+ cls = to_get.get(identifier.lower())
110
+ if cls is None:
111
+ raise ValueError(f"Could not interpret optimizer : {str(identifier)}")
112
+ return cls
113
+ raise ValueError(f"Could not interpret optimizer : {str(identifier)}")
look2hear/system/schedulers.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim.optimizer import Optimizer
3
+ import pytorch_lightning as pl
4
+ from torch.optim.lr_scheduler import _LRScheduler
5
+
6
+
7
+ class BaseScheduler(object):
8
+ """Base class for the step-wise scheduler logic.
9
+
10
+ Args:
11
+ optimizer (Optimize): Optimizer instance to apply lr schedule on.
12
+
13
+ Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler.
14
+ """
15
+
16
+ def __init__(self, optimizer):
17
+ self.optimizer = optimizer
18
+ self.step_num = 0
19
+
20
+ def zero_grad(self):
21
+ self.optimizer.zero_grad()
22
+
23
+ def _get_lr(self):
24
+ raise NotImplementedError
25
+
26
+ def _set_lr(self, lr):
27
+ for param_group in self.optimizer.param_groups:
28
+ param_group["lr"] = lr
29
+
30
+ def step(self, metrics=None, epoch=None):
31
+ """Update step-wise learning rate before optimizer.step."""
32
+ self.step_num += 1
33
+ lr = self._get_lr()
34
+ self._set_lr(lr)
35
+
36
+ def load_state_dict(self, state_dict):
37
+ self.__dict__.update(state_dict)
38
+
39
+ def state_dict(self):
40
+ return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
41
+
42
+ def as_tensor(self, start=0, stop=100_000):
43
+ """Returns the scheduler values from start to stop."""
44
+ lr_list = []
45
+ for _ in range(start, stop):
46
+ self.step_num += 1
47
+ lr_list.append(self._get_lr())
48
+ self.step_num = 0
49
+ return torch.tensor(lr_list)
50
+
51
+ def plot(self, start=0, stop=100_000): # noqa
52
+ """Plot the scheduler values from start to stop."""
53
+ import matplotlib.pyplot as plt
54
+
55
+ all_lr = self.as_tensor(start=start, stop=stop)
56
+ plt.plot(all_lr.numpy())
57
+ plt.show()
58
+
59
+ class DPTNetScheduler(BaseScheduler):
60
+ """Dual Path Transformer Scheduler used in [1]
61
+
62
+ Args:
63
+ optimizer (Optimizer): Optimizer instance to apply lr schedule on.
64
+ steps_per_epoch (int): Number of steps per epoch.
65
+ d_model(int): The number of units in the layer output.
66
+ warmup_steps (int): The number of steps in the warmup stage of training.
67
+ noam_scale (float): Linear increase rate in first phase.
68
+ exp_max (float): Max learning rate in second phase.
69
+ exp_base (float): Exp learning rate base in second phase.
70
+
71
+ Schedule:
72
+ This scheduler increases the learning rate linearly for the first
73
+ ``warmup_steps``, and then decay it by 0.98 for every two epochs.
74
+
75
+ References
76
+ [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context-
77
+ Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ optimizer,
83
+ steps_per_epoch,
84
+ d_model,
85
+ warmup_steps=4000,
86
+ noam_scale=1.0,
87
+ exp_max=0.0004,
88
+ exp_base=0.98,
89
+ ):
90
+ super().__init__(optimizer)
91
+ self.noam_scale = noam_scale
92
+ self.d_model = d_model
93
+ self.warmup_steps = warmup_steps
94
+ self.exp_max = exp_max
95
+ self.exp_base = exp_base
96
+ self.steps_per_epoch = steps_per_epoch
97
+ self.epoch = 0
98
+
99
+ def _get_lr(self):
100
+ if self.step_num % self.steps_per_epoch == 0:
101
+ self.epoch += 1
102
+
103
+ if self.step_num > self.warmup_steps:
104
+ # exp decaying
105
+ lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2))
106
+ else:
107
+ # noam
108
+ lr = (
109
+ self.noam_scale
110
+ * self.d_model ** (-0.5)
111
+ * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5))
112
+ )
113
+ return lr
114
+
115
+ class CustomExponentialLR(_LRScheduler):
116
+ def __init__(self, optimizer, gamma, step_size, last_epoch=-1):
117
+ self.gamma = gamma
118
+ self.step_size = step_size
119
+ self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))
120
+ super(CustomExponentialLR, self).__init__(optimizer, last_epoch)
121
+
122
+ def get_lr(self):
123
+ if self.last_epoch == 0 or (self.last_epoch + 1) % self.step_size != 0:
124
+ return [group['lr'] for group in self.optimizer.param_groups]
125
+ return [lr * self.gamma for lr in self.base_lrs]
126
+
127
+
128
+ # Backward compat
129
+ _BaseScheduler = BaseScheduler
look2hear/utils/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-18 16:53:49
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2024-01-22 01:01:02
6
+ ###
7
+ from .stft import STFT
8
+ from .torch_utils import pad_x_to_y, shape_reconstructed, tensors_to_device
9
+ from .parser_utils import (
10
+ prepare_parser_from_dict,
11
+ parse_args_as_dict,
12
+ str_int_float,
13
+ str2bool,
14
+ str2bool_arg,
15
+ isfloat,
16
+ isint,
17
+ instantiate
18
+ )
19
+ from .lightning_utils import print_only, RichProgressBarTheme, MyRichProgressBar, BatchesProcessedColumn, MyMetricsTextColumn
20
+ from .complex_utils import is_complex, is_torch_complex_tensor, new_complex_like
21
+ from .get_layer_from_string import get_layer
22
+ from .inversible_interface import InversibleInterface
23
+ from .nets_utils import make_pad_mask
24
+ from .pylogger import RankedLogger
25
+ from .separator import wav_chunk_inference
26
+
27
+ __all__ = [
28
+ "wav_chunk_inference",
29
+ "RankedLogger",
30
+ "instantiate",
31
+ "STFT",
32
+ "pad_x_to_y",
33
+ "shape_reconstructed",
34
+ "tensors_to_device",
35
+ "prepare_parser_from_dict",
36
+ "parse_args_as_dict",
37
+ "str_int_float",
38
+ "str2bool",
39
+ "str2bool_arg",
40
+ "isfloat",
41
+ "isint",
42
+ "print_only",
43
+ "RichProgressBarTheme",
44
+ "MyRichProgressBar",
45
+ "BatchesProcessedColumn",
46
+ "MyMetricsTextColumn",
47
+ "is_complex",
48
+ "is_torch_complex_tensor",
49
+ "new_complex_like",
50
+ "get_layer",
51
+ "InversibleInterface",
52
+ "make_pad_mask",
53
+ ]
look2hear/utils/complex_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Beamformer module."""
2
+ from typing import Sequence, Tuple, Union
3
+
4
+ import torch
5
+ from packaging.version import parse as V
6
+ from torch_complex import functional as FC
7
+ from torch_complex.tensor import ComplexTensor
8
+
9
+ EPS = torch.finfo(torch.double).eps
10
+ is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0")
11
+ is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
12
+
13
+
14
+ def new_complex_like(
15
+ ref: Union[torch.Tensor, ComplexTensor],
16
+ real_imag: Tuple[torch.Tensor, torch.Tensor],
17
+ ):
18
+ if isinstance(ref, ComplexTensor):
19
+ return ComplexTensor(*real_imag)
20
+ elif is_torch_complex_tensor(ref):
21
+ return torch.complex(*real_imag)
22
+ else:
23
+ raise ValueError(
24
+ "Please update your PyTorch version to 1.9+ for complex support."
25
+ )
26
+
27
+
28
+ def is_torch_complex_tensor(c):
29
+ return (
30
+ not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
31
+ )
32
+
33
+
34
+ def is_complex(c):
35
+ return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
36
+
37
+
38
+ def to_double(c):
39
+ if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
40
+ return c.to(dtype=torch.complex128)
41
+ else:
42
+ return c.double()
43
+
44
+
45
+ def to_float(c):
46
+ if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
47
+ return c.to(dtype=torch.complex64)
48
+ else:
49
+ return c.float()
50
+
51
+
52
+ def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
53
+ if not isinstance(seq, (list, tuple)):
54
+ raise TypeError(
55
+ "cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
56
+ "not Tensor"
57
+ )
58
+ if isinstance(seq[0], ComplexTensor):
59
+ return FC.cat(seq, *args, **kwargs)
60
+ else:
61
+ return torch.cat(seq, *args, **kwargs)
62
+
63
+
64
+ def complex_norm(
65
+ c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
66
+ ) -> torch.Tensor:
67
+ if not is_complex(c):
68
+ raise TypeError("Input is not a complex tensor.")
69
+ if is_torch_complex_tensor(c):
70
+ return torch.norm(c, dim=dim, keepdim=keepdim)
71
+ else:
72
+ if dim is None:
73
+ return torch.sqrt((c.real**2 + c.imag**2).sum() + EPS)
74
+ else:
75
+ return torch.sqrt(
76
+ (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
77
+ )
78
+
79
+
80
+ def einsum(equation, *operands):
81
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
82
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
83
+ # mixed input with complex and real tensors.
84
+ if len(operands) == 1:
85
+ if isinstance(operands[0], (tuple, list)):
86
+ operands = operands[0]
87
+ complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
88
+ return complex_module.einsum(equation, *operands)
89
+ elif len(operands) != 2:
90
+ op0 = operands[0]
91
+ same_type = all(op.dtype == op0.dtype for op in operands[1:])
92
+ if same_type:
93
+ _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
94
+ return _einsum(equation, *operands)
95
+ else:
96
+ raise ValueError("0 or More than 2 operands are not supported.")
97
+ a, b = operands
98
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
99
+ return FC.einsum(equation, a, b)
100
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
101
+ if not torch.is_complex(a):
102
+ o_real = torch.einsum(equation, a, b.real)
103
+ o_imag = torch.einsum(equation, a, b.imag)
104
+ return torch.complex(o_real, o_imag)
105
+ elif not torch.is_complex(b):
106
+ o_real = torch.einsum(equation, a.real, b)
107
+ o_imag = torch.einsum(equation, a.imag, b)
108
+ return torch.complex(o_real, o_imag)
109
+ else:
110
+ return torch.einsum(equation, a, b)
111
+ else:
112
+ return torch.einsum(equation, a, b)
113
+
114
+
115
+ def inverse(
116
+ c: Union[torch.Tensor, ComplexTensor]
117
+ ) -> Union[torch.Tensor, ComplexTensor]:
118
+ if isinstance(c, ComplexTensor):
119
+ return c.inverse2()
120
+ else:
121
+ return c.inverse()
122
+
123
+
124
+ def matmul(
125
+ a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
126
+ ) -> Union[torch.Tensor, ComplexTensor]:
127
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
128
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
129
+ # multiplication between complex and real tensors.
130
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
131
+ return FC.matmul(a, b)
132
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
133
+ if not torch.is_complex(a):
134
+ o_real = torch.matmul(a, b.real)
135
+ o_imag = torch.matmul(a, b.imag)
136
+ return torch.complex(o_real, o_imag)
137
+ elif not torch.is_complex(b):
138
+ o_real = torch.matmul(a.real, b)
139
+ o_imag = torch.matmul(a.imag, b)
140
+ return torch.complex(o_real, o_imag)
141
+ else:
142
+ return torch.matmul(a, b)
143
+ else:
144
+ return torch.matmul(a, b)
145
+
146
+
147
+ def trace(a: Union[torch.Tensor, ComplexTensor]):
148
+ # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
149
+ # support bacth processing. Use FC.trace() as fallback.
150
+ return FC.trace(a)
151
+
152
+
153
+ def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
154
+ if isinstance(a, ComplexTensor):
155
+ return FC.reverse(a, dim=dim)
156
+ else:
157
+ return torch.flip(a, dims=(dim,))
158
+
159
+
160
+ def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
161
+ """Solve the linear equation ax = b."""
162
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
163
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
164
+ # mixed input with complex and real tensors.
165
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
166
+ if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
167
+ return FC.solve(b, a, return_LU=False)
168
+ else:
169
+ return matmul(inverse(a), b)
170
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
171
+ if torch.is_complex(a) and torch.is_complex(b):
172
+ return torch.linalg.solve(a, b)
173
+ else:
174
+ return matmul(inverse(a), b)
175
+ else:
176
+ if is_torch_1_8_plus:
177
+ return torch.linalg.solve(a, b)
178
+ else:
179
+ return torch.solve(b, a)[0]
180
+
181
+
182
+ def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
183
+ if not isinstance(seq, (list, tuple)):
184
+ raise TypeError(
185
+ "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
186
+ "not Tensor"
187
+ )
188
+ if isinstance(seq[0], ComplexTensor):
189
+ return FC.stack(seq, *args, **kwargs)
190
+ else:
191
+ return torch.stack(seq, *args, **kwargs)
look2hear/utils/get_layer_from_string.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+
3
+ import torch
4
+
5
+
6
+ def get_layer(l_name, library=torch.nn):
7
+ """Return layer object handler from library e.g. from torch.nn
8
+
9
+ E.g. if l_name=="elu", returns torch.nn.ELU.
10
+
11
+ Args:
12
+ l_name (string): Case insensitive name for layer in library (e.g. .'elu').
13
+ library (module): Name of library/module where to search for object handler
14
+ with l_name e.g. "torch.nn".
15
+
16
+ Returns:
17
+ layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU)
18
+
19
+ """
20
+
21
+ all_torch_layers = [x for x in dir(torch.nn)]
22
+ match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
23
+ if len(match) == 0:
24
+ close_matches = difflib.get_close_matches(
25
+ l_name, [x.lower() for x in all_torch_layers]
26
+ )
27
+ raise NotImplementedError(
28
+ "Layer with name {} not found in {}.\n Closest matches: {}".format(
29
+ l_name, str(library), close_matches
30
+ )
31
+ )
32
+ elif len(match) > 1:
33
+ close_matches = difflib.get_close_matches(
34
+ l_name, [x.lower() for x in all_torch_layers]
35
+ )
36
+ raise NotImplementedError(
37
+ "Multiple matchs for layer with name {} not found in {}.\n "
38
+ "All matches: {}".format(l_name, str(library), close_matches)
39
+ )
40
+ else:
41
+ # valid
42
+ layer_handler = getattr(library, match[0])
43
+ return layer_handler
look2hear/utils/inversible_interface.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+
7
+ class InversibleInterface(ABC):
8
+ @abstractmethod
9
+ def inverse(
10
+ self, input: torch.Tensor, input_lengths: torch.Tensor = None
11
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
12
+ # return output, output_lengths
13
+ raise NotImplementedError
look2hear/utils/lightning_utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2022-05-27 10:27:56
4
+ # Email: [email protected]
5
+ # LastEditTime: 2022-06-13 12:11:15
6
+ ###
7
+ from rich import print
8
+ from dataclasses import dataclass
9
+ from pytorch_lightning.utilities import rank_zero_only
10
+ from typing import Union
11
+ from pytorch_lightning.callbacks.progress.rich_progress import *
12
+ from rich.console import Console, RenderableType
13
+ from rich.progress_bar import ProgressBar
14
+ from rich.style import Style
15
+ from rich.text import Text
16
+ from rich.progress import (
17
+ BarColumn,
18
+ DownloadColumn,
19
+ Progress,
20
+ TaskID,
21
+ TextColumn,
22
+ TimeRemainingColumn,
23
+ TransferSpeedColumn,
24
+ ProgressColumn
25
+ )
26
+ from rich import print, reconfigure
27
+
28
+ @rank_zero_only
29
+ def print_only(message: str):
30
+ print(message)
31
+
32
+ @dataclass
33
+ class RichProgressBarTheme:
34
+ """Styles to associate to different base components.
35
+
36
+ Args:
37
+ description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
38
+ progress_bar: Style for the bar in progress.
39
+ progress_bar_finished: Style for the finished progress bar.
40
+ progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
41
+ batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
42
+ time: Style for the processed time and estimate time remaining.
43
+ processing_speed: Style for the speed of the batches being processed.
44
+ metrics: Style for the metrics
45
+
46
+ https://rich.readthedocs.io/en/stable/style.html
47
+ """
48
+
49
+ description: Union[str, Style] = "#FF4500"
50
+ progress_bar: Union[str, Style] = "#f92672"
51
+ progress_bar_finished: Union[str, Style] = "#b7cc8a"
52
+ progress_bar_pulse: Union[str, Style] = "#f92672"
53
+ batch_progress: Union[str, Style] = "#fc608a"
54
+ time: Union[str, Style] = "#45ada2"
55
+ processing_speed: Union[str, Style] = "#DC143C"
56
+ metrics: Union[str, Style] = "#228B22"
57
+
58
+ class BatchesProcessedColumn(ProgressColumn):
59
+ def __init__(self, style: Union[str, Style]):
60
+ self.style = style
61
+ super().__init__()
62
+
63
+ def render(self, task) -> RenderableType:
64
+ total = task.total if task.total != float("inf") else "--"
65
+ return Text(f"{int(task.completed)}/{int(total)}", style=self.style)
66
+
67
+ class MyMetricsTextColumn(ProgressColumn):
68
+ """A column containing text."""
69
+
70
+ def __init__(self, style):
71
+ self._tasks = {}
72
+ self._current_task_id = 0
73
+ self._metrics = {}
74
+ self._style = style
75
+ super().__init__()
76
+
77
+ def update(self, metrics):
78
+ # Called when metrics are ready to be rendered.
79
+ # This is to prevent render from causing deadlock issues by requesting metrics
80
+ # in separate threads.
81
+ self._metrics = metrics
82
+
83
+ def render(self, task) -> Text:
84
+ text = ""
85
+ for k, v in self._metrics.items():
86
+ text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
87
+ return Text(text, justify="left", style=self._style)
88
+
89
+ class MyRichProgressBar(RichProgressBar):
90
+ """A progress bar prints metrics at the end of each epoch
91
+ """
92
+
93
+ def _init_progress(self, trainer):
94
+ if self.is_enabled and (self.progress is None or self._progress_stopped):
95
+ self._reset_progress_bar_ids()
96
+ reconfigure(**self._console_kwargs)
97
+ # file = open("/home/likai/data/Look2Hear/Experiments/run_logs/EdgeFRCNN-Noncausal.log", 'w')
98
+ self._console: Console = Console(force_terminal=True)
99
+ self._console.clear_live()
100
+ self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
101
+ self.progress = CustomProgress(
102
+ *self.configure_columns(trainer),
103
+ self._metric_component,
104
+ auto_refresh=False,
105
+ disable=self.is_disabled,
106
+ console=self._console,
107
+ )
108
+ self.progress.start()
109
+ # progress has started
110
+ self._progress_stopped = False
look2hear/utils/nets_utils.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Network related utility tools."""
4
+
5
+ import logging
6
+ from typing import Dict
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def to_device(m, x):
13
+ """Send tensor into the device of the module.
14
+
15
+ Args:
16
+ m (torch.nn.Module): Torch module.
17
+ x (Tensor): Torch tensor.
18
+
19
+ Returns:
20
+ Tensor: Torch tensor located in the same place as torch module.
21
+
22
+ """
23
+ if isinstance(m, torch.nn.Module):
24
+ device = next(m.parameters()).device
25
+ elif isinstance(m, torch.Tensor):
26
+ device = m.device
27
+ else:
28
+ raise TypeError(
29
+ "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
30
+ )
31
+ return x.to(device)
32
+
33
+
34
+ def pad_list(xs, pad_value):
35
+ """Perform padding for the list of tensors.
36
+
37
+ Args:
38
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
39
+ pad_value (float): Value for padding.
40
+
41
+ Returns:
42
+ Tensor: Padded tensor (B, Tmax, `*`).
43
+
44
+ Examples:
45
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
46
+ >>> x
47
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
48
+ >>> pad_list(x, 0)
49
+ tensor([[1., 1., 1., 1.],
50
+ [1., 1., 0., 0.],
51
+ [1., 0., 0., 0.]])
52
+
53
+ """
54
+ n_batch = len(xs)
55
+ max_len = max(x.size(0) for x in xs)
56
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
57
+
58
+ for i in range(n_batch):
59
+ pad[i, : xs[i].size(0)] = xs[i]
60
+
61
+ return pad
62
+
63
+
64
+ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
65
+ """Make mask tensor containing indices of padded part.
66
+
67
+ Args:
68
+ lengths (LongTensor or List): Batch of lengths (B,).
69
+ xs (Tensor, optional): The reference tensor.
70
+ If set, masks will be the same shape as this tensor.
71
+ length_dim (int, optional): Dimension indicator of the above tensor.
72
+ See the example.
73
+
74
+ Returns:
75
+ Tensor: Mask tensor containing indices of padded part.
76
+ dtype=torch.uint8 in PyTorch 1.2-
77
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
78
+
79
+ Examples:
80
+ With only lengths.
81
+
82
+ >>> lengths = [5, 3, 2]
83
+ >>> make_pad_mask(lengths)
84
+ masks = [[0, 0, 0, 0 ,0],
85
+ [0, 0, 0, 1, 1],
86
+ [0, 0, 1, 1, 1]]
87
+
88
+ With the reference tensor.
89
+
90
+ >>> xs = torch.zeros((3, 2, 4))
91
+ >>> make_pad_mask(lengths, xs)
92
+ tensor([[[0, 0, 0, 0],
93
+ [0, 0, 0, 0]],
94
+ [[0, 0, 0, 1],
95
+ [0, 0, 0, 1]],
96
+ [[0, 0, 1, 1],
97
+ [0, 0, 1, 1]]], dtype=torch.uint8)
98
+ >>> xs = torch.zeros((3, 2, 6))
99
+ >>> make_pad_mask(lengths, xs)
100
+ tensor([[[0, 0, 0, 0, 0, 1],
101
+ [0, 0, 0, 0, 0, 1]],
102
+ [[0, 0, 0, 1, 1, 1],
103
+ [0, 0, 0, 1, 1, 1]],
104
+ [[0, 0, 1, 1, 1, 1],
105
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
106
+
107
+ With the reference tensor and dimension indicator.
108
+
109
+ >>> xs = torch.zeros((3, 6, 6))
110
+ >>> make_pad_mask(lengths, xs, 1)
111
+ tensor([[[0, 0, 0, 0, 0, 0],
112
+ [0, 0, 0, 0, 0, 0],
113
+ [0, 0, 0, 0, 0, 0],
114
+ [0, 0, 0, 0, 0, 0],
115
+ [0, 0, 0, 0, 0, 0],
116
+ [1, 1, 1, 1, 1, 1]],
117
+ [[0, 0, 0, 0, 0, 0],
118
+ [0, 0, 0, 0, 0, 0],
119
+ [0, 0, 0, 0, 0, 0],
120
+ [1, 1, 1, 1, 1, 1],
121
+ [1, 1, 1, 1, 1, 1],
122
+ [1, 1, 1, 1, 1, 1]],
123
+ [[0, 0, 0, 0, 0, 0],
124
+ [0, 0, 0, 0, 0, 0],
125
+ [1, 1, 1, 1, 1, 1],
126
+ [1, 1, 1, 1, 1, 1],
127
+ [1, 1, 1, 1, 1, 1],
128
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
129
+ >>> make_pad_mask(lengths, xs, 2)
130
+ tensor([[[0, 0, 0, 0, 0, 1],
131
+ [0, 0, 0, 0, 0, 1],
132
+ [0, 0, 0, 0, 0, 1],
133
+ [0, 0, 0, 0, 0, 1],
134
+ [0, 0, 0, 0, 0, 1],
135
+ [0, 0, 0, 0, 0, 1]],
136
+ [[0, 0, 0, 1, 1, 1],
137
+ [0, 0, 0, 1, 1, 1],
138
+ [0, 0, 0, 1, 1, 1],
139
+ [0, 0, 0, 1, 1, 1],
140
+ [0, 0, 0, 1, 1, 1],
141
+ [0, 0, 0, 1, 1, 1]],
142
+ [[0, 0, 1, 1, 1, 1],
143
+ [0, 0, 1, 1, 1, 1],
144
+ [0, 0, 1, 1, 1, 1],
145
+ [0, 0, 1, 1, 1, 1],
146
+ [0, 0, 1, 1, 1, 1],
147
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
148
+
149
+ """
150
+ if length_dim == 0:
151
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
152
+
153
+ if not isinstance(lengths, list):
154
+ lengths = lengths.long().tolist()
155
+
156
+ bs = int(len(lengths))
157
+ if maxlen is None:
158
+ if xs is None:
159
+ maxlen = int(max(lengths))
160
+ else:
161
+ maxlen = xs.size(length_dim)
162
+ else:
163
+ assert xs is None
164
+ assert maxlen >= int(max(lengths))
165
+
166
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
167
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
168
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
169
+ mask = seq_range_expand >= seq_length_expand
170
+
171
+ if xs is not None:
172
+ assert xs.size(0) == bs, (xs.size(0), bs)
173
+
174
+ if length_dim < 0:
175
+ length_dim = xs.dim() + length_dim
176
+ # ind = (:, None, ..., None, :, , None, ..., None)
177
+ ind = tuple(
178
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
179
+ )
180
+ mask = mask[ind].expand_as(xs).to(xs.device)
181
+ return mask
182
+
183
+
184
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
185
+ """Make mask tensor containing indices of non-padded part.
186
+
187
+ Args:
188
+ lengths (LongTensor or List): Batch of lengths (B,).
189
+ xs (Tensor, optional): The reference tensor.
190
+ If set, masks will be the same shape as this tensor.
191
+ length_dim (int, optional): Dimension indicator of the above tensor.
192
+ See the example.
193
+
194
+ Returns:
195
+ ByteTensor: mask tensor containing indices of padded part.
196
+ dtype=torch.uint8 in PyTorch 1.2-
197
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
198
+
199
+ Examples:
200
+ With only lengths.
201
+
202
+ >>> lengths = [5, 3, 2]
203
+ >>> make_non_pad_mask(lengths)
204
+ masks = [[1, 1, 1, 1 ,1],
205
+ [1, 1, 1, 0, 0],
206
+ [1, 1, 0, 0, 0]]
207
+
208
+ With the reference tensor.
209
+
210
+ >>> xs = torch.zeros((3, 2, 4))
211
+ >>> make_non_pad_mask(lengths, xs)
212
+ tensor([[[1, 1, 1, 1],
213
+ [1, 1, 1, 1]],
214
+ [[1, 1, 1, 0],
215
+ [1, 1, 1, 0]],
216
+ [[1, 1, 0, 0],
217
+ [1, 1, 0, 0]]], dtype=torch.uint8)
218
+ >>> xs = torch.zeros((3, 2, 6))
219
+ >>> make_non_pad_mask(lengths, xs)
220
+ tensor([[[1, 1, 1, 1, 1, 0],
221
+ [1, 1, 1, 1, 1, 0]],
222
+ [[1, 1, 1, 0, 0, 0],
223
+ [1, 1, 1, 0, 0, 0]],
224
+ [[1, 1, 0, 0, 0, 0],
225
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
226
+
227
+ With the reference tensor and dimension indicator.
228
+
229
+ >>> xs = torch.zeros((3, 6, 6))
230
+ >>> make_non_pad_mask(lengths, xs, 1)
231
+ tensor([[[1, 1, 1, 1, 1, 1],
232
+ [1, 1, 1, 1, 1, 1],
233
+ [1, 1, 1, 1, 1, 1],
234
+ [1, 1, 1, 1, 1, 1],
235
+ [1, 1, 1, 1, 1, 1],
236
+ [0, 0, 0, 0, 0, 0]],
237
+ [[1, 1, 1, 1, 1, 1],
238
+ [1, 1, 1, 1, 1, 1],
239
+ [1, 1, 1, 1, 1, 1],
240
+ [0, 0, 0, 0, 0, 0],
241
+ [0, 0, 0, 0, 0, 0],
242
+ [0, 0, 0, 0, 0, 0]],
243
+ [[1, 1, 1, 1, 1, 1],
244
+ [1, 1, 1, 1, 1, 1],
245
+ [0, 0, 0, 0, 0, 0],
246
+ [0, 0, 0, 0, 0, 0],
247
+ [0, 0, 0, 0, 0, 0],
248
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
249
+ >>> make_non_pad_mask(lengths, xs, 2)
250
+ tensor([[[1, 1, 1, 1, 1, 0],
251
+ [1, 1, 1, 1, 1, 0],
252
+ [1, 1, 1, 1, 1, 0],
253
+ [1, 1, 1, 1, 1, 0],
254
+ [1, 1, 1, 1, 1, 0],
255
+ [1, 1, 1, 1, 1, 0]],
256
+ [[1, 1, 1, 0, 0, 0],
257
+ [1, 1, 1, 0, 0, 0],
258
+ [1, 1, 1, 0, 0, 0],
259
+ [1, 1, 1, 0, 0, 0],
260
+ [1, 1, 1, 0, 0, 0],
261
+ [1, 1, 1, 0, 0, 0]],
262
+ [[1, 1, 0, 0, 0, 0],
263
+ [1, 1, 0, 0, 0, 0],
264
+ [1, 1, 0, 0, 0, 0],
265
+ [1, 1, 0, 0, 0, 0],
266
+ [1, 1, 0, 0, 0, 0],
267
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
268
+
269
+ """
270
+ return ~make_pad_mask(lengths, xs, length_dim)
271
+
272
+
273
+ def mask_by_length(xs, lengths, fill=0):
274
+ """Mask tensor according to length.
275
+
276
+ Args:
277
+ xs (Tensor): Batch of input tensor (B, `*`).
278
+ lengths (LongTensor or List): Batch of lengths (B,).
279
+ fill (int or float): Value to fill masked part.
280
+
281
+ Returns:
282
+ Tensor: Batch of masked input tensor (B, `*`).
283
+
284
+ Examples:
285
+ >>> x = torch.arange(5).repeat(3, 1) + 1
286
+ >>> x
287
+ tensor([[1, 2, 3, 4, 5],
288
+ [1, 2, 3, 4, 5],
289
+ [1, 2, 3, 4, 5]])
290
+ >>> lengths = [5, 3, 2]
291
+ >>> mask_by_length(x, lengths)
292
+ tensor([[1, 2, 3, 4, 5],
293
+ [1, 2, 3, 0, 0],
294
+ [1, 2, 0, 0, 0]])
295
+
296
+ """
297
+ assert xs.size(0) == len(lengths)
298
+ ret = xs.data.new(*xs.size()).fill_(fill)
299
+ for i, l in enumerate(lengths):
300
+ ret[i, :l] = xs[i, :l]
301
+ return ret
302
+
303
+
304
+ def th_accuracy(pad_outputs, pad_targets, ignore_label):
305
+ """Calculate accuracy.
306
+
307
+ Args:
308
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
309
+ pad_targets (LongTensor): Target label tensors (B, Lmax, D).
310
+ ignore_label (int): Ignore label id.
311
+
312
+ Returns:
313
+ float: Accuracy value (0.0 - 1.0).
314
+
315
+ """
316
+ pad_pred = pad_outputs.view(
317
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
318
+ ).argmax(2)
319
+ mask = pad_targets != ignore_label
320
+ numerator = torch.sum(
321
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
322
+ )
323
+ denominator = torch.sum(mask)
324
+ return float(numerator) / float(denominator)
325
+
326
+
327
+ def to_torch_tensor(x):
328
+ """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
329
+
330
+ Args:
331
+ x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
332
+
333
+ Returns:
334
+ Tensor or ComplexTensor: Type converted inputs.
335
+
336
+ Examples:
337
+ >>> xs = np.ones(3, dtype=np.float32)
338
+ >>> xs = to_torch_tensor(xs)
339
+ tensor([1., 1., 1.])
340
+ >>> xs = torch.ones(3, 4, 5)
341
+ >>> assert to_torch_tensor(xs) is xs
342
+ >>> xs = {'real': xs, 'imag': xs}
343
+ >>> to_torch_tensor(xs)
344
+ ComplexTensor(
345
+ Real:
346
+ tensor([1., 1., 1.])
347
+ Imag;
348
+ tensor([1., 1., 1.])
349
+ )
350
+
351
+ """
352
+ # If numpy, change to torch tensor
353
+ if isinstance(x, np.ndarray):
354
+ if x.dtype.kind == "c":
355
+ # Dynamically importing because torch_complex requires python3
356
+ from torch_complex.tensor import ComplexTensor
357
+
358
+ return ComplexTensor(x)
359
+ else:
360
+ return torch.from_numpy(x)
361
+
362
+ # If {'real': ..., 'imag': ...}, convert to ComplexTensor
363
+ elif isinstance(x, dict):
364
+ # Dynamically importing because torch_complex requires python3
365
+ from torch_complex.tensor import ComplexTensor
366
+
367
+ if "real" not in x or "imag" not in x:
368
+ raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
369
+ # Relative importing because of using python3 syntax
370
+ return ComplexTensor(x["real"], x["imag"])
371
+
372
+ # If torch.Tensor, as it is
373
+ elif isinstance(x, torch.Tensor):
374
+ return x
375
+
376
+ else:
377
+ error = (
378
+ "x must be numpy.ndarray, torch.Tensor or a dict like "
379
+ "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
380
+ "but got {}".format(type(x))
381
+ )
382
+ try:
383
+ from torch_complex.tensor import ComplexTensor
384
+ except Exception:
385
+ # If PY2
386
+ raise ValueError(error)
387
+ else:
388
+ # If PY3
389
+ if isinstance(x, ComplexTensor):
390
+ return x
391
+ else:
392
+ raise ValueError(error)
393
+
394
+
395
+ def get_subsample(train_args, mode, arch):
396
+ """Parse the subsampling factors from the args for the specified `mode` and `arch`.
397
+
398
+ Args:
399
+ train_args: argument Namespace containing options.
400
+ mode: one of ('asr', 'mt', 'st')
401
+ arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
402
+
403
+ Returns:
404
+ np.ndarray / List[np.ndarray]: subsampling factors.
405
+ """
406
+ if arch == "transformer":
407
+ return np.array([1])
408
+
409
+ elif mode == "mt" and arch == "rnn":
410
+ # +1 means input (+1) and layers outputs (train_args.elayer)
411
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
412
+ logging.warning("Subsampling is not performed for machine translation.")
413
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
414
+ return subsample
415
+
416
+ elif (
417
+ (mode == "asr" and arch in ("rnn", "rnn-t"))
418
+ or (mode == "mt" and arch == "rnn")
419
+ or (mode == "st" and arch == "rnn")
420
+ ):
421
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
422
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
423
+ ss = train_args.subsample.split("_")
424
+ for j in range(min(train_args.elayers + 1, len(ss))):
425
+ subsample[j] = int(ss[j])
426
+ else:
427
+ logging.warning(
428
+ "Subsampling is not performed for vgg*. "
429
+ "It is performed in max pooling layers at CNN."
430
+ )
431
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
432
+ return subsample
433
+
434
+ elif mode == "asr" and arch == "rnn_mix":
435
+ subsample = np.ones(
436
+ train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64
437
+ )
438
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
439
+ ss = train_args.subsample.split("_")
440
+ for j in range(
441
+ min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
442
+ ):
443
+ subsample[j] = int(ss[j])
444
+ else:
445
+ logging.warning(
446
+ "Subsampling is not performed for vgg*. "
447
+ "It is performed in max pooling layers at CNN."
448
+ )
449
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
450
+ return subsample
451
+
452
+ elif mode == "asr" and arch == "rnn_mulenc":
453
+ subsample_list = []
454
+ for idx in range(train_args.num_encs):
455
+ subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64)
456
+ if train_args.etype[idx].endswith("p") and not train_args.etype[
457
+ idx
458
+ ].startswith("vgg"):
459
+ ss = train_args.subsample[idx].split("_")
460
+ for j in range(min(train_args.elayers[idx] + 1, len(ss))):
461
+ subsample[j] = int(ss[j])
462
+ else:
463
+ logging.warning(
464
+ "Encoder %d: Subsampling is not performed for vgg*. "
465
+ "It is performed in max pooling layers at CNN.",
466
+ idx + 1,
467
+ )
468
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
469
+ subsample_list.append(subsample)
470
+ return subsample_list
471
+
472
+ else:
473
+ raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
474
+
475
+
476
+ def rename_state_dict(
477
+ old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
478
+ ):
479
+ """Replace keys of old prefix with new prefix in state dict."""
480
+ # need this list not to break the dict iterator
481
+ old_keys = [k for k in state_dict if k.startswith(old_prefix)]
482
+ if len(old_keys) > 0:
483
+ logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
484
+ for k in old_keys:
485
+ v = state_dict.pop(k)
486
+ new_k = k.replace(old_prefix, new_prefix)
487
+ state_dict[new_k] = v
488
+
489
+
490
+ def get_activation(act):
491
+ """Return activation function."""
492
+ # Lazy load to avoid unused import
493
+ from espnet.nets.pytorch_backend.conformer.swish import Swish
494
+
495
+ activation_funcs = {
496
+ "hardtanh": torch.nn.Hardtanh,
497
+ "tanh": torch.nn.Tanh,
498
+ "relu": torch.nn.ReLU,
499
+ "selu": torch.nn.SELU,
500
+ "swish": Swish,
501
+ }
502
+
503
+ return activation_funcs[act]()
look2hear/utils/parser_utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-20 00:36:46
4
+ # LastEditors: Please set LastEditors
5
+ # LastEditTime: 2024-01-22 03:02:57
6
+ ###
7
+ import sys
8
+ import argparse
9
+ import importlib
10
+ from omegaconf import DictConfig
11
+
12
+ def prepare_parser_from_dict(dic, parser=None):
13
+ """Prepare an argparser from a dictionary.
14
+
15
+ Args:
16
+ dic (dict): Two-level config dictionary with unique bottom-level keys.
17
+ parser (argparse.ArgumentParser, optional): If a parser already
18
+ exists, add the keys from the dictionary on the top of it.
19
+
20
+ Returns:
21
+ argparse.ArgumentParser:
22
+ Parser instance with groups corresponding to the first level keys
23
+ and arguments corresponding to the second level keys with default
24
+ values given by the values.
25
+ """
26
+
27
+ def standardized_entry_type(value):
28
+ """If the default value is None, replace NoneType by str_int_float.
29
+ If the default value is boolean, look for boolean strings."""
30
+ if value is None:
31
+ return str_int_float
32
+ if isinstance(str2bool(value), bool):
33
+ return str2bool_arg
34
+ return type(value)
35
+
36
+ if parser is None:
37
+ parser = argparse.ArgumentParser()
38
+ for k in dic.keys():
39
+ group = parser.add_argument_group(k)
40
+ if isinstance(dic[k], list):
41
+ entry_type = standardized_entry_type(dic[k])
42
+ group.add_argument("--" + k, default=dic[k], type=entry_type)
43
+ elif isinstance(dic[k], dict):
44
+ for kk in dic[k].keys():
45
+ entry_type = standardized_entry_type(dic[k][kk])
46
+ group.add_argument("--" + kk, default=dic[k][kk], type=entry_type)
47
+ elif isinstance(dic[k], str):
48
+ entry_type = standardized_entry_type(dic[k])
49
+ group.add_argument("--" + k, default=dic[k], type=entry_type)
50
+ return parser
51
+
52
+
53
+ def str_int_float(value):
54
+ """Type to convert strings to int, float (in this order) if possible.
55
+
56
+ Args:
57
+ value (str): Value to convert.
58
+
59
+ Returns:
60
+ int, float, str: Converted value.
61
+ """
62
+ if isint(value):
63
+ return int(value)
64
+ if isfloat(value):
65
+ return float(value)
66
+ elif isinstance(value, str):
67
+ return value
68
+
69
+
70
+ def str2bool(value):
71
+ """Type to convert strings to Boolean (returns input if not boolean)"""
72
+ if not isinstance(value, str):
73
+ return value
74
+ if value.lower() in ("yes", "true", "y", "1"):
75
+ return True
76
+ elif value.lower() in ("no", "false", "n", "0"):
77
+ return False
78
+ else:
79
+ return value
80
+
81
+
82
+ def str2bool_arg(value):
83
+ """Argparse type to convert strings to Boolean"""
84
+ value = str2bool(value)
85
+ if isinstance(value, bool):
86
+ return value
87
+ raise argparse.ArgumentTypeError("Boolean value expected.")
88
+
89
+
90
+ def isfloat(value):
91
+ """Computes whether `value` can be cast to a float.
92
+
93
+ Args:
94
+ value (str): Value to check.
95
+
96
+ Returns:
97
+ bool: Whether `value` can be cast to a float.
98
+
99
+ """
100
+ try:
101
+ float(value)
102
+ return True
103
+ except ValueError:
104
+ return False
105
+
106
+
107
+ def isint(value):
108
+ """Computes whether `value` can be cast to an int
109
+
110
+ Args:
111
+ value (str): Value to check.
112
+
113
+ Returns:
114
+ bool: Whether `value` can be cast to an int.
115
+
116
+ """
117
+ try:
118
+ int(value)
119
+ return True
120
+ except ValueError:
121
+ return False
122
+
123
+
124
+ def parse_args_as_dict(parser, return_plain_args=False, args=None):
125
+ """Get a dict of dicts out of process `parser.parse_args()`
126
+
127
+ Top-level keys corresponding to groups and bottom-level keys corresponding
128
+ to arguments. Under `'main_args'`, the arguments which don't belong to a
129
+ argparse group (i.e main arguments defined before parsing from a dict) can
130
+ be found.
131
+
132
+ Args:
133
+ parser (argparse.ArgumentParser): ArgumentParser instance containing
134
+ groups. Output of `prepare_parser_from_dict`.
135
+ return_plain_args (bool): Whether to return the output or
136
+ `parser.parse_args()`.
137
+ args (list): List of arguments as read from the command line.
138
+ Used for unit testing.
139
+
140
+ Returns:
141
+ dict:
142
+ Dictionary of dictionaries containing the arguments. Optionally the
143
+ direct output `parser.parse_args()`.
144
+ """
145
+ args = parser.parse_args(args=args)
146
+ args_dic = {}
147
+ for group in parser._action_groups:
148
+ group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
149
+ args_dic[group.title] = group_dict
150
+ if sys.version_info.minor == 10:
151
+ args_dic["main_args"] = args_dic["positional arguments"]
152
+ del args_dic["positional arguments"]
153
+ else:
154
+ args_dic["main_args"] = args_dic["optional arguments"]
155
+ del args_dic["optional arguments"]
156
+ if return_plain_args:
157
+ return args_dic, args
158
+ return args_dic
159
+
160
+ def instantiate(config, **kwargs):
161
+ if '__target__' in config:
162
+ module_path, class_name = config['__target__'].rsplit('.', 1)
163
+ module = importlib.import_module(module_path)
164
+ cls = getattr(module, class_name)
165
+ # 先处理嵌套的配置
166
+ params = {}
167
+ for key, value in config.items():
168
+ if key != '__target__':
169
+ if isinstance(value, DictConfig) and '__target__' in value:
170
+ params[key] = instantiate(value)
171
+ else:
172
+ params[key] = value
173
+ # 添加额外的关键字参数
174
+ params.update(kwargs)
175
+ return cls(**params)
176
+ else:
177
+ # 对于不包含 '__target__' 的字典,递归处理其每个值
178
+ return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()}
look2hear/utils/pylogger.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Mapping, Optional
3
+
4
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
+
6
+
7
+ class RankedLogger(logging.LoggerAdapter):
8
+ """A multi-GPU-friendly python command line logger."""
9
+
10
+ def __init__(
11
+ self,
12
+ name: str = __name__,
13
+ rank_zero_only: bool = False,
14
+ extra: Optional[Mapping[str, object]] = None,
15
+ log_file: str = "log.txt", # 添加日志文件名参数
16
+ ) -> None:
17
+ logger = logging.getLogger(name)
18
+
19
+ # 设置日志格式
20
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21
+
22
+ # 添加文件处理器
23
+ file_handler = logging.FileHandler(log_file)
24
+ file_handler.setFormatter(formatter)
25
+ logger.addHandler(file_handler)
26
+
27
+ super().__init__(logger=logger, extra=extra)
28
+ self.rank_zero_only = rank_zero_only
29
+
30
+ def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
31
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
32
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
33
+ occur on that rank/process.
34
+
35
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
36
+ :param msg: The message to log.
37
+ :param rank: The rank to log at.
38
+ :param args: Additional args to pass to the underlying logging function.
39
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
40
+ """
41
+ if self.isEnabledFor(level):
42
+ msg, kwargs = self.process(msg, kwargs)
43
+ current_rank = getattr(rank_zero_only, "rank", None)
44
+ if current_rank is None:
45
+ raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
46
+ msg = rank_prefixed_message(msg, current_rank)
47
+ if self.rank_zero_only:
48
+ if current_rank == 0:
49
+ self.logger.log(level, msg, *args, **kwargs)
50
+ else:
51
+ if rank is None:
52
+ self.logger.log(level, msg, *args, **kwargs)
53
+ elif current_rank == rank:
54
+ self.logger.log(level, msg, *args, **kwargs)
look2hear/utils/separator.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-18 16:32:50
4
+ # LastEditors: Kai Li
5
+ # LastEditTime: 2021-06-19 01:02:04
6
+ ###
7
+ import os
8
+ import warnings
9
+ import torch
10
+ import numpy as np
11
+ import soundfile as sf
12
+
13
+
14
+ def get_device(tensor_or_module, default=None):
15
+ if hasattr(tensor_or_module, "device"):
16
+ return tensor_or_module.device
17
+ elif hasattr(tensor_or_module, "parameters"):
18
+ return next(tensor_or_module.parameters()).device
19
+ elif default is None:
20
+ raise TypeError(
21
+ f"Don't know how to get device of {type(tensor_or_module)} object"
22
+ )
23
+ else:
24
+ return torch.device(default)
25
+
26
+
27
+ class Separator:
28
+ def forward_wav(self, wav, **kwargs):
29
+ raise NotImplementedError
30
+
31
+ def sample_rate(self):
32
+ raise NotImplementedError
33
+
34
+
35
+ def separate(model, wav, **kwargs):
36
+ if isinstance(wav, np.ndarray):
37
+ return numpy_separate(model, wav, **kwargs)
38
+ elif isinstance(wav, torch.Tensor):
39
+ return torch_separate(model, wav, **kwargs)
40
+ else:
41
+ raise ValueError(
42
+ f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
43
+ )
44
+
45
+
46
+ @torch.no_grad()
47
+ def torch_separate(model: Separator, wav: torch.Tensor, **kwargs) -> torch.Tensor:
48
+ """Core logic of `separate`."""
49
+ if model.in_channels is not None and wav.shape[-2] != model.in_channels:
50
+ raise RuntimeError(
51
+ f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels."
52
+ f"Please match the number of channels."
53
+ )
54
+ # Handle device placement
55
+ input_device = get_device(wav, default="cpu")
56
+ model_device = get_device(model, default="cpu")
57
+ wav = wav.to(model_device)
58
+ # Forward
59
+ separate_func = getattr(model, "forward_wav", model)
60
+ out_wavs = separate_func(wav, **kwargs)
61
+
62
+ # FIXME: for now this is the best we can do.
63
+ out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
64
+
65
+ # Back to input device (and numpy if necessary)
66
+ out_wavs = out_wavs.to(input_device)
67
+ return out_wavs
68
+
69
+
70
+ def numpy_separate(model: Separator, wav: np.ndarray, **kwargs) -> np.ndarray:
71
+ """Numpy interface to `separate`."""
72
+ wav = torch.from_numpy(wav)
73
+ out_wavs = torch_separate(model, wav, **kwargs)
74
+ out_wavs = out_wavs.data.numpy()
75
+ return out_wavs
76
+
77
+
78
+ def wav_chunk_inference(model, mixture_tensor, sr=16000, target_length=12.0, hop_length=4.0, batch_size=10, n_tracks=3):
79
+ """
80
+ Input:
81
+ mixture_tensor: Tensor, [nch, input_length]
82
+
83
+ Output:
84
+ all_target_tensor: Tensor, [nch, n_track, input_length]
85
+ """
86
+ batch_mixture = mixture_tensor
87
+
88
+ # split data into segments
89
+ batch_length = batch_mixture.shape[-1]
90
+
91
+ session = int(sr * target_length)
92
+ target = int(sr * target_length)
93
+ ignore = (session - target) // 2
94
+ hop = int(sr * hop_length)
95
+ tr_ratio = target_length / hop_length
96
+ if ignore > 0:
97
+ zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], ignore).type(batch_mixture.type()).to(batch_mixture.device)
98
+ batch_mixture_pad = torch.cat([zero_pad, batch_mixture, zero_pad], -1)
99
+ else:
100
+ batch_mixture_pad = batch_mixture
101
+ if target - hop > 0:
102
+ hop_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], target-hop).type(batch_mixture.type()).to(batch_mixture.device)
103
+ batch_mixture_pad = torch.cat([hop_pad, batch_mixture_pad, hop_pad], -1)
104
+
105
+ skip_idx = ignore + target - hop
106
+ zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], session).type(batch_mixture.type()).to(batch_mixture.device)
107
+ num_session = (batch_mixture_pad.shape[-1] - session) // hop + 2
108
+ all_target = torch.zeros(batch_mixture_pad.shape[0], n_tracks, batch_mixture_pad.shape[1], batch_mixture_pad.shape[2]).to(batch_mixture_pad.device)
109
+ all_input = []
110
+ all_segment_length = []
111
+
112
+ for i in range(num_session):
113
+ this_input = batch_mixture_pad[:,:,i*hop:i*hop+session]
114
+ segment_length = this_input.shape[-1]
115
+ if segment_length < session:
116
+ this_input = torch.cat([this_input, zero_pad[:,:,:session-segment_length]], -1)
117
+ all_input.append(this_input)
118
+ all_segment_length.append(segment_length)
119
+
120
+ all_input = torch.cat(all_input, 0)
121
+ num_batch = num_session // batch_size
122
+ if num_session % batch_size > 0:
123
+ num_batch += 1
124
+
125
+ for i in range(num_batch):
126
+
127
+ this_input = all_input[i*batch_size:(i+1)*batch_size]
128
+ actual_batch_size = this_input.shape[0]
129
+ with torch.no_grad():
130
+ est_target = model(this_input)
131
+ # print(est_target.shape)
132
+ for j in range(actual_batch_size):
133
+ this_est_target = est_target[j,:,:,:all_segment_length[i*batch_size+j]][:,:,ignore:ignore+target].unsqueeze(0)
134
+ all_target[:,:,:,ignore+(i*batch_size+j)*hop:ignore+(i*batch_size+j)*hop+target] += this_est_target
135
+
136
+ all_target = all_target[:,:,:,skip_idx:skip_idx+batch_length].contiguous() / tr_ratio
137
+
138
+ return all_target.squeeze(0)
look2hear/utils/stft.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Jian Wu
2
+ # License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as tf
10
+ import librosa.filters as filters
11
+
12
+ from typing import Optional, Tuple
13
+ from distutils.version import LooseVersion
14
+
15
+ EPSILON = float(np.finfo(np.float32).eps)
16
+ TORCH_VERSION = th.__version__
17
+
18
+ if TORCH_VERSION >= LooseVersion("1.7"):
19
+ from torch.fft import fft as fft_func
20
+ else:
21
+ pass
22
+
23
+
24
+ def export_jit(transform: nn.Module) -> nn.Module:
25
+ """
26
+ Export transform module for inference
27
+ """
28
+ export_out = [module for module in transform if module.exportable()]
29
+ return nn.Sequential(*export_out)
30
+
31
+
32
+ def init_window(wnd: str, frame_len: int, device: th.device = "cpu") -> th.Tensor:
33
+ """
34
+ Return window coefficient
35
+ Args:
36
+ wnd: window name
37
+ frame_len: length of the frame
38
+ """
39
+
40
+ def sqrthann(frame_len, periodic=True):
41
+ return th.hann_window(frame_len, periodic=periodic) ** 0.5
42
+
43
+ if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]:
44
+ raise RuntimeError(f"Unknown window type: {wnd}")
45
+
46
+ wnd_tpl = {
47
+ "sqrthann": sqrthann,
48
+ "hann": th.hann_window,
49
+ "hamm": th.hamming_window,
50
+ "blackman": th.blackman_window,
51
+ "bartlett": th.bartlett_window,
52
+ "rect": th.ones,
53
+ }
54
+ if wnd != "rect":
55
+ # match with librosa
56
+ c = wnd_tpl[wnd](frame_len, periodic=True)
57
+ else:
58
+ c = wnd_tpl[wnd](frame_len)
59
+ return c.to(device)
60
+
61
+
62
+ def init_kernel(
63
+ frame_len: int,
64
+ frame_hop: int,
65
+ window: th.Tensor,
66
+ round_pow_of_two: bool = True,
67
+ normalized: bool = False,
68
+ inverse: bool = False,
69
+ mode: str = "librosa",
70
+ ) -> Tuple[th.Tensor, th.Tensor]:
71
+ """
72
+ Return STFT kernels
73
+ Args:
74
+ frame_len: length of the frame
75
+ frame_hop: hop size between frames
76
+ window: window tensor
77
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
78
+ normalized: return normalized DFT matrix
79
+ inverse: return iDFT matrix
80
+ mode: framing mode (librosa or kaldi)
81
+ """
82
+ if mode not in ["librosa", "kaldi"]:
83
+ raise ValueError(f"Unsupported mode: {mode}")
84
+ # FFT size: B
85
+ if round_pow_of_two or mode == "kaldi":
86
+ fft_size = 2 ** math.ceil(math.log2(frame_len))
87
+ else:
88
+ fft_size = frame_len
89
+ # center padding window if needed
90
+ if mode == "librosa" and fft_size != frame_len:
91
+ lpad = (fft_size - frame_len) // 2
92
+ window = tf.pad(window, (lpad, fft_size - frame_len - lpad))
93
+ if normalized:
94
+ # make K^H * K = I
95
+ S = fft_size ** 0.5
96
+ else:
97
+ S = 1
98
+ # W x B x 2
99
+ if TORCH_VERSION >= LooseVersion("1.7"):
100
+ K = fft_func(th.eye(fft_size) / S, dim=-1)
101
+ K = th.stack([K.real, K.imag], dim=-1)
102
+ else:
103
+ I = th.stack([th.eye(fft_size), th.zeros(fft_size, fft_size)], dim=-1)
104
+ K = th.fft(I / S, 1)
105
+ if mode == "kaldi":
106
+ K = K[:frame_len]
107
+ if inverse and not normalized:
108
+ # to make K^H * K = I
109
+ K = K / fft_size
110
+ # 2 x B x W
111
+ K = th.transpose(K, 0, 2)
112
+ # 2B x 1 x W
113
+ K = th.reshape(K, (fft_size * 2, 1, K.shape[-1]))
114
+ return K.to(window.device), window
115
+
116
+
117
+ def mel_filter(
118
+ frame_len: int,
119
+ round_pow_of_two: bool = True,
120
+ num_bins: Optional[int] = None,
121
+ sr: int = 16000,
122
+ num_mels: int = 80,
123
+ fmin: float = 0.0,
124
+ fmax: Optional[float] = None,
125
+ norm: bool = False,
126
+ ) -> th.Tensor:
127
+ """
128
+ Return mel filter coefficients
129
+ Args:
130
+ frame_len: length of the frame
131
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
132
+ num_bins: number of the frequency bins produced by STFT
133
+ num_mels: number of the mel bands
134
+ fmin: lowest frequency (in Hz)
135
+ fmax: highest frequency (in Hz)
136
+ norm: normalize the mel filter coefficients
137
+ """
138
+ # FFT points
139
+ if num_bins is None:
140
+ N = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
141
+ else:
142
+ N = (num_bins - 1) * 2
143
+ # fmin & fmax
144
+ freq_upper = sr // 2
145
+ if fmax is None:
146
+ fmax = freq_upper
147
+ else:
148
+ fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper)
149
+ fmin = max(0, fmin)
150
+ # mel filter coefficients
151
+ mel = filters.mel(
152
+ sr,
153
+ N,
154
+ n_mels=num_mels,
155
+ fmax=fmax,
156
+ fmin=fmin,
157
+ htk=True,
158
+ norm="slaney" if norm else None,
159
+ )
160
+ # num_mels x (N // 2 + 1)
161
+ return th.tensor(mel, dtype=th.float32)
162
+
163
+
164
+ def speed_perturb_filter(
165
+ src_sr: int, dst_sr: int, cutoff_ratio: float = 0.95, num_zeros: int = 64
166
+ ) -> th.Tensor:
167
+ """
168
+ Return speed perturb filters, reference:
169
+ https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py
170
+ Args:
171
+ src_sr: sample rate of the source signal
172
+ dst_sr: sample rate of the target signal
173
+ Return:
174
+ weight (Tensor): coefficients of the filter
175
+ """
176
+ if src_sr == dst_sr:
177
+ raise ValueError(f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}")
178
+ gcd = math.gcd(src_sr, dst_sr)
179
+ src_sr = src_sr // gcd
180
+ dst_sr = dst_sr // gcd
181
+ if src_sr == 1 or dst_sr == 1:
182
+ raise ValueError("do not support integer downsample/upsample")
183
+ zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio
184
+ padding = 1 + int(num_zeros / zeros_per_block)
185
+ # dst_sr x src_sr x K
186
+ times = (
187
+ np.arange(dst_sr)[:, None, None] / float(dst_sr)
188
+ - np.arange(src_sr)[None, :, None] / float(src_sr)
189
+ - np.arange(2 * padding + 1)[None, None, :]
190
+ + padding
191
+ )
192
+ window = np.heaviside(1 - np.abs(times / padding), 0.0) * (
193
+ 0.5 + 0.5 * np.cos(times / padding * math.pi)
194
+ )
195
+ weight = np.sinc(times * zeros_per_block) * window * zeros_per_block / float(src_sr)
196
+ return th.tensor(weight, dtype=th.float32)
197
+
198
+
199
+ def splice_feature(
200
+ feats: th.Tensor, lctx: int = 1, rctx: int = 1, op: str = "cat"
201
+ ) -> th.Tensor:
202
+ """
203
+ Splice feature
204
+ Args:
205
+ feats (Tensor): N x ... x T x F, original feature
206
+ lctx: left context
207
+ rctx: right context
208
+ op: operator on feature context
209
+ Return:
210
+ splice (Tensor): feature with context padded
211
+ """
212
+ if lctx + rctx == 0:
213
+ return feats
214
+ if op not in ["cat", "stack"]:
215
+ raise ValueError(f"Unknown op for feature splicing: {op}")
216
+ # [N x ... x T x F, ...]
217
+ ctx = []
218
+ T = feats.shape[-2]
219
+ for c in range(-lctx, rctx + 1):
220
+ idx = th.arange(c, c + T, device=feats.device, dtype=th.int64)
221
+ idx = th.clamp(idx, min=0, max=T - 1)
222
+ ctx.append(th.index_select(feats, -2, idx))
223
+ if op == "cat":
224
+ # N x ... x T x FD
225
+ splice = th.cat(ctx, -1)
226
+ else:
227
+ # N x ... x T x F x D
228
+ splice = th.stack(ctx, -1)
229
+ return splice
230
+
231
+
232
+ def _forward_stft(
233
+ wav: th.Tensor,
234
+ kernel: th.Tensor,
235
+ window: th.Tensor,
236
+ return_polar: bool = False,
237
+ pre_emphasis: float = 0,
238
+ frame_hop: int = 256,
239
+ onesided: bool = False,
240
+ center: bool = False,
241
+ eps: float = EPSILON,
242
+ ) -> th.Tensor:
243
+ """
244
+ STFT function implemented by conv1d (not efficient, but we don't care during training)
245
+ Args:
246
+ wav (Tensor): N x (C) x S
247
+ kernel (Tensor): STFT transform kernels, from init_kernel(...)
248
+ return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
249
+ pre_emphasis: factor of preemphasis
250
+ frame_hop: frame hop size in number samples
251
+ onesided: return half FFT bins
252
+ center: if true, we assumed to have centered frames
253
+ Return:
254
+ transform (Tensor): STFT transform results
255
+ """
256
+ wav_dim = wav.dim()
257
+ if wav_dim not in [2, 3]:
258
+ raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
259
+ # if N x S, reshape N x 1 x S
260
+ # else: reshape NC x 1 x S
261
+ N, S = wav.shape[0], wav.shape[-1]
262
+ wav = wav.view(-1, 1, S)
263
+ # NC x 1 x S+2P
264
+ if center:
265
+ pad = kernel.shape[-1] // 2
266
+ # NOTE: match with librosa
267
+ wav = tf.pad(wav, (pad, pad), mode="reflect")
268
+ # STFT
269
+ kernel = kernel * window
270
+ if pre_emphasis > 0:
271
+ # NC x W x T
272
+ frames = tf.unfold(
273
+ wav[:, None], (1, kernel.shape[-1]), stride=frame_hop, padding=0
274
+ )
275
+ # follow Kaldi's Preemphasize
276
+ frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1]
277
+ frames[:, 0] *= 1 - pre_emphasis
278
+ # 1 x 2B x W, NC x W x T, NC x 2B x T
279
+ packed = th.matmul(kernel[:, 0][None, ...], frames)
280
+ else:
281
+ packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0)
282
+ # NC x 2B x T => N x C x 2B x T
283
+ if wav_dim == 3:
284
+ packed = packed.view(N, -1, packed.shape[-2], packed.shape[-1])
285
+ # N x (C) x B x T
286
+ real, imag = th.chunk(packed, 2, dim=-2)
287
+ # N x (C) x B/2+1 x T
288
+ if onesided:
289
+ num_bins = kernel.shape[0] // 4 + 1
290
+ real = real[..., :num_bins, :]
291
+ imag = imag[..., :num_bins, :]
292
+ if return_polar:
293
+ mag = (real ** 2 + imag ** 2 + eps) ** 0.5
294
+ pha = th.atan2(imag, real)
295
+ return th.stack([mag, pha], dim=-1)
296
+ else:
297
+ return th.stack([real, imag], dim=-1)
298
+
299
+
300
+ def _inverse_stft(
301
+ transform: th.Tensor,
302
+ kernel: th.Tensor,
303
+ window: th.Tensor,
304
+ return_polar: bool = False,
305
+ frame_hop: int = 256,
306
+ onesided: bool = False,
307
+ center: bool = False,
308
+ eps: float = EPSILON,
309
+ ) -> th.Tensor:
310
+ """
311
+ iSTFT function implemented by conv1d
312
+ Args:
313
+ transform (Tensor): STFT transform results
314
+ kernel (Tensor): STFT transform kernels, from init_kernel(...)
315
+ return_polar (bool): keep same with the one in _forward_stft
316
+ frame_hop: frame hop size in number samples
317
+ onesided: return half FFT bins
318
+ center: used in _forward_stft
319
+ Return:
320
+ wav (Tensor), N x S
321
+ """
322
+ # (N) x F x T x 2
323
+ transform_dim = transform.dim()
324
+ # if F x T x 2, reshape 1 x F x T x 2
325
+ if transform_dim == 3:
326
+ transform = th.unsqueeze(transform, 0)
327
+ if transform_dim != 4:
328
+ raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
329
+
330
+ if return_polar:
331
+ real = transform[..., 0] * th.cos(transform[..., 1])
332
+ imag = transform[..., 0] * th.sin(transform[..., 1])
333
+ else:
334
+ real, imag = transform[..., 0], transform[..., 1]
335
+
336
+ if onesided:
337
+ # [self.num_bins - 2, ..., 1]
338
+ reverse = range(kernel.shape[0] // 4 - 1, 0, -1)
339
+ # extend matrix: N x B x T
340
+ real = th.cat([real, real[:, reverse]], 1)
341
+ imag = th.cat([imag, -imag[:, reverse]], 1)
342
+ # pack: N x 2B x T
343
+ packed = th.cat([real, imag], dim=1)
344
+ # N x 1 x T
345
+ wav = tf.conv_transpose1d(packed, kernel * window, stride=frame_hop, padding=0)
346
+ # normalized audio samples
347
+ # refer: https://github.com/pytorch/audio/blob/2ebbbf511fb1e6c47b59fd32ad7e66023fa0dff1/torchaudio/functional.py#L171
348
+ num_frames = packed.shape[-1]
349
+ win_length = window.shape[0]
350
+ # W x T
351
+ win = th.repeat_interleave(window[..., None] ** 2, num_frames, dim=-1)
352
+ # Do OLA on windows
353
+ # v1)
354
+ I = th.eye(win_length, device=win.device)[:, None]
355
+ denorm = tf.conv_transpose1d(win[None, ...], I, stride=frame_hop, padding=0)
356
+ # v2)
357
+ # num_samples = (num_frames - 1) * frame_hop + win_length
358
+ # denorm = tf.fold(win[None, ...], (num_samples, 1), (win_length, 1),
359
+ # stride=frame_hop)[..., 0]
360
+ if center:
361
+ pad = kernel.shape[-1] // 2
362
+ wav = wav[..., pad:-pad]
363
+ denorm = denorm[..., pad:-pad]
364
+ wav = wav / (denorm + eps)
365
+ # N x S
366
+ return wav.squeeze(1)
367
+
368
+
369
+ def _pytorch_stft(
370
+ wav: th.Tensor,
371
+ frame_len: int,
372
+ frame_hop: int,
373
+ n_fft: int = 512,
374
+ return_polar: bool = False,
375
+ window: str = "sqrthann",
376
+ normalized: bool = False,
377
+ onesided: bool = True,
378
+ center: bool = False,
379
+ eps: float = EPSILON,
380
+ ) -> th.Tensor:
381
+ """
382
+ Wrapper of PyTorch STFT function
383
+ Args:
384
+ wav (Tensor): source audio signal
385
+ frame_len: length of the frame
386
+ frame_hop: hop size between frames
387
+ n_fft: number of the FFT size
388
+ return_polar: return the results in polar coordinate
389
+ window: window tensor
390
+ center: same definition with the parameter in librosa.stft
391
+ normalized: use normalized DFT kernel
392
+ onesided: output onesided STFT
393
+ Return:
394
+ transform (Tensor), STFT transform results
395
+ """
396
+ if TORCH_VERSION < LooseVersion("1.7"):
397
+ raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
398
+ wav_dim = wav.dim()
399
+ if wav_dim not in [2, 3]:
400
+ raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
401
+ # if N x C x S, reshape NC x S
402
+ wav = wav.view(-1, wav.shape[-1])
403
+ # STFT: N x F x T x 2
404
+ stft = th.stft(
405
+ wav,
406
+ n_fft,
407
+ hop_length=frame_hop,
408
+ win_length=window.shape[-1],
409
+ window=window,
410
+ center=center,
411
+ normalized=normalized,
412
+ onesided=onesided,
413
+ return_complex=False,
414
+ )
415
+ if wav_dim == 3:
416
+ N, F, T, _ = stft.shape
417
+ stft = stft.view(N, -1, F, T, 2)
418
+ # N x (C) x F x T x 2
419
+ if not return_polar:
420
+ return stft
421
+ # N x (C) x F x T
422
+ real, imag = stft[..., 0], stft[..., 1]
423
+ mag = (real ** 2 + imag ** 2 + eps) ** 0.5
424
+ pha = th.atan2(imag, real)
425
+ return th.stack([mag, pha], dim=-1)
426
+
427
+
428
+ def _pytorch_istft(
429
+ transform: th.Tensor,
430
+ frame_len: int,
431
+ frame_hop: int,
432
+ window: th.Tensor,
433
+ n_fft: int = 512,
434
+ return_polar: bool = False,
435
+ normalized: bool = False,
436
+ onesided: bool = True,
437
+ center: bool = False,
438
+ eps: float = EPSILON,
439
+ ) -> th.Tensor:
440
+ """
441
+ Wrapper of PyTorch iSTFT function
442
+ Args:
443
+ transform (Tensor): results of STFT
444
+ frame_len: length of the frame
445
+ frame_hop: hop size between frames
446
+ window: window tensor
447
+ n_fft: number of the FFT size
448
+ return_polar: keep same with _pytorch_stft
449
+ center: same definition with the parameter in librosa.stft
450
+ normalized: use normalized DFT kernel
451
+ onesided: output onesided STFT
452
+ Return:
453
+ wav (Tensor): synthetic audio
454
+ """
455
+ if TORCH_VERSION < LooseVersion("1.7"):
456
+ raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
457
+
458
+ transform_dim = transform.dim()
459
+ # if F x T x 2, reshape 1 x F x T x 2
460
+ if transform_dim == 3:
461
+ transform = th.unsqueeze(transform, 0)
462
+ if transform_dim != 4:
463
+ raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
464
+
465
+ if return_polar:
466
+ real = transform[..., 0] * th.cos(transform[..., 1])
467
+ imag = transform[..., 0] * th.sin(transform[..., 1])
468
+ transform = th.stack([real, imag], -1)
469
+ # stft is a complex tensor of PyTorch
470
+ stft = th.view_as_complex(transform)
471
+ # (N) x S
472
+ wav = th.istft(
473
+ stft,
474
+ n_fft,
475
+ hop_length=frame_hop,
476
+ win_length=window.shape[-1],
477
+ window=window,
478
+ center=center,
479
+ normalized=normalized,
480
+ onesided=onesided,
481
+ return_complex=False,
482
+ )
483
+ return wav
484
+
485
+
486
+ def forward_stft(
487
+ wav: th.Tensor,
488
+ frame_len: int,
489
+ frame_hop: int,
490
+ window: str = "sqrthann",
491
+ round_pow_of_two: bool = True,
492
+ return_polar: bool = False,
493
+ pre_emphasis: float = 0,
494
+ normalized: bool = False,
495
+ onesided: bool = True,
496
+ center: bool = False,
497
+ mode: str = "librosa",
498
+ eps: float = EPSILON,
499
+ ) -> th.Tensor:
500
+ """
501
+ STFT function implementation, equals to STFT layer
502
+ Args:
503
+ wav: source audio signal
504
+ frame_len: length of the frame
505
+ frame_hop: hop size between frames
506
+ return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
507
+ window: window name
508
+ center: center flag (similar with that in librosa.stft)
509
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
510
+ pre_emphasis: factor of preemphasis
511
+ normalized: use normalized DFT kernel
512
+ onesided: output onesided STFT
513
+ inverse: using iDFT kernel (for iSTFT)
514
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
515
+ Return:
516
+ transform: results of STFT
517
+ """
518
+ window = init_window(window, frame_len, device=wav.device)
519
+ if mode == "torch":
520
+ n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
521
+ return _pytorch_stft(
522
+ wav,
523
+ frame_len,
524
+ frame_hop,
525
+ n_fft=n_fft,
526
+ return_polar=return_polar,
527
+ window=window,
528
+ normalized=normalized,
529
+ onesided=onesided,
530
+ center=center,
531
+ eps=eps,
532
+ )
533
+ else:
534
+ kernel, window = init_kernel(
535
+ frame_len,
536
+ frame_hop,
537
+ window=window,
538
+ round_pow_of_two=round_pow_of_two,
539
+ normalized=normalized,
540
+ inverse=False,
541
+ mode=mode,
542
+ )
543
+ return _forward_stft(
544
+ wav,
545
+ kernel,
546
+ window,
547
+ return_polar=return_polar,
548
+ frame_hop=frame_hop,
549
+ pre_emphasis=pre_emphasis,
550
+ onesided=onesided,
551
+ center=center,
552
+ eps=eps,
553
+ )
554
+
555
+
556
+ def inverse_stft(
557
+ transform: th.Tensor,
558
+ frame_len: int,
559
+ frame_hop: int,
560
+ return_polar: bool = False,
561
+ window: str = "sqrthann",
562
+ round_pow_of_two: bool = True,
563
+ normalized: bool = False,
564
+ onesided: bool = True,
565
+ center: bool = False,
566
+ mode: str = "librosa",
567
+ eps: float = EPSILON,
568
+ ) -> th.Tensor:
569
+ """
570
+ iSTFT function implementation, equals to iSTFT layer
571
+ Args:
572
+ transform: results of STFT
573
+ frame_len: length of the frame
574
+ frame_hop: hop size between frames
575
+ return_polar: keep same with function forward_stft(...)
576
+ window: window name
577
+ center: center flag (similar with that in librosa.stft)
578
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
579
+ normalized: use normalized DFT kernel
580
+ onesided: output onesided STFT
581
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
582
+ Return:
583
+ wav: synthetic signals
584
+ """
585
+ window = init_window(window, frame_len, device=transform.device)
586
+ if mode == "torch":
587
+ n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
588
+ return _pytorch_istft(
589
+ transform,
590
+ frame_len,
591
+ frame_hop,
592
+ n_fft=n_fft,
593
+ return_polar=return_polar,
594
+ window=window,
595
+ normalized=normalized,
596
+ onesided=onesided,
597
+ center=center,
598
+ eps=eps,
599
+ )
600
+ else:
601
+ kernel, window = init_kernel(
602
+ frame_len,
603
+ frame_hop,
604
+ window,
605
+ round_pow_of_two=round_pow_of_two,
606
+ normalized=normalized,
607
+ inverse=True,
608
+ mode=mode,
609
+ )
610
+ return _inverse_stft(
611
+ transform,
612
+ kernel,
613
+ window,
614
+ return_polar=return_polar,
615
+ frame_hop=frame_hop,
616
+ onesided=onesided,
617
+ center=center,
618
+ eps=eps,
619
+ )
620
+
621
+
622
+ class STFTBase(nn.Module):
623
+ """
624
+ Base layer for (i)STFT
625
+ Args:
626
+ frame_len: length of the frame
627
+ frame_hop: hop size between frames
628
+ window: window name
629
+ center: center flag (similar with that in librosa.stft)
630
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
631
+ normalized: use normalized DFT kernel
632
+ pre_emphasis: factor of preemphasis
633
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
634
+ onesided: output onesided STFT
635
+ inverse: using iDFT kernel (for iSTFT)
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ frame_len: int,
641
+ frame_hop: int,
642
+ window: str = "sqrthann",
643
+ round_pow_of_two: bool = True,
644
+ normalized: bool = False,
645
+ pre_emphasis: float = 0,
646
+ onesided: bool = True,
647
+ inverse: bool = False,
648
+ center: bool = False,
649
+ mode: str = "librosa",
650
+ ) -> None:
651
+ super(STFTBase, self).__init__()
652
+ if mode != "torch":
653
+ K, w = init_kernel(
654
+ frame_len,
655
+ frame_hop,
656
+ init_window(window, frame_len),
657
+ round_pow_of_two=round_pow_of_two,
658
+ normalized=normalized,
659
+ inverse=inverse,
660
+ mode=mode,
661
+ )
662
+ self.K = nn.Parameter(K, requires_grad=False)
663
+ self.w = nn.Parameter(w, requires_grad=False)
664
+ self.num_bins = self.K.shape[0] // 4 + 1
665
+ self.pre_emphasis = pre_emphasis
666
+ self.win_length = self.K.shape[2]
667
+ else:
668
+ self.K = None
669
+ w = init_window(window, frame_len)
670
+ self.w = nn.Parameter(w, requires_grad=False)
671
+ fft_size = (
672
+ 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
673
+ )
674
+ self.num_bins = fft_size // 2 + 1
675
+ self.pre_emphasis = 0
676
+ self.win_length = fft_size
677
+ self.frame_len = frame_len
678
+ self.frame_hop = frame_hop
679
+ self.window = window
680
+ self.normalized = normalized
681
+ self.onesided = onesided
682
+ self.center = center
683
+ self.mode = mode
684
+
685
+ def num_frames(self, wav_len: th.Tensor) -> th.Tensor:
686
+ """
687
+ Compute number of the frames
688
+ """
689
+ assert th.sum(wav_len <= self.win_length) == 0
690
+ if self.center:
691
+ wav_len += self.win_length
692
+ return (
693
+ th.div(wav_len - self.win_length, self.frame_hop, rounding_mode="trunc") + 1
694
+ )
695
+
696
+ def extra_repr(self) -> str:
697
+ str_repr = (
698
+ f"num_bins={self.num_bins}, win_length={self.win_length}, "
699
+ + f"stride={self.frame_hop}, window={self.window}, "
700
+ + f"center={self.center}, mode={self.mode}"
701
+ )
702
+ if not self.onesided:
703
+ str_repr += f", onesided={self.onesided}"
704
+ if self.pre_emphasis > 0:
705
+ str_repr += f", pre_emphasis={self.pre_emphasis}"
706
+ if self.normalized:
707
+ str_repr += f", normalized={self.normalized}"
708
+ return str_repr
709
+
710
+
711
+ class STFT(STFTBase):
712
+ """
713
+ Short-time Fourier Transform as a Layer
714
+ """
715
+
716
+ def __init__(self, *args, **kwargs):
717
+ super(STFT, self).__init__(*args, inverse=False, **kwargs)
718
+
719
+ def forward(
720
+ self, wav: th.Tensor, return_polar: bool = False, eps: float = EPSILON
721
+ ) -> th.Tensor:
722
+ """
723
+ Accept (single or multiple channel) raw waveform and output magnitude and phase
724
+ Args
725
+ wav (Tensor) input signal, N x (C) x S
726
+ Return
727
+ transform (Tensor), N x (C) x F x T x 2
728
+ """
729
+ if self.mode == "torch":
730
+ return _pytorch_stft(
731
+ wav,
732
+ self.frame_len,
733
+ self.frame_hop,
734
+ n_fft=(self.num_bins - 1) * 2,
735
+ return_polar=return_polar,
736
+ window=self.w,
737
+ normalized=self.normalized,
738
+ onesided=self.onesided,
739
+ center=self.center,
740
+ eps=eps,
741
+ )
742
+ else:
743
+ return _forward_stft(
744
+ wav,
745
+ self.K,
746
+ self.w,
747
+ return_polar=return_polar,
748
+ frame_hop=self.frame_hop,
749
+ pre_emphasis=self.pre_emphasis,
750
+ onesided=self.onesided,
751
+ center=self.center,
752
+ eps=eps,
753
+ )
754
+
755
+
756
+ class iSTFT(STFTBase):
757
+ """
758
+ Inverse Short-time Fourier Transform as a Layer
759
+ """
760
+
761
+ def __init__(self, *args, **kwargs):
762
+ super(iSTFT, self).__init__(*args, inverse=True, **kwargs)
763
+
764
+ def forward(
765
+ self, transform: th.Tensor, return_polar: bool = False, eps: float = EPSILON
766
+ ) -> th.Tensor:
767
+ """
768
+ Accept phase & magnitude and output raw waveform
769
+ Args
770
+ transform (Tensor): STFT output, N x F x T x 2
771
+ Return
772
+ s (Tensor): N x S
773
+ """
774
+ if self.mode == "torch":
775
+ return _pytorch_istft(
776
+ transform,
777
+ self.frame_len,
778
+ self.frame_hop,
779
+ n_fft=(self.num_bins - 1) * 2,
780
+ return_polar=return_polar,
781
+ window=self.w,
782
+ normalized=self.normalized,
783
+ onesided=self.onesided,
784
+ center=self.center,
785
+ eps=eps,
786
+ )
787
+ else:
788
+ return _inverse_stft(
789
+ transform,
790
+ self.K,
791
+ self.w,
792
+ return_polar=return_polar,
793
+ frame_hop=self.frame_hop,
794
+ onesided=self.onesided,
795
+ center=self.center,
796
+ eps=eps,
797
+ )
look2hear/utils/torch_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-18 17:29:21
4
+ # LastEditors: Kai Li
5
+ # LastEditTime: 2021-06-21 23:52:52
6
+ ###
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ def pad_x_to_y(x, y, axis: int = -1):
13
+ if axis != -1:
14
+ raise NotImplementedError
15
+ inp_len = y.shape[axis]
16
+ output_len = x.shape[axis]
17
+ return nn.functional.pad(x, [0, inp_len - output_len])
18
+
19
+
20
+ def shape_reconstructed(reconstructed, size):
21
+ if len(size) == 1:
22
+ return reconstructed.squeeze(0)
23
+ return reconstructed
24
+
25
+
26
+ def tensors_to_device(tensors, device):
27
+ """Transfer tensor, dict or list of tensors to device.
28
+
29
+ Args:
30
+ tensors (:class:`torch.Tensor`): May be a single, a list or a
31
+ dictionary of tensors.
32
+ device (:class: `torch.device`): the device where to place the tensors.
33
+
34
+ Returns:
35
+ Union [:class:`torch.Tensor`, list, tuple, dict]:
36
+ Same as input but transferred to device.
37
+ Goes through lists and dicts and transfers the torch.Tensor to
38
+ device. Leaves the rest untouched.
39
+ """
40
+ if isinstance(tensors, torch.Tensor):
41
+ return tensors.to(device)
42
+ elif isinstance(tensors, (list, tuple)):
43
+ return [tensors_to_device(tens, device) for tens in tensors]
44
+ elif isinstance(tensors, dict):
45
+ for key in tensors.keys():
46
+ tensors[key] = tensors_to_device(tensors[key], device)
47
+ return tensors
48
+ else:
49
+ return tensors
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchaudio==2.2.0
2
+ torch==2.2.0
3
+ huggingface
4
+ huggingface_hub
5
+ numpy<2.0
6
+ omegaconf
7
+ ml_collections
8
+ librosa
9
+ gradio
10
+ tqdm
11
+ spaces
weights/apollo.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99d9af7f1ff20e63c393035513a655392818d66b4d7fc23d658175c1f15e8d76
3
+ size 66541845