Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import requests | |
from dotenv import load_dotenv | |
import logging | |
# 设置日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# 加载环境变量 | |
load_dotenv() | |
# API配置 | |
BASE_URL = os.getenv("API_URL", "").rstrip('/') # 移除末尾的斜杠 | |
API_KEY = os.getenv("API_KEY", "") | |
# 构建完整的API URL | |
API_URL = f"{BASE_URL}/v1/chat/completions" | |
# 验证环境变量 | |
if not BASE_URL or not API_KEY: | |
raise ValueError(""" | |
请确保设置了必要的环境变量: | |
- API_URL: API基础地址 (例如: https://api.example.com) | |
- API_KEY: API密钥 | |
可以在Hugging Face Space的Settings -> Repository Secrets中设置这些变量 | |
""") | |
class ChatBot: | |
def __init__(self): | |
# 修正 headers 的设置 | |
self.headers = { | |
"Authorization": f"Bearer {API_KEY}", # 正确使用 API_KEY | |
"Content-Type": "application/json", | |
"Accept": "text/event-stream" | |
} | |
self.verify_api_config() | |
def verify_api_config(self): | |
try: | |
# 使用 OPTIONS 请求来验证API端点 | |
response = requests.options(API_URL, timeout=5) | |
logger.info(f"API endpoint: {API_URL}") | |
logger.info(f"API headers: {self.headers}") | |
if response.status_code >= 400: | |
logger.error(f"API配置可能有误: {response.status_code}") | |
logger.error(f"API响应: {response.text[:200]}") | |
except Exception as e: | |
logger.error(f"API连接测试失败: {str(e)}") | |
def format_message(role, content): | |
return {"role": role, "content": content} | |
class ChatBot: | |
def __init__(self): | |
self.headers = { | |
"Authorization": f"Bearer {API_KEY}", | |
"Content-Type": "application/json", | |
"Accept": "text/event-stream" # 添加这行 | |
} | |
def chat_stream(message, history): | |
chatbot = ChatBot() | |
logger.info(f"Sending message: {message}") | |
logger.info(f"API URL: {API_URL}") | |
logger.info(f"Headers: {chatbot.headers}") # 添加header日志 | |
messages = [] | |
for human, assistant in history: | |
messages.append(format_message("user", human)) | |
messages.append(format_message("assistant", assistant)) | |
messages.append(format_message("user", message)) | |
try: | |
# 首先验证API是否可用 | |
verify_response = requests.get(API_URL) | |
logger.info(f"API验证响应状态码: {verify_response.status_code}") | |
logger.info(f"API验证响应内容: {verify_response.text[:200]}...") # 只记录前200个字符 | |
payload = { | |
"model": "gpt-4o", | |
"messages": messages, | |
"stream": True, | |
"temperature": 0.7 | |
} | |
logger.info(f"发送请求数据: {json.dumps(payload, ensure_ascii=False)}") | |
response = requests.post( | |
API_URL, | |
headers=chatbot.headers, | |
json=payload, | |
stream=True | |
) | |
if response.headers.get('content-type', '').startswith('text/html'): | |
error_msg = "API返回了HTML而不是预期的流式响应。请检查API配置。" | |
logger.error(error_msg) | |
history.append((message, error_msg)) | |
return history | |
if response.status_code != 200: | |
error_msg = f"API返回错误状态码: {response.status_code}\n错误信息: {response.text}" | |
logger.error(error_msg) | |
history.append((message, error_msg)) | |
return history | |
partial_message = "" | |
for line in response.iter_lines(): | |
if line: | |
try: | |
line = line.decode('utf-8') | |
logger.info(f"收到数据: {line}") | |
if line.startswith('data: '): | |
line = line[6:] | |
if line == '[DONE]': | |
break | |
try: | |
chunk = json.loads(line) | |
if chunk and "choices" in chunk: | |
delta = chunk["choices"][0]["delta"] | |
if "content" in delta: | |
content = delta["content"] | |
partial_message += content | |
history.append((message, partial_message)) | |
yield history | |
history.pop() | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON解析错误: {e}") | |
continue | |
except Exception as e: | |
logger.error(f"处理响应时出错: {e}") | |
continue | |
if not partial_message: | |
error_msg = "未能获取到有效的响应内容" | |
logger.error(error_msg) | |
history.append((message, error_msg)) | |
else: | |
history.append((message, partial_message)) | |
except Exception as e: | |
error_msg = f"请求发生错误: {str(e)}" | |
logger.error(error_msg) | |
history.append((message, error_msg)) | |
return history | |
# Gradio界面配置 | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
chatbot = gr.Chatbot( | |
height=600, | |
show_copy_button=True, | |
avatar_images=["assets/user.png", "assets/assistant.png"], | |
) | |
msg = gr.Textbox( | |
placeholder="在这里输入您的问题...", | |
container=False, | |
scale=7, | |
) | |
with gr.Row(): | |
submit = gr.Button("发送", scale=2, variant="primary") | |
clear = gr.Button("清除对话", scale=1) | |
# 事件处理 | |
msg.submit( | |
chat_stream, | |
[msg, chatbot], | |
[chatbot], | |
api_name="chat" | |
).then( | |
lambda: "", | |
None, | |
[msg], | |
api_name="clear_input" | |
) | |
submit.click( | |
chat_stream, | |
[msg, chatbot], | |
[chatbot], | |
).then( | |
lambda: "", | |
None, | |
[msg], | |
) | |
clear.click(lambda: [], None, chatbot) | |
# 启动应用 | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) |