Spaces:
Running
Running
backend_version = "2.2.3 240316" | |
print(f"Backend version: {backend_version}") | |
# 在开头加入路径 | |
import os, sys | |
now_dir = os.getcwd() | |
sys.path.append(now_dir) | |
sys.path.append(os.path.join(now_dir, "GPT_SoVITS")) | |
import soundfile as sf | |
from flask import Flask, request, Response, jsonify, stream_with_context,send_file | |
from flask_httpauth import HTTPBasicAuth | |
from flask_cors import CORS | |
import io | |
import urllib.parse | |
import tempfile | |
import hashlib, json | |
# 将当前文件所在的目录添加到 sys.path | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
# 从配置文件读取配置 | |
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json") | |
enable_auth = False | |
USERS = {} | |
if os.path.exists(config_path): | |
with open(config_path, 'r', encoding='utf-8') as f: | |
_config = json.load(f) | |
tts_port = _config.get("tts_port", 5000) | |
default_batch_size = _config.get("batch_size", 1) | |
default_word_count = _config.get("max_word_count", 50) | |
enable_auth = _config.get("enable_auth", "false").lower() == "true" | |
is_classic = _config.get("classic_inference", "false").lower() == "true" | |
if enable_auth: | |
print("启用了身份验证") | |
USERS = _config.get("user", {}) | |
try: | |
from TTS_infer_pack.TTS import TTS | |
except ImportError: | |
is_classic = True | |
if not is_classic: | |
from load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info | |
else: | |
from classic_inference.classic_load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info | |
app = Flask(__name__) | |
CORS(app, resources={r"/*": {"origins": "*"}}) | |
# 存储临时文件的字典 | |
temp_files = {} | |
# 用于防止重复请求 | |
def generate_file_hash(*args): | |
"""生成基于输入参数的哈希值,用于唯一标识一个请求""" | |
hash_object = hashlib.md5() | |
for arg in args: | |
hash_object.update(str(arg).encode()) | |
return hash_object.hexdigest() | |
auth = HTTPBasicAuth() | |
CORS(app, resources={r"/*": {"origins": "*"}}) | |
def verify_password(username, password): | |
if not enable_auth: | |
return True # 如果没有启用验证,则允许访问 | |
return USERS.get(username) == password | |
def character_list(): | |
res = jsonify(update_character_info()['characters_and_emotions']) | |
return res | |
def tts(): | |
global character_name | |
global models_path | |
# 尝试从JSON中获取数据,如果不是JSON,则从查询参数中获取 | |
if request.is_json: | |
data = request.json | |
else: | |
data = request.args | |
text = urllib.parse.unquote(data.get('text', '')) | |
cha_name = data.get('cha_name', None) | |
expected_path = os.path.join(models_path, cha_name) if cha_name else None | |
# 检查cha_name和路径 | |
if cha_name and cha_name != character_name and expected_path and os.path.exists(expected_path): | |
character_name = cha_name | |
print(f"Loading character {character_name}") | |
load_character(character_name) | |
elif expected_path and not os.path.exists(expected_path): | |
return jsonify({"error": f"Directory {expected_path} does not exist. Using the current character."}), 400 | |
text_language = str(data.get('text_language', '多语种混合')).lower() | |
try: | |
batch_size = int(data.get('batch_size', default_batch_size)) | |
speed_factor = float(data.get('speed', 1.0)) | |
top_k = int(data.get('top_k', 6)) | |
top_p = float(data.get('top_p', 0.8)) | |
temperature = float(data.get('temperature', 0.8)) | |
seed = int(data.get('seed', -1)) | |
except ValueError: | |
return jsonify({"error": "Invalid parameters. They must be numbers."}), 400 | |
stream = str(data.get('stream', 'False')).lower() in ('true', '1', 't', 'y', 'yes') | |
save_temp = str(data.get('save_temp', 'False')).lower() in ('true', '1', 't', 'y', 'yes') | |
cut_method = str(data.get('cut_method', 'auto_cut')).lower() | |
character_emotion = data.get('character_emotion', 'default') | |
if cut_method == "auto_cut": | |
cut_method = f"auto_cut_{default_word_count}" | |
params = { | |
"text": text, | |
"text_language": text_language, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature, | |
"character_emotion": character_emotion, | |
"cut_method": cut_method, | |
"stream": stream | |
} | |
# 如果不是经典模式,则添加额外的参数 | |
if not is_classic: | |
params["batch_size"] = batch_size | |
params["speed_factor"] = speed_factor | |
params["seed"] = seed | |
request_hash = generate_file_hash(text, text_language, top_k, top_p, temperature, character_emotion, character_name, seed) | |
format = data.get('format', 'wav') | |
if not format in ['wav', 'mp3', 'ogg']: | |
return jsonify({"error": "Invalid format. It must be one of 'wav', 'mp3', or 'ogg'."}), 400 | |
if stream == False: | |
if save_temp: | |
if request_hash in temp_files: | |
return send_file(temp_files[request_hash], mimetype=f'audio/{format}') | |
else: | |
gen = get_wav_from_text_api(**params) | |
sampling_rate, audio_data = next(gen) | |
temp_file_path = tempfile.mktemp(suffix=f'.{format}') | |
with open(temp_file_path, 'wb') as temp_file: | |
sf.write(temp_file, audio_data, sampling_rate, format=format) | |
temp_files[request_hash] = temp_file_path | |
return send_file(temp_file_path, mimetype=f'audio/{format}') | |
else: | |
gen = get_wav_from_text_api(**params) | |
sampling_rate, audio_data = next(gen) | |
wav = io.BytesIO() | |
sf.write(wav, audio_data, sampling_rate, format=format) | |
wav.seek(0) | |
return Response(wav, mimetype=f'audio/{format}') | |
else: | |
gen = get_wav_from_text_api(**params) | |
return Response(stream_with_context(gen), mimetype='audio/wav') | |
if __name__ == '__main__': | |
app.run( host='0.0.0.0', port=tts_port) | |