|
import os |
|
import sys |
|
import torch |
|
import numpy as np |
|
import traceback |
|
from scipy.io import wavfile |
|
import librosa |
|
from pathlib import Path |
|
from time import time as ttime |
|
import shutil |
|
from tools.my_utils import load_audio, clean_path |
|
from feature_extractor import cnhubert |
|
|
|
def my_save(fea, path, i_part): |
|
"""Fix issue: torch.save doesn't support chinese path""" |
|
dir = os.path.dirname(path) |
|
name = os.path.basename(path) |
|
tmp_path = f"{ttime()}{i_part}.pth" |
|
torch.save(fea, tmp_path) |
|
shutil.move(tmp_path, f"{dir}/{name}") |
|
|
|
def extract_hubert_features(data_dir="data8", exp_dir="logs/s2"): |
|
"""Extract Hubert features for stage 2 training""" |
|
|
|
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
data_dir = os.path.join(root_dir, data_dir) |
|
exp_dir = os.path.join(root_dir, exp_dir) |
|
|
|
|
|
inp_text = os.path.join(exp_dir, "2-name2text.txt") |
|
inp_wav_dir = os.path.join(exp_dir, "5-wav32k") |
|
exp_name = "s2" |
|
i_part = "0" |
|
all_parts = "1" |
|
opt_dir = exp_dir |
|
cnhubert.cnhubert_base_path = os.path.join(root_dir, "pretrained_models", "chinese-hubert-base") |
|
is_half = torch.cuda.is_available() |
|
|
|
print("Starting Hubert feature extraction...") |
|
print(f"Input text file: {inp_text}") |
|
print(f"Input wav directory: {inp_wav_dir}") |
|
print(f"Output directory: {opt_dir}") |
|
|
|
hubert_dir = f"{opt_dir}/4-cnhubert" |
|
wav32dir = f"{opt_dir}/5-wav32k" |
|
os.makedirs(opt_dir, exist_ok=True) |
|
os.makedirs(hubert_dir, exist_ok=True) |
|
os.makedirs(wav32dir, exist_ok=True) |
|
|
|
maxx = 0.95 |
|
alpha = 0.5 |
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
else: |
|
device = "cpu" |
|
|
|
print(f"Loading Hubert model from: {cnhubert.cnhubert_base_path}") |
|
model = cnhubert.get_model() |
|
if is_half: |
|
model = model.half().to(device) |
|
else: |
|
model = model.to(device) |
|
|
|
nan_fails = [] |
|
|
|
def name2go(wav_name, wav_path): |
|
print(f"Processing: {wav_name} from {wav_path}") |
|
hubert_path = f"{hubert_dir}/{wav_name}.pt" |
|
if os.path.exists(hubert_path): |
|
print(f"Skipping {wav_name} - already processed") |
|
return |
|
|
|
if not os.path.exists(wav_path): |
|
print(f"Error: WAV file not found: {wav_path}") |
|
return |
|
|
|
tmp_audio = load_audio(wav_path, 32000) |
|
if tmp_audio is None: |
|
print(f"Error: Failed to load audio: {wav_path}") |
|
return |
|
|
|
tmp_max = np.abs(tmp_audio).max() |
|
if tmp_max > 2.2: |
|
print(f"{wav_name}-filtered,{tmp_max}") |
|
return |
|
|
|
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio |
|
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio |
|
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) |
|
|
|
tensor_wav16 = torch.from_numpy(tmp_audio) |
|
if is_half: |
|
tensor_wav16 = tensor_wav16.half().to(device) |
|
else: |
|
tensor_wav16 = tensor_wav16.to(device) |
|
|
|
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() |
|
|
|
if np.isnan(ssl.detach().numpy()).sum() != 0: |
|
nan_fails.append((wav_name, wav_path)) |
|
print(f"nan filtered:{wav_name}") |
|
return |
|
|
|
wavfile.write( |
|
f"{wav32dir}/{wav_name}", |
|
32000, |
|
tmp_audio32.astype("int16"), |
|
) |
|
my_save(ssl, hubert_path, i_part) |
|
print(f"Successfully processed {wav_name}") |
|
|
|
print(f"Reading text file: {inp_text}") |
|
with open(inp_text, "r", encoding="utf8") as f: |
|
lines = f.read().strip("\n").split("\n") |
|
print(f"Found {len(lines)} lines in text file") |
|
|
|
for line in lines[int(i_part)::int(all_parts)]: |
|
try: |
|
print(f"Processing line: {line}") |
|
wav_name, text, _, _ = line.split("\t") |
|
wav_name = clean_path(wav_name) |
|
if inp_wav_dir: |
|
wav_name = os.path.basename(wav_name) |
|
wav_path = f"{inp_wav_dir}/{wav_name}" |
|
else: |
|
wav_path = wav_name |
|
wav_name = os.path.basename(wav_name) |
|
name2go(wav_name, wav_path) |
|
except Exception as e: |
|
print(f"Error processing line: {line}") |
|
print(traceback.format_exc()) |
|
|
|
if len(nan_fails) > 0 and is_half: |
|
print("Retrying failed files in float32 mode...") |
|
is_half = False |
|
model = model.float() |
|
for wav in nan_fails: |
|
try: |
|
name2go(wav[0], wav[1]) |
|
except: |
|
print(f"Error retrying {wav_name}") |
|
print(traceback.format_exc()) |
|
|
|
print("Hubert feature extraction complete.") |
|
|
|
if __name__ == "__main__": |
|
extract_hubert_features() |