import os import glob import json import torch import logging import argparse import requests import subprocess import modelscope import huggingface_hub import numpy as np from tqdm import tqdm from scipy.io.wavfile import read MATPLOTLIB_FLAG = False logger = logging.getLogger(__name__) EN_US = os.getenv("LANG") != "zh_CN.UTF-8" ZH2EN = { "输入模式": "Input Mode", "请输入简体中文文案": "Please input the Simplified Chinese text", "首次推理需耗时下载模型,还请耐心等待。": "The first inference takes time to download the model, so be patient.", "角色": "Role", "状态栏": "Status", "语调调节": "Modulation of intonation", "感情调节": "Emotional adjustment", "音素长度": "Phoneme length", "生成时长": "Output duration", "输出音频": "Output Audio", "上传模式": "Upload Mode", "请上传简体中文 TXT 文案": "Please upload a simplified Chinese TXT", "文案提取结果": "Result of TXT extraction", """ 欢迎使用此创空间,此创空间基于 Bert-vits2 开源项目制作,移至最底端有原理浅讲。使用此创空间必须遵守当地相关法律法规,禁止用其从事任何违法犯罪活动。""": """ Welcome to the Space, which is based on the open source project Bert-vits2, and moved to the bottom for an explanation of the principle. This Space must be used in accordance with local laws and regulations, prohibiting the use of it for any criminal activities.""", } MODEL_DIR = ( huggingface_hub.snapshot_download( "Genius-Society/hoyoTTS", cache_dir="./__pycache__", ) if EN_US else modelscope.snapshot_download( "Genius-Society/hoyoTTS", cache_dir="./__pycache__", ) ) def _L(zh_txt: str): return ZH2EN[zh_txt] if EN_US else zh_txt def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") iteration = checkpoint_dict["iteration"] learning_rate = checkpoint_dict["learning_rate"] if ( optimizer is not None and not skip_optimizer and checkpoint_dict["optimizer"] is not None ): optimizer.load_state_dict(checkpoint_dict["optimizer"]) elif optimizer is None and not skip_optimizer: # else: Disable this line if Infer and resume checkpoint,then enable the line upper new_opt_dict = optimizer.state_dict() new_opt_dict_params = new_opt_dict["param_groups"][0]["params"] new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"] new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params optimizer.load_state_dict(new_opt_dict) saved_state_dict = checkpoint_dict["model"] if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() new_state_dict = {} for k, v in state_dict.items(): try: # assert "emb_g" not in k # print("load", k) new_state_dict[k] = saved_state_dict[k] assert saved_state_dict[k].shape == v.shape, ( saved_state_dict[k].shape, v.shape, ) except: logger.error("%s is not in the checkpoint" % k) new_state_dict[k] = v if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False) else: model.load_state_dict(new_state_dict, strict=False) logger.info( "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) ) return model, optimizer, learning_rate, iteration def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): logger.info( "Saving model and optimizer state at iteration {} to {}".format( iteration, checkpoint_path ) ) if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save( { "model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate, }, checkpoint_path, ) def summarize( writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050, ): for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): writer.add_histogram(k, v, global_step) for k, v in images.items(): writer.add_image(k, v, global_step, dataformats="HWC") for k, v in audios.items(): writer.add_audio(k, v, global_step, audio_sampling_rate) def latest_checkpoint_path(dir_path, regex="G_*.pth"): f_list = glob.glob(os.path.join(dir_path, regex)) f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) x = f_list[-1] print(x) return x def plot_spectrogram_to_numpy(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib matplotlib.use("Agg") MATPLOTLIB_FLAG = True mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def plot_alignment_to_numpy(alignment, info=None): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib matplotlib.use("Agg") MATPLOTLIB_FLAG = True mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow( alignment.transpose(), aspect="auto", origin="lower", interpolation="none" ) fig.colorbar(im, ax=ax) xlabel = "Decoder timestep" if info is not None: xlabel += "\n\n" + info plt.xlabel(xlabel) plt.ylabel("Encoder timestep") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), sampling_rate def load_filepaths_and_text(filename, split="|"): with open(filename, encoding="utf-8") as f: filepaths_and_text = [line.strip().split(split) for line in f] return filepaths_and_text def get_hparams(init=True): parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, default="./configs/base.json", help="JSON file for configuration", ) parser.add_argument("-m", "--model", type=str, required=True, help="Model name") args = parser.parse_args() model_dir = os.path.join("./logs", args.model) if not os.path.exists(model_dir): os.makedirs(model_dir) config_path = args.config config_save_path = os.path.join(model_dir, "config.json") if init: with open(config_path, "r") as f: data = f.read() with open(config_save_path, "w") as f: f.write(data) else: with open(config_save_path, "r") as f: data = f.read() config = json.loads(data) hparams = HParams(**config) hparams.model_dir = model_dir return hparams def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): """Freeing up space by deleting saved ckpts Arguments: path_to_models -- Path to the model directory n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth sort_by_time -- True -> chronologically delete ckpts False -> lexicographically delete ckpts """ import re ckpts_files = [ f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f)) ] name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1)) time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)) sort_key = time_key if sort_by_time else name_key x_sorted = lambda _x: sorted( [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], key=sort_key, ) to_del = [ os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep]) ] del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") del_routine = lambda x: [os.remove(x), del_info(x)] rs = [del_routine(fn) for fn in to_del] print(rs) def get_hparams_from_dir(model_dir): config_save_path = os.path.join(model_dir, "config.json") with open(config_save_path, "r", encoding="utf-8") as f: data = f.read() config = json.loads(data) hparams = HParams(**config) hparams.model_dir = model_dir return hparams def download_file(file_url: str): filename = file_url.split("&FilePath=")[-1] if os.path.exists(filename): return filename response = requests.get(file_url, stream=True) # 检查请求是否成功 if response.status_code == 200: # 获取文件总大小 file_size = int(response.headers.get("Content-Length", 0)) # 打开文件以写入二进制数据 with open(filename, "wb") as file: # 创建进度条 progress_bar = tqdm( total=file_size, unit="B", unit_scale=True, desc=f"Downloading {filename}...", ) # 以块的形式下载文件 for chunk in response.iter_content(chunk_size=8192): if chunk: # 过滤掉保持连接的新块 file.write(chunk) progress_bar.update(len(chunk)) # 更新进度条 progress_bar.close() # 关闭进度条 print(f"模型文件 '{file_url}' 下载成功。") else: print(f"下载失败,状态码:{response.status_code}") return filename def get_hparams_from_url(config_url): response = requests.get(config_url) config = response.json() return HParams(**config) def check_git_hash(model_dir): source_dir = os.path.dirname(os.path.realpath(__file__)) if not os.path.exists(os.path.join(source_dir, ".git")): logger.warn( "{} is not a git repository, therefore hash value comparison will be ignored.".format( source_dir ) ) return cur_hash = subprocess.getoutput("git rev-parse HEAD") path = os.path.join(model_dir, "githash") if os.path.exists(path): saved_hash = open(path).read() if saved_hash != cur_hash: logger.warn( "git hash values are different. {}(saved) != {}(current)".format( saved_hash[:8], cur_hash[:8] ) ) else: open(path, "w").write(cur_hash) def get_logger(model_dir, filename="train.log"): global logger logger = logging.getLogger(os.path.basename(model_dir)) logger.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") if not os.path.exists(model_dir): os.makedirs(model_dir) h = logging.FileHandler(os.path.join(model_dir, filename)) h.setLevel(logging.DEBUG) h.setFormatter(formatter) logger.addHandler(h) return logger class HParams: def __init__(self, **kwargs): for k, v in kwargs.items(): if type(v) == dict: v = HParams(**v) self[k] = v def keys(self): return self.__dict__.keys() def items(self): return self.__dict__.items() def values(self): return self.__dict__.values() def __len__(self): return len(self.__dict__) def __getitem__(self, key): return getattr(self, key) def __setitem__(self, key, value): return setattr(self, key, value) def __contains__(self, key): return key in self.__dict__ def __repr__(self): return self.__dict__.__repr__()