Spaces:
Runtime error
Runtime error
import argparse | |
import json | |
from collections import defaultdict | |
from pathlib import Path | |
from random import sample, shuffle | |
from typing import Optional | |
from tqdm import tqdm | |
from config import get_config | |
from style_bert_vits2.logging import logger | |
from style_bert_vits2.nlp import clean_text | |
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker | |
from style_bert_vits2.nlp.japanese.user_dict import update_dict | |
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT | |
# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 | |
pyopenjtalk_worker.initialize_worker() | |
# dict_data/ 以下の辞書データを pyopenjtalk に適用 | |
update_dict() | |
preprocess_text_config = get_config().preprocess_text_config | |
# Count lines for tqdm | |
def count_lines(file_path: Path): | |
with file_path.open("r", encoding="utf-8") as file: | |
return sum(1 for _ in file) | |
def write_error_log(error_log_path: Path, line: str, error: Exception): | |
with error_log_path.open("a", encoding="utf-8") as error_log: | |
error_log.write(f"{line.strip()}\n{error}\n\n") | |
def process_line( | |
line: str, | |
transcription_path: Path, | |
correct_path: bool, | |
use_jp_extra: bool, | |
yomi_error: str, | |
): | |
splitted_line = line.strip().split("|") | |
if len(splitted_line) != 4: | |
raise ValueError(f"Invalid line format: {line.strip()}") | |
utt, spk, language, text = splitted_line | |
norm_text, phones, tones, word2ph = clean_text( | |
text=text, | |
language=language, # type: ignore | |
use_jp_extra=use_jp_extra, | |
raise_yomi_error=(yomi_error != "use"), | |
) | |
if correct_path: | |
utt = str(transcription_path.parent / "wavs" / utt) | |
return "{}|{}|{}|{}|{}|{}|{}\n".format( | |
utt, | |
spk, | |
language, | |
norm_text, | |
" ".join(phones), | |
" ".join([str(i) for i in tones]), | |
" ".join([str(i) for i in word2ph]), | |
) | |
def preprocess( | |
transcription_path: Path, | |
cleaned_path: Optional[Path], | |
train_path: Path, | |
val_path: Path, | |
config_path: Path, | |
val_per_lang: int, | |
max_val_total: int, | |
# clean: bool, | |
use_jp_extra: bool, | |
yomi_error: str, | |
correct_path: bool, | |
): | |
assert yomi_error in ["raise", "skip", "use"] | |
if cleaned_path == "" or cleaned_path is None: | |
cleaned_path = transcription_path.with_name( | |
transcription_path.name + ".cleaned" | |
) | |
error_log_path = transcription_path.parent / "text_error.log" | |
if error_log_path.exists(): | |
error_log_path.unlink() | |
error_count = 0 | |
total_lines = count_lines(transcription_path) | |
# transcription_path から 1行ずつ読み込んで文章処理して cleaned_path に書き込む | |
with ( | |
transcription_path.open("r", encoding="utf-8") as trans_file, | |
cleaned_path.open("w", encoding="utf-8") as out_file, | |
): | |
for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines): | |
try: | |
processed_line = process_line( | |
line, | |
transcription_path, | |
correct_path, | |
use_jp_extra, | |
yomi_error, | |
) | |
out_file.write(processed_line) | |
except Exception as e: | |
logger.error( | |
f"An error occurred at line:\n{line.strip()}\n{e}", encoding="utf-8" | |
) | |
write_error_log(error_log_path, line, e) | |
error_count += 1 | |
transcription_path = cleaned_path | |
# 各話者ごとのlineの辞書 | |
spk_utt_map: dict[str, list[str]] = defaultdict(list) | |
# 話者からIDへの写像 | |
spk_id_map: dict[str, int] = {} | |
# 話者ID | |
current_sid: int = 0 | |
# 音源ファイルのチェックや、spk_id_mapの作成 | |
with transcription_path.open("r", encoding="utf-8") as f: | |
audio_paths: set[str] = set() | |
count_same = 0 | |
count_not_found = 0 | |
for line in f.readlines(): | |
utt, spk = line.strip().split("|")[:2] | |
if utt in audio_paths: | |
logger.warning(f"Same audio file appears multiple times: {utt}") | |
count_same += 1 | |
continue | |
if not Path(utt).is_file(): | |
logger.warning(f"Audio not found: {utt}") | |
count_not_found += 1 | |
continue | |
audio_paths.add(utt) | |
spk_utt_map[spk].append(line) | |
# 新しい話者が出てきたら話者IDを割り当て、current_sidを1増やす | |
if spk not in spk_id_map: | |
spk_id_map[spk] = current_sid | |
current_sid += 1 | |
if count_same > 0 or count_not_found > 0: | |
logger.warning( | |
f"Total repeated audios: {count_same}, Total number of audio not found: {count_not_found}" | |
) | |
train_list: list[str] = [] | |
val_list: list[str] = [] | |
# 各話者ごとに発話リストを処理 | |
for spk, utts in spk_utt_map.items(): | |
if val_per_lang == 0: | |
train_list.extend(utts) | |
continue | |
# ランダムにval_per_lang個のインデックスを選択 | |
val_indices = set(sample(range(len(utts)), val_per_lang)) | |
# 元の順序を保ちながらリストを分割 | |
for index, utt in enumerate(utts): | |
if index in val_indices: | |
val_list.append(utt) | |
else: | |
train_list.append(utt) | |
# バリデーションリストのサイズ調整 | |
if len(val_list) > max_val_total: | |
extra_val = val_list[max_val_total:] | |
val_list = val_list[:max_val_total] | |
# 余剰のバリデーション発話をトレーニングリストに追加(元の順序を保持) | |
train_list.extend(extra_val) | |
with train_path.open("w", encoding="utf-8") as f: | |
for line in train_list: | |
f.write(line) | |
with val_path.open("w", encoding="utf-8") as f: | |
for line in val_list: | |
f.write(line) | |
with config_path.open("r", encoding="utf-8") as f: | |
json_config = json.load(f) | |
json_config["data"]["spk2id"] = spk_id_map | |
json_config["data"]["n_speakers"] = len(spk_id_map) | |
with config_path.open("w", encoding="utf-8") as f: | |
json.dump(json_config, f, indent=2, ensure_ascii=False) | |
if error_count > 0: | |
if yomi_error == "skip": | |
logger.warning( | |
f"An error occurred in {error_count} lines. Proceed with lines without errors. Please check {error_log_path} for details." | |
) | |
else: | |
# yom_error == "raise"と"use"の場合。 | |
# "use"の場合は、そもそもyomi_error = Falseで処理しているので、 | |
# ここが実行されるのは他の例外のときなので、エラーをraiseする。 | |
logger.error( | |
f"An error occurred in {error_count} lines. Please check {error_log_path} for details." | |
) | |
raise Exception( | |
f"An error occurred in {error_count} lines. Please check `Data/you_model_name/text_error.log` file for details." | |
) | |
# 何故か{error_log_path}をraiseすると文字コードエラーが起きるので上のように書いている | |
else: | |
logger.info( | |
"Training set and validation set generation from texts is complete!" | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--transcription-path", default=preprocess_text_config.transcription_path | |
) | |
parser.add_argument("--cleaned-path", default=preprocess_text_config.cleaned_path) | |
parser.add_argument("--train-path", default=preprocess_text_config.train_path) | |
parser.add_argument("--val-path", default=preprocess_text_config.val_path) | |
parser.add_argument("--config-path", default=preprocess_text_config.config_path) | |
# 「話者ごと」のバリデーションデータ数、言語ごとではない! | |
# 元のコードや設定ファイルでval_per_langとなっていたので名前をそのままにしている | |
parser.add_argument( | |
"--val-per-lang", | |
default=preprocess_text_config.val_per_lang, | |
help="Number of validation data per SPEAKER, not per language (due to compatibility with the original code).", | |
) | |
parser.add_argument("--max-val-total", default=preprocess_text_config.max_val_total) | |
parser.add_argument("--use_jp_extra", action="store_true") | |
parser.add_argument("--yomi_error", default="raise") | |
parser.add_argument("--correct_path", action="store_true") | |
args = parser.parse_args() | |
transcription_path = Path(args.transcription_path) | |
cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None | |
train_path = Path(args.train_path) | |
val_path = Path(args.val_path) | |
config_path = Path(args.config_path) | |
val_per_lang = int(args.val_per_lang) | |
max_val_total = int(args.max_val_total) | |
use_jp_extra: bool = args.use_jp_extra | |
yomi_error: str = args.yomi_error | |
correct_path: bool = args.correct_path | |
preprocess( | |
transcription_path=transcription_path, | |
cleaned_path=cleaned_path, | |
train_path=train_path, | |
val_path=val_path, | |
config_path=config_path, | |
val_per_lang=val_per_lang, | |
max_val_total=max_val_total, | |
use_jp_extra=use_jp_extra, | |
yomi_error=yomi_error, | |
correct_path=correct_path, | |
) | |