Style-Bert-VITS2 / preprocess_text.py
samiabat's picture
Upload folder using huggingface_hub
d00dc99 verified
import argparse
import json
from collections import defaultdict
from pathlib import Path
from random import sample, shuffle
from typing import Optional
from tqdm import tqdm
from config import get_config
from style_bert_vits2.logging import logger
from style_bert_vits2.nlp import clean_text
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker
from style_bert_vits2.nlp.japanese.user_dict import update_dict
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT
# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化
pyopenjtalk_worker.initialize_worker()
# dict_data/ 以下の辞書データを pyopenjtalk に適用
update_dict()
preprocess_text_config = get_config().preprocess_text_config
# Count lines for tqdm
def count_lines(file_path: Path):
with file_path.open("r", encoding="utf-8") as file:
return sum(1 for _ in file)
def write_error_log(error_log_path: Path, line: str, error: Exception):
with error_log_path.open("a", encoding="utf-8") as error_log:
error_log.write(f"{line.strip()}\n{error}\n\n")
def process_line(
line: str,
transcription_path: Path,
correct_path: bool,
use_jp_extra: bool,
yomi_error: str,
):
splitted_line = line.strip().split("|")
if len(splitted_line) != 4:
raise ValueError(f"Invalid line format: {line.strip()}")
utt, spk, language, text = splitted_line
norm_text, phones, tones, word2ph = clean_text(
text=text,
language=language, # type: ignore
use_jp_extra=use_jp_extra,
raise_yomi_error=(yomi_error != "use"),
)
if correct_path:
utt = str(transcription_path.parent / "wavs" / utt)
return "{}|{}|{}|{}|{}|{}|{}\n".format(
utt,
spk,
language,
norm_text,
" ".join(phones),
" ".join([str(i) for i in tones]),
" ".join([str(i) for i in word2ph]),
)
def preprocess(
transcription_path: Path,
cleaned_path: Optional[Path],
train_path: Path,
val_path: Path,
config_path: Path,
val_per_lang: int,
max_val_total: int,
# clean: bool,
use_jp_extra: bool,
yomi_error: str,
correct_path: bool,
):
assert yomi_error in ["raise", "skip", "use"]
if cleaned_path == "" or cleaned_path is None:
cleaned_path = transcription_path.with_name(
transcription_path.name + ".cleaned"
)
error_log_path = transcription_path.parent / "text_error.log"
if error_log_path.exists():
error_log_path.unlink()
error_count = 0
total_lines = count_lines(transcription_path)
# transcription_path から 1行ずつ読み込んで文章処理して cleaned_path に書き込む
with (
transcription_path.open("r", encoding="utf-8") as trans_file,
cleaned_path.open("w", encoding="utf-8") as out_file,
):
for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines):
try:
processed_line = process_line(
line,
transcription_path,
correct_path,
use_jp_extra,
yomi_error,
)
out_file.write(processed_line)
except Exception as e:
logger.error(
f"An error occurred at line:\n{line.strip()}\n{e}", encoding="utf-8"
)
write_error_log(error_log_path, line, e)
error_count += 1
transcription_path = cleaned_path
# 各話者ごとのlineの辞書
spk_utt_map: dict[str, list[str]] = defaultdict(list)
# 話者からIDへの写像
spk_id_map: dict[str, int] = {}
# 話者ID
current_sid: int = 0
# 音源ファイルのチェックや、spk_id_mapの作成
with transcription_path.open("r", encoding="utf-8") as f:
audio_paths: set[str] = set()
count_same = 0
count_not_found = 0
for line in f.readlines():
utt, spk = line.strip().split("|")[:2]
if utt in audio_paths:
logger.warning(f"Same audio file appears multiple times: {utt}")
count_same += 1
continue
if not Path(utt).is_file():
logger.warning(f"Audio not found: {utt}")
count_not_found += 1
continue
audio_paths.add(utt)
spk_utt_map[spk].append(line)
# 新しい話者が出てきたら話者IDを割り当て、current_sidを1増やす
if spk not in spk_id_map:
spk_id_map[spk] = current_sid
current_sid += 1
if count_same > 0 or count_not_found > 0:
logger.warning(
f"Total repeated audios: {count_same}, Total number of audio not found: {count_not_found}"
)
train_list: list[str] = []
val_list: list[str] = []
# 各話者ごとに発話リストを処理
for spk, utts in spk_utt_map.items():
if val_per_lang == 0:
train_list.extend(utts)
continue
# ランダムにval_per_lang個のインデックスを選択
val_indices = set(sample(range(len(utts)), val_per_lang))
# 元の順序を保ちながらリストを分割
for index, utt in enumerate(utts):
if index in val_indices:
val_list.append(utt)
else:
train_list.append(utt)
# バリデーションリストのサイズ調整
if len(val_list) > max_val_total:
extra_val = val_list[max_val_total:]
val_list = val_list[:max_val_total]
# 余剰のバリデーション発話をトレーニングリストに追加(元の順序を保持)
train_list.extend(extra_val)
with train_path.open("w", encoding="utf-8") as f:
for line in train_list:
f.write(line)
with val_path.open("w", encoding="utf-8") as f:
for line in val_list:
f.write(line)
with config_path.open("r", encoding="utf-8") as f:
json_config = json.load(f)
json_config["data"]["spk2id"] = spk_id_map
json_config["data"]["n_speakers"] = len(spk_id_map)
with config_path.open("w", encoding="utf-8") as f:
json.dump(json_config, f, indent=2, ensure_ascii=False)
if error_count > 0:
if yomi_error == "skip":
logger.warning(
f"An error occurred in {error_count} lines. Proceed with lines without errors. Please check {error_log_path} for details."
)
else:
# yom_error == "raise"と"use"の場合。
# "use"の場合は、そもそもyomi_error = Falseで処理しているので、
# ここが実行されるのは他の例外のときなので、エラーをraiseする。
logger.error(
f"An error occurred in {error_count} lines. Please check {error_log_path} for details."
)
raise Exception(
f"An error occurred in {error_count} lines. Please check `Data/you_model_name/text_error.log` file for details."
)
# 何故か{error_log_path}をraiseすると文字コードエラーが起きるので上のように書いている
else:
logger.info(
"Training set and validation set generation from texts is complete!"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcription-path", default=preprocess_text_config.transcription_path
)
parser.add_argument("--cleaned-path", default=preprocess_text_config.cleaned_path)
parser.add_argument("--train-path", default=preprocess_text_config.train_path)
parser.add_argument("--val-path", default=preprocess_text_config.val_path)
parser.add_argument("--config-path", default=preprocess_text_config.config_path)
# 「話者ごと」のバリデーションデータ数、言語ごとではない!
# 元のコードや設定ファイルでval_per_langとなっていたので名前をそのままにしている
parser.add_argument(
"--val-per-lang",
default=preprocess_text_config.val_per_lang,
help="Number of validation data per SPEAKER, not per language (due to compatibility with the original code).",
)
parser.add_argument("--max-val-total", default=preprocess_text_config.max_val_total)
parser.add_argument("--use_jp_extra", action="store_true")
parser.add_argument("--yomi_error", default="raise")
parser.add_argument("--correct_path", action="store_true")
args = parser.parse_args()
transcription_path = Path(args.transcription_path)
cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None
train_path = Path(args.train_path)
val_path = Path(args.val_path)
config_path = Path(args.config_path)
val_per_lang = int(args.val_per_lang)
max_val_total = int(args.max_val_total)
use_jp_extra: bool = args.use_jp_extra
yomi_error: str = args.yomi_error
correct_path: bool = args.correct_path
preprocess(
transcription_path=transcription_path,
cleaned_path=cleaned_path,
train_path=train_path,
val_path=val_path,
config_path=config_path,
val_per_lang=val_per_lang,
max_val_total=max_val_total,
use_jp_extra=use_jp_extra,
yomi_error=yomi_error,
correct_path=correct_path,
)