File size: 2,887 Bytes
5dd43d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import json
import gradio as gr
import requests
from sseclient import SSEClient
from dotenv import load_dotenv

# 加载环境变量
load_dotenv()

# API配置
API_URL = os.getenv("API_URL", "https://your-api-endpoint/v1/chat/completions")
API_KEY = os.getenv("API_KEY", "your-api-key")

def format_message(role, content):
    return {"role": role, "content": content}

class ChatBot:
    def __init__(self):
        self.headers = {
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json"
        }
    
    def generate_stream(self, messages):
        data = {
            "model": "gpt-3.5-turbo",  # 或其他模型
            "messages": messages,
            "stream": True,
            "temperature": 0.7
        }
        
        response = requests.post(
            API_URL,
            headers=self.headers,
            json=data,
            stream=True
        )
        
        client = SSEClient(response)
        return client

def chat_stream(message, history):
    chatbot = ChatBot()
    
    # 构建消息历史
    messages = []
    for human, assistant in history:
        messages.append(format_message("user", human))
        messages.append(format_message("assistant", assistant))
    messages.append(format_message("user", message))
    
    # 流式响应
    response_stream = chatbot.generate_stream(messages)
    partial_message = ""
    
    for event in response_stream:
        if event.data != "[DONE]":
            try:
                chunk = json.loads(event.data)
                if chunk and "choices" in chunk:
                    delta = chunk["choices"][0]["delta"]
                    if "content" in delta:
                        partial_message += delta["content"]
                        yield partial_message
            except json.JSONDecodeError:
                continue
    
    return partial_message

# Gradio界面配置
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    chatbot = gr.Chatbot(
        height=600,
        show_copy_button=True,
        avatar_images=["assets/user.png", "assets/assistant.png"],
    )
    msg = gr.Textbox(
        placeholder="在这里输入您的问题...",
        container=False,
        scale=7,
    )
    with gr.Row():
        submit = gr.Button("发送", scale=2, variant="primary")
        clear = gr.Button("清除对话", scale=1)
    
    # 事件处理
    msg.submit(
        chat_stream,
        [msg, chatbot],
        [chatbot],
        api_name="chat"
    ).then(
        lambda: "",
        None,
        [msg],
        api_name="clear_input"
    )
    
    submit.click(
        chat_stream,
        [msg, chatbot],
        [chatbot],
    ).then(
        lambda: "",
        None,
        [msg],
    )
    
    clear.click(lambda: None, None, chatbot)

# 启动应用
if __name__ == "__main__":
    demo.queue().launch()