Spaces:
Running
Running
import os | |
import glob | |
import json | |
import torch | |
import logging | |
import argparse | |
import requests | |
import subprocess | |
import modelscope | |
import huggingface_hub | |
import numpy as np | |
from tqdm import tqdm | |
from scipy.io.wavfile import read | |
MATPLOTLIB_FLAG = False | |
logger = logging.getLogger(__name__) | |
EN_US = os.getenv("LANG") != "zh_CN.UTF-8" | |
ZH2EN = { | |
"输入模式": "Input Mode", | |
"请输入简体中文文案": "Please input the Simplified Chinese text", | |
"首次推理需耗时下载模型,还请耐心等待。": "The first inference takes time to download the model, so be patient.", | |
"角色": "Role", | |
"状态栏": "Status", | |
"语调调节": "Modulation of intonation", | |
"感情调节": "Emotional adjustment", | |
"音素长度": "Phoneme length", | |
"生成时长": "Output duration", | |
"输出音频": "Output Audio", | |
"上传模式": "Upload Mode", | |
"请上传简体中文 TXT 文案": "Please upload a simplified Chinese TXT", | |
"文案提取结果": "Result of TXT extraction", | |
""" | |
欢迎使用此创空间,此创空间基于 <a href="https://github.com/fishaudio/Bert-VITS2">Bert-vits2</a> 开源项目制作,移至最底端有原理浅讲。使用此创空间必须遵守当地相关法律法规,禁止用其从事任何违法犯罪活动。""": """ | |
Welcome to the Space, which is based on the open source project <a href="https://github.com/fishaudio/Bert-VITS2">Bert-vits2</a>, and moved to the bottom for an explanation of the principle. This Space must be used in accordance with local laws and regulations, prohibiting the use of it for any criminal activities.""", | |
} | |
MODEL_DIR = ( | |
huggingface_hub.snapshot_download( | |
"Genius-Society/hoyoTTS", | |
cache_dir="./__pycache__", | |
) | |
if EN_US | |
else modelscope.snapshot_download( | |
"Genius-Society/hoyoTTS", | |
cache_dir="./__pycache__", | |
) | |
) | |
def _L(zh_txt: str): | |
return ZH2EN[zh_txt] if EN_US else zh_txt | |
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): | |
assert os.path.isfile(checkpoint_path) | |
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | |
iteration = checkpoint_dict["iteration"] | |
learning_rate = checkpoint_dict["learning_rate"] | |
if ( | |
optimizer is not None | |
and not skip_optimizer | |
and checkpoint_dict["optimizer"] is not None | |
): | |
optimizer.load_state_dict(checkpoint_dict["optimizer"]) | |
elif optimizer is None and not skip_optimizer: | |
# else: Disable this line if Infer and resume checkpoint,then enable the line upper | |
new_opt_dict = optimizer.state_dict() | |
new_opt_dict_params = new_opt_dict["param_groups"][0]["params"] | |
new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"] | |
new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params | |
optimizer.load_state_dict(new_opt_dict) | |
saved_state_dict = checkpoint_dict["model"] | |
if hasattr(model, "module"): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
try: | |
# assert "emb_g" not in k | |
# print("load", k) | |
new_state_dict[k] = saved_state_dict[k] | |
assert saved_state_dict[k].shape == v.shape, ( | |
saved_state_dict[k].shape, | |
v.shape, | |
) | |
except: | |
logger.error("%s is not in the checkpoint" % k) | |
new_state_dict[k] = v | |
if hasattr(model, "module"): | |
model.module.load_state_dict(new_state_dict, strict=False) | |
else: | |
model.load_state_dict(new_state_dict, strict=False) | |
logger.info( | |
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) | |
) | |
return model, optimizer, learning_rate, iteration | |
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): | |
logger.info( | |
"Saving model and optimizer state at iteration {} to {}".format( | |
iteration, checkpoint_path | |
) | |
) | |
if hasattr(model, "module"): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
torch.save( | |
{ | |
"model": state_dict, | |
"iteration": iteration, | |
"optimizer": optimizer.state_dict(), | |
"learning_rate": learning_rate, | |
}, | |
checkpoint_path, | |
) | |
def summarize( | |
writer, | |
global_step, | |
scalars={}, | |
histograms={}, | |
images={}, | |
audios={}, | |
audio_sampling_rate=22050, | |
): | |
for k, v in scalars.items(): | |
writer.add_scalar(k, v, global_step) | |
for k, v in histograms.items(): | |
writer.add_histogram(k, v, global_step) | |
for k, v in images.items(): | |
writer.add_image(k, v, global_step, dataformats="HWC") | |
for k, v in audios.items(): | |
writer.add_audio(k, v, global_step, audio_sampling_rate) | |
def latest_checkpoint_path(dir_path, regex="G_*.pth"): | |
f_list = glob.glob(os.path.join(dir_path, regex)) | |
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) | |
x = f_list[-1] | |
print(x) | |
return x | |
def plot_spectrogram_to_numpy(spectrogram): | |
global MATPLOTLIB_FLAG | |
if not MATPLOTLIB_FLAG: | |
import matplotlib | |
matplotlib.use("Agg") | |
MATPLOTLIB_FLAG = True | |
mpl_logger = logging.getLogger("matplotlib") | |
mpl_logger.setLevel(logging.WARNING) | |
import matplotlib.pylab as plt | |
import numpy as np | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") | |
plt.colorbar(im, ax=ax) | |
plt.xlabel("Frames") | |
plt.ylabel("Channels") | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plt.close() | |
return data | |
def plot_alignment_to_numpy(alignment, info=None): | |
global MATPLOTLIB_FLAG | |
if not MATPLOTLIB_FLAG: | |
import matplotlib | |
matplotlib.use("Agg") | |
MATPLOTLIB_FLAG = True | |
mpl_logger = logging.getLogger("matplotlib") | |
mpl_logger.setLevel(logging.WARNING) | |
import matplotlib.pylab as plt | |
import numpy as np | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
im = ax.imshow( | |
alignment.transpose(), aspect="auto", origin="lower", interpolation="none" | |
) | |
fig.colorbar(im, ax=ax) | |
xlabel = "Decoder timestep" | |
if info is not None: | |
xlabel += "\n\n" + info | |
plt.xlabel(xlabel) | |
plt.ylabel("Encoder timestep") | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plt.close() | |
return data | |
def load_wav_to_torch(full_path): | |
sampling_rate, data = read(full_path) | |
return torch.FloatTensor(data.astype(np.float32)), sampling_rate | |
def load_filepaths_and_text(filename, split="|"): | |
with open(filename, encoding="utf-8") as f: | |
filepaths_and_text = [line.strip().split(split) for line in f] | |
return filepaths_and_text | |
def get_hparams(init=True): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-c", | |
"--config", | |
type=str, | |
default="./configs/base.json", | |
help="JSON file for configuration", | |
) | |
parser.add_argument("-m", "--model", type=str, required=True, help="Model name") | |
args = parser.parse_args() | |
model_dir = os.path.join("./logs", args.model) | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
config_path = args.config | |
config_save_path = os.path.join(model_dir, "config.json") | |
if init: | |
with open(config_path, "r") as f: | |
data = f.read() | |
with open(config_save_path, "w") as f: | |
f.write(data) | |
else: | |
with open(config_save_path, "r") as f: | |
data = f.read() | |
config = json.loads(data) | |
hparams = HParams(**config) | |
hparams.model_dir = model_dir | |
return hparams | |
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): | |
"""Freeing up space by deleting saved ckpts | |
Arguments: | |
path_to_models -- Path to the model directory | |
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth | |
sort_by_time -- True -> chronologically delete ckpts | |
False -> lexicographically delete ckpts | |
""" | |
import re | |
ckpts_files = [ | |
f | |
for f in os.listdir(path_to_models) | |
if os.path.isfile(os.path.join(path_to_models, f)) | |
] | |
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1)) | |
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)) | |
sort_key = time_key if sort_by_time else name_key | |
x_sorted = lambda _x: sorted( | |
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], | |
key=sort_key, | |
) | |
to_del = [ | |
os.path.join(path_to_models, fn) | |
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep]) | |
] | |
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") | |
del_routine = lambda x: [os.remove(x), del_info(x)] | |
rs = [del_routine(fn) for fn in to_del] | |
print(rs) | |
def get_hparams_from_dir(model_dir): | |
config_save_path = os.path.join(model_dir, "config.json") | |
with open(config_save_path, "r", encoding="utf-8") as f: | |
data = f.read() | |
config = json.loads(data) | |
hparams = HParams(**config) | |
hparams.model_dir = model_dir | |
return hparams | |
def download_file(file_url: str): | |
filename = file_url.split("&FilePath=")[-1] | |
if os.path.exists(filename): | |
return filename | |
response = requests.get(file_url, stream=True) | |
# 检查请求是否成功 | |
if response.status_code == 200: | |
# 获取文件总大小 | |
file_size = int(response.headers.get("Content-Length", 0)) | |
# 打开文件以写入二进制数据 | |
with open(filename, "wb") as file: | |
# 创建进度条 | |
progress_bar = tqdm( | |
total=file_size, | |
unit="B", | |
unit_scale=True, | |
desc=f"Downloading {filename}...", | |
) | |
# 以块的形式下载文件 | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: # 过滤掉保持连接的新块 | |
file.write(chunk) | |
progress_bar.update(len(chunk)) # 更新进度条 | |
progress_bar.close() # 关闭进度条 | |
print(f"模型文件 '{file_url}' 下载成功。") | |
else: | |
print(f"下载失败,状态码:{response.status_code}") | |
return filename | |
def get_hparams_from_url(config_url): | |
response = requests.get(config_url) | |
config = response.json() | |
return HParams(**config) | |
def check_git_hash(model_dir): | |
source_dir = os.path.dirname(os.path.realpath(__file__)) | |
if not os.path.exists(os.path.join(source_dir, ".git")): | |
logger.warn( | |
"{} is not a git repository, therefore hash value comparison will be ignored.".format( | |
source_dir | |
) | |
) | |
return | |
cur_hash = subprocess.getoutput("git rev-parse HEAD") | |
path = os.path.join(model_dir, "githash") | |
if os.path.exists(path): | |
saved_hash = open(path).read() | |
if saved_hash != cur_hash: | |
logger.warn( | |
"git hash values are different. {}(saved) != {}(current)".format( | |
saved_hash[:8], cur_hash[:8] | |
) | |
) | |
else: | |
open(path, "w").write(cur_hash) | |
def get_logger(model_dir, filename="train.log"): | |
global logger | |
logger = logging.getLogger(os.path.basename(model_dir)) | |
logger.setLevel(logging.DEBUG) | |
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
h = logging.FileHandler(os.path.join(model_dir, filename)) | |
h.setLevel(logging.DEBUG) | |
h.setFormatter(formatter) | |
logger.addHandler(h) | |
return logger | |
class HParams: | |
def __init__(self, **kwargs): | |
for k, v in kwargs.items(): | |
if type(v) == dict: | |
v = HParams(**v) | |
self[k] = v | |
def keys(self): | |
return self.__dict__.keys() | |
def items(self): | |
return self.__dict__.items() | |
def values(self): | |
return self.__dict__.values() | |
def __len__(self): | |
return len(self.__dict__) | |
def __getitem__(self, key): | |
return getattr(self, key) | |
def __setitem__(self, key, value): | |
return setattr(self, key, value) | |
def __contains__(self, key): | |
return key in self.__dict__ | |
def __repr__(self): | |
return self.__dict__.__repr__() | |