Spaces:
Build error
Build error
import whisper | |
import os | |
import json | |
import torchaudio | |
import argparse | |
import torch | |
lang2token = { | |
'zh': "[ZH]", | |
'ja': "[JA]", | |
"en": "[EN]", | |
} | |
def transcribe_one(audio_path): | |
# load audio and pad/trim it to fit 30 seconds | |
audio = whisper.load_audio(audio_path) | |
audio = whisper.pad_or_trim(audio) | |
# make log-Mel spectrogram and move to the same device as the model | |
mel = whisper.log_mel_spectrogram(audio).to(model.device) | |
# detect the spoken language | |
_, probs = model.detect_language(mel) | |
print(f"Detected language: {max(probs, key=probs.get)}") | |
lang = max(probs, key=probs.get) | |
# decode the audio | |
options = whisper.DecodingOptions(beam_size=5) | |
result = whisper.decode(model, mel, options) | |
# print the recognized text | |
print(result.text) | |
return lang, result.text | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--languages", default="CJE") | |
parser.add_argument("--whisper_size", default="medium") | |
args = parser.parse_args() | |
if args.languages == "CJE": | |
lang2token = { | |
'zh': "[ZH]", | |
'ja': "[JA]", | |
"en": "[EN]", | |
} | |
elif args.languages == "CJ": | |
lang2token = { | |
'zh': "[ZH]", | |
'ja': "[JA]", | |
} | |
elif args.languages == "C": | |
lang2token = { | |
'zh': "[ZH]", | |
} | |
assert (torch.cuda.is_available()), "Please enable GPU in order to run Whisper!" | |
model = whisper.load_model(args.whisper_size) | |
parent_dir = "./custom_character_voice/" | |
speaker_names = list(os.walk(parent_dir))[0][1] | |
speaker_annos = [] | |
total_files = sum([len(files) for r, d, files in os.walk(parent_dir)]) | |
# resample audios | |
# 2023/4/21: Get the target sampling rate | |
with open("./configs/config.json", 'r', encoding='utf-8') as f: | |
hps = json.load(f) | |
target_sr = hps['data']['sampling_rate'] | |
processed_files = 0 | |
for speaker in speaker_names: | |
for i, wavfile in enumerate(list(os.walk(parent_dir + speaker))[0][2]): | |
# try to load file as audio | |
if wavfile.startswith("processed_"): | |
continue | |
try: | |
wav, sr = torchaudio.load(parent_dir + speaker + "/" + wavfile, frame_offset=0, num_frames=-1, normalize=True, | |
channels_first=True) | |
wav = wav.mean(dim=0).unsqueeze(0) | |
if sr != target_sr: | |
wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(wav) | |
if wav.shape[1] / sr > 20: | |
print(f"{wavfile} too long, ignoring\n") | |
save_path = parent_dir + speaker + "/" + f"processed_{i}.wav" | |
torchaudio.save(save_path, wav, target_sr, channels_first=True) | |
# transcribe text | |
lang, text = transcribe_one(save_path) | |
if lang not in list(lang2token.keys()): | |
print(f"{lang} not supported, ignoring\n") | |
continue | |
text = "ZH|" + text + "\n"# | |
#text = lang2token[lang] + text + lang2token[lang] + "\n" | |
speaker_annos.append(save_path + "|" + speaker + "|" + text) | |
processed_files += 1 | |
print(f"Processed: {processed_files}/{total_files}") | |
except: | |
continue | |
# # clean annotation | |
# import argparse | |
# import text | |
# from utils import load_filepaths_and_text | |
# for i, line in enumerate(speaker_annos): | |
# path, sid, txt = line.split("|") | |
# cleaned_text = text._clean_text(txt, ["cjke_cleaners2"]) | |
# cleaned_text += "\n" if not cleaned_text.endswith("\n") else "" | |
# speaker_annos[i] = path + "|" + sid + "|" + cleaned_text | |
# write into annotation | |
if len(speaker_annos) == 0: | |
print("Warning: no short audios found, this IS expected if you have only uploaded long audios, videos or video links.") | |
print("this IS NOT expected if you have uploaded a zip file of short audios. Please check your file structure or make sure your audio language is supported.") | |
with open("./filelists/short_character_anno.list", 'w', encoding='utf-8') as f: | |
for line in speaker_annos: | |
f.write(line) | |
# import json | |
# # generate new config | |
# with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: | |
# hps = json.load(f) | |
# # modify n_speakers | |
# hps['data']["n_speakers"] = 1000 + len(speaker2id) | |
# # add speaker names | |
# for speaker in speaker_names: | |
# hps['speakers'][speaker] = speaker2id[speaker] | |
# # save modified config | |
# with open("./configs/modified_finetune_speaker.json", 'w', encoding='utf-8') as f: | |
# json.dump(hps, f, indent=2) | |
# print("finished") | |