Spaces:
Paused
Paused
| backend_version = "2.2.3 240316" | |
| print(f"Backend version: {backend_version}") | |
| # 在开头加入路径 | |
| import os, sys | |
| now_dir = os.getcwd() | |
| sys.path.append(now_dir) | |
| sys.path.append(os.path.join(now_dir, "GPT_SoVITS")) | |
| import soundfile as sf | |
| from flask import Flask, request, Response, jsonify, stream_with_context,send_file | |
| from flask_httpauth import HTTPBasicAuth | |
| from flask_cors import CORS | |
| import io | |
| import urllib.parse | |
| import tempfile | |
| import hashlib, json | |
| # 将当前文件所在的目录添加到 sys.path | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| # 从配置文件读取配置 | |
| config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json") | |
| enable_auth = False | |
| USERS = {} | |
| if os.path.exists(config_path): | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| _config = json.load(f) | |
| tts_port = _config.get("tts_port", 5000) | |
| default_batch_size = _config.get("batch_size", 1) | |
| default_word_count = _config.get("max_word_count", 50) | |
| enable_auth = _config.get("enable_auth", "false").lower() == "true" | |
| is_classic = _config.get("classic_inference", "false").lower() == "true" | |
| if enable_auth: | |
| print("启用了身份验证") | |
| USERS = _config.get("user", {}) | |
| try: | |
| from TTS_infer_pack.TTS import TTS | |
| except ImportError: | |
| is_classic = True | |
| if not is_classic: | |
| from load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info | |
| else: | |
| from classic_inference.classic_load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info | |
| app = Flask(__name__) | |
| CORS(app, resources={r"/*": {"origins": "*"}}) | |
| # 存储临时文件的字典 | |
| temp_files = {} | |
| # 用于防止重复请求 | |
| def generate_file_hash(*args): | |
| """生成基于输入参数的哈希值,用于唯一标识一个请求""" | |
| hash_object = hashlib.md5() | |
| for arg in args: | |
| hash_object.update(str(arg).encode()) | |
| return hash_object.hexdigest() | |
| auth = HTTPBasicAuth() | |
| CORS(app, resources={r"/*": {"origins": "*"}}) | |
| def verify_password(username, password): | |
| if not enable_auth: | |
| return True # 如果没有启用验证,则允许访问 | |
| return USERS.get(username) == password | |
| def character_list(): | |
| res = jsonify(update_character_info()['characters_and_emotions']) | |
| return res | |
| def tts(): | |
| global character_name | |
| global models_path | |
| # 尝试从JSON中获取数据,如果不是JSON,则从查询参数中获取 | |
| if request.is_json: | |
| data = request.json | |
| else: | |
| data = request.args | |
| text = urllib.parse.unquote(data.get('text', '')) | |
| cha_name = data.get('cha_name', None) | |
| expected_path = os.path.join(models_path, cha_name) if cha_name else None | |
| # 检查cha_name和路径 | |
| if cha_name and cha_name != character_name and expected_path and os.path.exists(expected_path): | |
| character_name = cha_name | |
| print(f"Loading character {character_name}") | |
| load_character(character_name) | |
| elif expected_path and not os.path.exists(expected_path): | |
| return jsonify({"error": f"Directory {expected_path} does not exist. Using the current character."}), 400 | |
| text_language = str(data.get('text_language', '多语种混合')).lower() | |
| try: | |
| batch_size = int(data.get('batch_size', default_batch_size)) | |
| speed_factor = float(data.get('speed', 1.0)) | |
| top_k = int(data.get('top_k', 6)) | |
| top_p = float(data.get('top_p', 0.8)) | |
| temperature = float(data.get('temperature', 0.8)) | |
| seed = int(data.get('seed', -1)) | |
| except ValueError: | |
| return jsonify({"error": "Invalid parameters. They must be numbers."}), 400 | |
| stream = str(data.get('stream', 'False')).lower() in ('true', '1', 't', 'y', 'yes') | |
| save_temp = str(data.get('save_temp', 'False')).lower() in ('true', '1', 't', 'y', 'yes') | |
| cut_method = str(data.get('cut_method', 'auto_cut')).lower() | |
| character_emotion = data.get('character_emotion', 'default') | |
| if cut_method == "auto_cut": | |
| cut_method = f"auto_cut_{default_word_count}" | |
| params = { | |
| "text": text, | |
| "text_language": text_language, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "character_emotion": character_emotion, | |
| "cut_method": cut_method, | |
| "stream": stream | |
| } | |
| # 如果不是经典模式,则添加额外的参数 | |
| if not is_classic: | |
| params["batch_size"] = batch_size | |
| params["speed_factor"] = speed_factor | |
| params["seed"] = seed | |
| request_hash = generate_file_hash(text, text_language, top_k, top_p, temperature, character_emotion, character_name, seed) | |
| format = data.get('format', 'wav') | |
| if not format in ['wav', 'mp3', 'ogg']: | |
| return jsonify({"error": "Invalid format. It must be one of 'wav', 'mp3', or 'ogg'."}), 400 | |
| if stream == False: | |
| if save_temp: | |
| if request_hash in temp_files: | |
| return send_file(temp_files[request_hash], mimetype=f'audio/{format}') | |
| else: | |
| gen = get_wav_from_text_api(**params) | |
| sampling_rate, audio_data = next(gen) | |
| temp_file_path = tempfile.mktemp(suffix=f'.{format}') | |
| with open(temp_file_path, 'wb') as temp_file: | |
| sf.write(temp_file, audio_data, sampling_rate, format=format) | |
| temp_files[request_hash] = temp_file_path | |
| return send_file(temp_file_path, mimetype=f'audio/{format}') | |
| else: | |
| gen = get_wav_from_text_api(**params) | |
| sampling_rate, audio_data = next(gen) | |
| wav = io.BytesIO() | |
| sf.write(wav, audio_data, sampling_rate, format=format) | |
| wav.seek(0) | |
| return Response(wav, mimetype=f'audio/{format}') | |
| else: | |
| gen = get_wav_from_text_api(**params) | |
| return Response(stream_with_context(gen), mimetype='audio/wav') | |
| if __name__ == '__main__': | |
| app.run( host='0.0.0.0', port=tts_port) | |