Update audiosr/pipeline.py
Browse files- audiosr/pipeline.py +176 -175
audiosr/pipeline.py
CHANGED
@@ -1,175 +1,176 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
|
4 |
-
import yaml
|
5 |
-
import torch
|
6 |
-
import torchaudio
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
from audiosr.latent_diffusion.
|
12 |
-
from audiosr.
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
import
|
25 |
-
import
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
torch.
|
32 |
-
torch.
|
33 |
-
torch.backends.cudnn.
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
"
|
89 |
-
"
|
90 |
-
"
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
print("Loading
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
latent_diffusion
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import yaml
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import numpy as np
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
import audiosr.latent_diffusion.modules.phoneme_encoder.text as text
|
11 |
+
from audiosr.latent_diffusion.models.ddpm import LatentDiffusion
|
12 |
+
from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding
|
13 |
+
from audiosr.utils import (
|
14 |
+
default_audioldm_config,
|
15 |
+
download_checkpoint,
|
16 |
+
read_audio_file,
|
17 |
+
lowpass_filtering_prepare_inference,
|
18 |
+
wav_feature_extraction,
|
19 |
+
)
|
20 |
+
import os
|
21 |
+
|
22 |
+
|
23 |
+
def seed_everything(seed):
|
24 |
+
import random, os
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
|
28 |
+
random.seed(seed)
|
29 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
30 |
+
np.random.seed(seed)
|
31 |
+
torch.manual_seed(seed)
|
32 |
+
torch.cuda.manual_seed(seed)
|
33 |
+
torch.backends.cudnn.deterministic = True
|
34 |
+
torch.backends.cudnn.benchmark = True
|
35 |
+
|
36 |
+
|
37 |
+
def text2phoneme(data):
|
38 |
+
return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"])
|
39 |
+
|
40 |
+
|
41 |
+
def text_to_filename(text):
|
42 |
+
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
|
43 |
+
|
44 |
+
|
45 |
+
def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
|
46 |
+
norm_mean = -4.2677393
|
47 |
+
norm_std = 4.5689974
|
48 |
+
|
49 |
+
if sampling_rate != 16000:
|
50 |
+
waveform_16k = torchaudio.functional.resample(
|
51 |
+
waveform, orig_freq=sampling_rate, new_freq=16000
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
waveform_16k = waveform
|
55 |
+
|
56 |
+
waveform_16k = waveform_16k - waveform_16k.mean()
|
57 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
58 |
+
waveform_16k,
|
59 |
+
htk_compat=True,
|
60 |
+
sample_frequency=16000,
|
61 |
+
use_energy=False,
|
62 |
+
window_type="hanning",
|
63 |
+
num_mel_bins=128,
|
64 |
+
dither=0.0,
|
65 |
+
frame_shift=10,
|
66 |
+
)
|
67 |
+
|
68 |
+
TARGET_LEN = log_mel_spec.size(0)
|
69 |
+
|
70 |
+
# cut and pad
|
71 |
+
n_frames = fbank.shape[0]
|
72 |
+
p = TARGET_LEN - n_frames
|
73 |
+
if p > 0:
|
74 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
75 |
+
fbank = m(fbank)
|
76 |
+
elif p < 0:
|
77 |
+
fbank = fbank[:TARGET_LEN, :]
|
78 |
+
|
79 |
+
fbank = (fbank - norm_mean) / (norm_std * 2)
|
80 |
+
|
81 |
+
return {"ta_kaldi_fbank": fbank} # [1024, 128]
|
82 |
+
|
83 |
+
|
84 |
+
def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
|
85 |
+
log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
|
86 |
+
|
87 |
+
batch = {
|
88 |
+
"waveform": torch.FloatTensor(waveform),
|
89 |
+
"stft": torch.FloatTensor(stft),
|
90 |
+
"log_mel_spec": torch.FloatTensor(log_mel_spec),
|
91 |
+
"sampling_rate": 48000,
|
92 |
+
}
|
93 |
+
|
94 |
+
# print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size())
|
95 |
+
|
96 |
+
batch.update(lowpass_filtering_prepare_inference(batch))
|
97 |
+
|
98 |
+
assert "waveform_lowpass" in batch.keys()
|
99 |
+
lowpass_mel, lowpass_stft = wav_feature_extraction(
|
100 |
+
batch["waveform_lowpass"], target_frame
|
101 |
+
)
|
102 |
+
batch["lowpass_mel"] = lowpass_mel
|
103 |
+
|
104 |
+
for k in batch.keys():
|
105 |
+
if type(batch[k]) == torch.Tensor:
|
106 |
+
batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0)
|
107 |
+
|
108 |
+
return batch, duration
|
109 |
+
|
110 |
+
|
111 |
+
def round_up_duration(duration):
|
112 |
+
return int(round(duration / 2.5) + 1) * 2.5
|
113 |
+
|
114 |
+
@spaces.GPU
|
115 |
+
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
116 |
+
if device is None or device == "auto":
|
117 |
+
if torch.cuda.is_available():
|
118 |
+
device = torch.device("cuda:0")
|
119 |
+
elif torch.backends.mps.is_available():
|
120 |
+
device = torch.device("mps")
|
121 |
+
else:
|
122 |
+
device = torch.device("cpu")
|
123 |
+
|
124 |
+
print("Loading AudioSR: %s" % model_name)
|
125 |
+
print("Loading model on %s" % device)
|
126 |
+
|
127 |
+
ckpt_path = download_checkpoint(model_name)
|
128 |
+
|
129 |
+
if config is not None:
|
130 |
+
assert type(config) is str
|
131 |
+
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
|
132 |
+
else:
|
133 |
+
config = default_audioldm_config(model_name)
|
134 |
+
|
135 |
+
# # Use text as condition instead of using waveform during training
|
136 |
+
config["model"]["params"]["device"] = device
|
137 |
+
# config["model"]["params"]["cond_stage_key"] = "text"
|
138 |
+
|
139 |
+
# No normalization here
|
140 |
+
latent_diffusion = LatentDiffusion(**config["model"]["params"])
|
141 |
+
|
142 |
+
resume_from_checkpoint = ckpt_path
|
143 |
+
|
144 |
+
checkpoint = torch.load(resume_from_checkpoint, map_location='cpu')
|
145 |
+
|
146 |
+
latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
|
147 |
+
|
148 |
+
latent_diffusion.eval()
|
149 |
+
latent_diffusion = latent_diffusion.to(device)
|
150 |
+
|
151 |
+
return latent_diffusion
|
152 |
+
|
153 |
+
|
154 |
+
def super_resolution(
|
155 |
+
latent_diffusion,
|
156 |
+
input_file,
|
157 |
+
seed=42,
|
158 |
+
ddim_steps=200,
|
159 |
+
guidance_scale=3.5,
|
160 |
+
latent_t_per_second=12.8,
|
161 |
+
config=None,
|
162 |
+
):
|
163 |
+
seed_everything(int(seed))
|
164 |
+
waveform = None
|
165 |
+
|
166 |
+
batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
waveform = latent_diffusion.generate_batch(
|
170 |
+
batch,
|
171 |
+
unconditional_guidance_scale=guidance_scale,
|
172 |
+
ddim_steps=ddim_steps,
|
173 |
+
duration=duration,
|
174 |
+
)
|
175 |
+
|
176 |
+
return waveform
|