asr / audiosr /pipeline.py
kgout's picture
Update audiosr/pipeline.py
98f7b12 verified
import os
import re
import spaces
import yaml
import torch
import torchaudio
import numpy as np
import audiosr.latent_diffusion.modules.phoneme_encoder.text as text
from audiosr.latent_diffusion.models.ddpm import LatentDiffusion
from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding
from audiosr.utils import (
default_audioldm_config,
download_checkpoint,
read_audio_file,
lowpass_filtering_prepare_inference,
wav_feature_extraction,
)
import os
def seed_everything(seed):
import random, os
import numpy as np
import torch
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def text2phoneme(data):
return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"])
def text_to_filename(text):
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
norm_mean = -4.2677393
norm_std = 4.5689974
if sampling_rate != 16000:
waveform_16k = torchaudio.functional.resample(
waveform, orig_freq=sampling_rate, new_freq=16000
)
else:
waveform_16k = waveform
waveform_16k = waveform_16k - waveform_16k.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform_16k,
htk_compat=True,
sample_frequency=16000,
use_energy=False,
window_type="hanning",
num_mel_bins=128,
dither=0.0,
frame_shift=10,
)
TARGET_LEN = log_mel_spec.size(0)
# cut and pad
n_frames = fbank.shape[0]
p = TARGET_LEN - n_frames
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[:TARGET_LEN, :]
fbank = (fbank - norm_mean) / (norm_std * 2)
return {"ta_kaldi_fbank": fbank} # [1024, 128]
def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
batch = {
"waveform": torch.FloatTensor(waveform),
"stft": torch.FloatTensor(stft),
"log_mel_spec": torch.FloatTensor(log_mel_spec),
"sampling_rate": 48000,
}
# print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size())
batch.update(lowpass_filtering_prepare_inference(batch))
assert "waveform_lowpass" in batch.keys()
lowpass_mel, lowpass_stft = wav_feature_extraction(
batch["waveform_lowpass"], target_frame
)
batch["lowpass_mel"] = lowpass_mel
for k in batch.keys():
if type(batch[k]) == torch.Tensor:
batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0)
return batch, duration
def round_up_duration(duration):
return int(round(duration / 2.5) + 1) * 2.5
@spaces.GPU
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
if device is None or device == "auto":
if torch.cuda.is_available():
# device = torch.Tensor([0]).cuda()
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print("Loading AudioSR: %s" % model_name)
print("Loading model on %s" % device)
ckpt_path = download_checkpoint(model_name)
if config is not None:
assert type(config) is str
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
else:
config = default_audioldm_config(model_name)
# # Use text as condition instead of using waveform during training
config["model"]["params"]["device"] = device
# config["model"]["params"]["cond_stage_key"] = "text"
# No normalization here
latent_diffusion = LatentDiffusion(**config["model"]["params"])
resume_from_checkpoint = ckpt_path
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
latent_diffusion.eval()
latent_diffusion = latent_diffusion.to(device)
return latent_diffusion
def super_resolution(
latent_diffusion,
input_file,
seed=42,
ddim_steps=200,
guidance_scale=3.5,
latent_t_per_second=12.8,
config=None,
):
seed_everything(int(seed))
waveform = None
batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
with torch.no_grad():
waveform = latent_diffusion.generate_batch(
batch,
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
duration=duration,
)
return waveform