JohnSmith9982
commited on
Commit
·
1620ce5
1
Parent(s):
55db78d
Upload 3 files
Browse files- app.py +19 -18
- presets.py +11 -2
- utils.py +65 -49
app.py
CHANGED
@@ -42,14 +42,6 @@ else:
|
|
42 |
gr.Chatbot.postprocess = postprocess
|
43 |
|
44 |
with gr.Blocks(css=customCSS) as demo:
|
45 |
-
gr.HTML(title)
|
46 |
-
gr.HTML('''<center><a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="复制 Space"></a>强烈建议点击上面的按钮复制一份这个Space,在你自己的Space里运行,响应更迅速、也更安全👆</center>''')
|
47 |
-
with gr.Row():
|
48 |
-
with gr.Column(scale=4):
|
49 |
-
keyTxt = gr.Textbox(show_label=False, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY).style(container=True)
|
50 |
-
with gr.Column(scale=1):
|
51 |
-
use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
|
52 |
-
chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
|
53 |
history = gr.State([])
|
54 |
token_count = gr.State([])
|
55 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
@@ -57,6 +49,15 @@ with gr.Blocks(css=customCSS) as demo:
|
|
57 |
FALSECONSTANT = gr.State(False)
|
58 |
topic = gr.State("未命名对话历史记录")
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
with gr.Row():
|
61 |
with gr.Column(scale=12):
|
62 |
user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
|
@@ -69,8 +70,9 @@ with gr.Blocks(css=customCSS) as demo:
|
|
69 |
delLastBtn = gr.Button("🗑️ 删除最近一条对话")
|
70 |
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
71 |
status_display = gr.Markdown("status: ready")
|
72 |
-
|
73 |
-
|
|
|
74 |
with gr.Accordion(label="加载Prompt模板", open=False):
|
75 |
with gr.Column():
|
76 |
with gr.Row():
|
@@ -101,28 +103,27 @@ with gr.Blocks(css=customCSS) as demo:
|
|
101 |
#inputs, top_p, temperature, top_k, repetition_penalty
|
102 |
with gr.Accordion("参数", open=False):
|
103 |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
|
104 |
-
|
105 |
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
|
106 |
step=0.1, interactive=True, label="Temperature",)
|
107 |
-
|
108 |
-
#repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
|
109 |
gr.Markdown(description)
|
110 |
|
111 |
|
112 |
-
user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
113 |
user_input.submit(reset_textbox, [], [user_input])
|
114 |
|
115 |
-
submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
116 |
submitBtn.click(reset_textbox, [], [user_input])
|
117 |
|
118 |
emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
|
119 |
|
120 |
-
retryBtn.click(retry, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
121 |
|
122 |
-
delLastBtn.click(delete_last_conversation, [chatbot, history, token_count
|
123 |
chatbot, history, token_count, status_display], show_progress=True)
|
124 |
|
125 |
-
reduceTokenBtn.click(reduce_token_size, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
126 |
|
127 |
saveHistoryBtn.click(save_chat_history, [
|
128 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
|
|
42 |
gr.Chatbot.postprocess = postprocess
|
43 |
|
44 |
with gr.Blocks(css=customCSS) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
history = gr.State([])
|
46 |
token_count = gr.State([])
|
47 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
|
|
49 |
FALSECONSTANT = gr.State(False)
|
50 |
topic = gr.State("未命名对话历史记录")
|
51 |
|
52 |
+
gr.HTML(title)
|
53 |
+
with gr.Row():
|
54 |
+
with gr.Column():
|
55 |
+
keyTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY, label="API-Key")
|
56 |
+
with gr.Column():
|
57 |
+
with gr.Row():
|
58 |
+
model_select_dropdown = gr.Dropdown(label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0])
|
59 |
+
use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
|
60 |
+
chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
|
61 |
with gr.Row():
|
62 |
with gr.Column(scale=12):
|
63 |
user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
|
|
|
70 |
delLastBtn = gr.Button("🗑️ 删除最近一条对话")
|
71 |
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
72 |
status_display = gr.Markdown("status: ready")
|
73 |
+
|
74 |
+
systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...", label="System prompt", value=initial_prompt).style(container=True)
|
75 |
+
|
76 |
with gr.Accordion(label="加载Prompt模板", open=False):
|
77 |
with gr.Column():
|
78 |
with gr.Row():
|
|
|
103 |
#inputs, top_p, temperature, top_k, repetition_penalty
|
104 |
with gr.Accordion("参数", open=False):
|
105 |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
|
106 |
+
interactive=True, label="Top-p (nucleus sampling)",)
|
107 |
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
|
108 |
step=0.1, interactive=True, label="Temperature",)
|
109 |
+
|
|
|
110 |
gr.Markdown(description)
|
111 |
|
112 |
|
113 |
+
user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
114 |
user_input.submit(reset_textbox, [], [user_input])
|
115 |
|
116 |
+
submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
117 |
submitBtn.click(reset_textbox, [], [user_input])
|
118 |
|
119 |
emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
|
120 |
|
121 |
+
retryBtn.click(retry, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
122 |
|
123 |
+
delLastBtn.click(delete_last_conversation, [chatbot, history, token_count], [
|
124 |
chatbot, history, token_count, status_display], show_progress=True)
|
125 |
|
126 |
+
reduceTokenBtn.click(reduce_token_size, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
127 |
|
128 |
saveHistoryBtn.click(save_chat_history, [
|
129 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
presets.py
CHANGED
@@ -31,9 +31,18 @@ pre code {
|
|
31 |
}
|
32 |
"""
|
33 |
|
34 |
-
standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
|
35 |
-
error_retrieve_prompt = "连接超时,无法获取对话。请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
|
36 |
summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
max_token_streaming = 3500 # 流式对话时的最大 token 数
|
38 |
timeout_streaming = 15 # 流式对话时的超时时间
|
39 |
max_token_all = 3500 # 非流式对话时的最大 token 数
|
|
|
31 |
}
|
32 |
"""
|
33 |
|
|
|
|
|
34 |
summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
|
35 |
+
MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"] # 可选的模型
|
36 |
+
|
37 |
+
# 错误信息
|
38 |
+
standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
|
39 |
+
error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
|
40 |
+
connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
|
41 |
+
read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
|
42 |
+
proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
|
43 |
+
ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
|
44 |
+
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
45 |
+
|
46 |
max_token_streaming = 3500 # 流式对话时的最大 token 数
|
47 |
timeout_streaming = 15 # 流式对话时的超时时间
|
48 |
max_token_all = 3500 # 非流式对话时的最大 token 数
|
utils.py
CHANGED
@@ -99,7 +99,7 @@ def construct_assistant(text):
|
|
99 |
def construct_token_message(token, stream=False):
|
100 |
return f"Token 计数: {token}"
|
101 |
|
102 |
-
def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
|
103 |
headers = {
|
104 |
"Content-Type": "application/json",
|
105 |
"Authorization": f"Bearer {openai_api_key}"
|
@@ -108,7 +108,7 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
|
|
108 |
history = [construct_system(system_prompt), *history]
|
109 |
|
110 |
payload = {
|
111 |
-
"model":
|
112 |
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
|
113 |
"temperature": temperature, # 1.0,
|
114 |
"top_p": top_p, # 1.0,
|
@@ -124,40 +124,40 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
|
|
124 |
response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
|
125 |
return response
|
126 |
|
127 |
-
def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot,
|
128 |
def get_return_value():
|
129 |
-
return chatbot, history, status_text,
|
130 |
|
131 |
print("实时回答模式")
|
132 |
-
token_counter = 0
|
133 |
partial_words = ""
|
134 |
counter = 0
|
135 |
status_text = "开始实时传输回答……"
|
136 |
history.append(construct_user(inputs))
|
|
|
|
|
137 |
user_token_count = 0
|
138 |
-
if len(
|
139 |
system_prompt_token_count = count_token(system_prompt)
|
140 |
user_token_count = count_token(inputs) + system_prompt_token_count
|
141 |
else:
|
142 |
user_token_count = count_token(inputs)
|
|
|
143 |
print(f"输入token计数: {user_token_count}")
|
|
|
144 |
try:
|
145 |
-
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
|
146 |
except requests.exceptions.ConnectTimeout:
|
147 |
-
|
148 |
-
status_text = standard_error_msg + "连接超时,无法获取对话。" + error_retrieve_prompt
|
149 |
yield get_return_value()
|
150 |
return
|
151 |
except requests.exceptions.ReadTimeout:
|
152 |
-
|
153 |
-
status_text = standard_error_msg + "读取超时,无法获取对话。" + error_retrieve_prompt
|
154 |
yield get_return_value()
|
155 |
return
|
156 |
|
157 |
-
chatbot.append((parse_text(inputs), ""))
|
158 |
yield get_return_value()
|
159 |
|
160 |
-
for chunk in response.iter_lines():
|
161 |
if counter == 0:
|
162 |
counter += 1
|
163 |
continue
|
@@ -169,77 +169,93 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
169 |
try:
|
170 |
chunk = json.loads(chunk[6:])
|
171 |
except json.JSONDecodeError:
|
|
|
172 |
status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
|
173 |
yield get_return_value()
|
174 |
-
|
175 |
# decode each line as response data is in bytes
|
176 |
if chunklength > 6 and "delta" in chunk['choices'][0]:
|
177 |
finish_reason = chunk['choices'][0]['finish_reason']
|
178 |
-
status_text = construct_token_message(sum(
|
179 |
if finish_reason == "stop":
|
180 |
-
print("生成完毕")
|
181 |
yield get_return_value()
|
182 |
break
|
183 |
try:
|
184 |
partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
|
185 |
except KeyError:
|
186 |
-
status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(
|
187 |
yield get_return_value()
|
188 |
break
|
189 |
-
|
190 |
-
history.append(construct_assistant(" " + partial_words))
|
191 |
-
else:
|
192 |
-
history[-1] = construct_assistant(partial_words)
|
193 |
chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
|
194 |
-
|
195 |
yield get_return_value()
|
196 |
|
197 |
|
198 |
-
def predict_all(openai_api_key, system_prompt, history, inputs, chatbot,
|
199 |
print("一次性回答模式")
|
200 |
history.append(construct_user(inputs))
|
|
|
|
|
|
|
201 |
try:
|
202 |
-
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False)
|
203 |
except requests.exceptions.ConnectTimeout:
|
204 |
-
status_text = standard_error_msg + error_retrieve_prompt
|
205 |
-
return chatbot, history, status_text,
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
response = json.loads(response.text)
|
207 |
content = response["choices"][0]["message"]["content"]
|
208 |
-
history
|
209 |
chatbot.append((parse_text(inputs), parse_text(content)))
|
210 |
total_token_count = response["usage"]["total_tokens"]
|
211 |
-
|
212 |
status_text = construct_token_message(total_token_count)
|
213 |
-
|
214 |
-
return chatbot, history, status_text, previous_token_count
|
215 |
|
216 |
|
217 |
-
def predict(openai_api_key, system_prompt, history, inputs, chatbot,
|
218 |
print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
if stream:
|
220 |
print("使用流式传输")
|
221 |
-
iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot,
|
222 |
-
for chatbot, history, status_text,
|
223 |
-
yield chatbot, history, status_text,
|
224 |
else:
|
225 |
print("不使用流式传输")
|
226 |
-
chatbot, history, status_text,
|
227 |
-
yield chatbot, history, status_text,
|
228 |
-
print(f"传输完毕。当前token计数为{
|
229 |
-
|
|
|
230 |
if stream:
|
231 |
max_token = max_token_streaming
|
232 |
else:
|
233 |
max_token = max_token_all
|
234 |
-
if sum(
|
235 |
-
print(f"精简token中{
|
236 |
-
iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot,
|
237 |
-
for chatbot, history, status_text,
|
238 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
239 |
-
yield chatbot, history, status_text,
|
240 |
|
241 |
|
242 |
-
def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False):
|
243 |
print("重试中……")
|
244 |
if len(history) == 0:
|
245 |
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
@@ -247,15 +263,15 @@ def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, t
|
|
247 |
history.pop()
|
248 |
inputs = history.pop()["content"]
|
249 |
token_count.pop()
|
250 |
-
iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream)
|
251 |
print("重试完毕")
|
252 |
for x in iter:
|
253 |
yield x
|
254 |
|
255 |
|
256 |
-
def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False):
|
257 |
print("开始减少token数量……")
|
258 |
-
iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, should_check_token_count=False)
|
259 |
for chatbot, history, status_text, previous_token_count in iter:
|
260 |
history = history[-2:]
|
261 |
token_count = previous_token_count[-1:]
|
@@ -265,7 +281,7 @@ def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_cou
|
|
265 |
print("减少token数量完毕")
|
266 |
|
267 |
|
268 |
-
def delete_last_conversation(chatbot, history, previous_token_count
|
269 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
270 |
print("由于包含报错信息,只删除chatbot记录")
|
271 |
chatbot.pop()
|
@@ -280,7 +296,7 @@ def delete_last_conversation(chatbot, history, previous_token_count, streaming):
|
|
280 |
if len(previous_token_count) > 0:
|
281 |
print("删除了一组对话的token计数记录")
|
282 |
previous_token_count.pop()
|
283 |
-
return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count)
|
284 |
|
285 |
|
286 |
def save_chat_history(filename, system, history, chatbot):
|
|
|
99 |
def construct_token_message(token, stream=False):
|
100 |
return f"Token 计数: {token}"
|
101 |
|
102 |
+
def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model):
|
103 |
headers = {
|
104 |
"Content-Type": "application/json",
|
105 |
"Authorization": f"Bearer {openai_api_key}"
|
|
|
108 |
history = [construct_system(system_prompt), *history]
|
109 |
|
110 |
payload = {
|
111 |
+
"model": selected_model,
|
112 |
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
|
113 |
"temperature": temperature, # 1.0,
|
114 |
"top_p": top_p, # 1.0,
|
|
|
124 |
response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
|
125 |
return response
|
126 |
|
127 |
+
def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
|
128 |
def get_return_value():
|
129 |
+
return chatbot, history, status_text, all_token_counts
|
130 |
|
131 |
print("实时回答模式")
|
|
|
132 |
partial_words = ""
|
133 |
counter = 0
|
134 |
status_text = "开始实时传输回答……"
|
135 |
history.append(construct_user(inputs))
|
136 |
+
history.append(construct_assistant(""))
|
137 |
+
chatbot.append((parse_text(inputs), ""))
|
138 |
user_token_count = 0
|
139 |
+
if len(all_token_counts) == 0:
|
140 |
system_prompt_token_count = count_token(system_prompt)
|
141 |
user_token_count = count_token(inputs) + system_prompt_token_count
|
142 |
else:
|
143 |
user_token_count = count_token(inputs)
|
144 |
+
all_token_counts.append(user_token_count)
|
145 |
print(f"输入token计数: {user_token_count}")
|
146 |
+
yield get_return_value()
|
147 |
try:
|
148 |
+
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True, selected_model)
|
149 |
except requests.exceptions.ConnectTimeout:
|
150 |
+
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
|
|
151 |
yield get_return_value()
|
152 |
return
|
153 |
except requests.exceptions.ReadTimeout:
|
154 |
+
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
|
|
155 |
yield get_return_value()
|
156 |
return
|
157 |
|
|
|
158 |
yield get_return_value()
|
159 |
|
160 |
+
for chunk in tqdm(response.iter_lines()):
|
161 |
if counter == 0:
|
162 |
counter += 1
|
163 |
continue
|
|
|
169 |
try:
|
170 |
chunk = json.loads(chunk[6:])
|
171 |
except json.JSONDecodeError:
|
172 |
+
print(chunk)
|
173 |
status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
|
174 |
yield get_return_value()
|
175 |
+
continue
|
176 |
# decode each line as response data is in bytes
|
177 |
if chunklength > 6 and "delta" in chunk['choices'][0]:
|
178 |
finish_reason = chunk['choices'][0]['finish_reason']
|
179 |
+
status_text = construct_token_message(sum(all_token_counts), stream=True)
|
180 |
if finish_reason == "stop":
|
|
|
181 |
yield get_return_value()
|
182 |
break
|
183 |
try:
|
184 |
partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
|
185 |
except KeyError:
|
186 |
+
status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(all_token_counts))
|
187 |
yield get_return_value()
|
188 |
break
|
189 |
+
history[-1] = construct_assistant(partial_words)
|
|
|
|
|
|
|
190 |
chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
|
191 |
+
all_token_counts[-1] += 1
|
192 |
yield get_return_value()
|
193 |
|
194 |
|
195 |
+
def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
|
196 |
print("一次性回答模式")
|
197 |
history.append(construct_user(inputs))
|
198 |
+
history.append(construct_assistant(""))
|
199 |
+
chatbot.append((parse_text(inputs), ""))
|
200 |
+
all_token_counts.append(count_token(inputs))
|
201 |
try:
|
202 |
+
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False, selected_model)
|
203 |
except requests.exceptions.ConnectTimeout:
|
204 |
+
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
205 |
+
return chatbot, history, status_text, all_token_counts
|
206 |
+
except requests.exceptions.ProxyError:
|
207 |
+
status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
|
208 |
+
return chatbot, history, status_text, all_token_counts
|
209 |
+
except requests.exceptions.SSLError:
|
210 |
+
status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
|
211 |
+
return chatbot, history, status_text, all_token_counts
|
212 |
response = json.loads(response.text)
|
213 |
content = response["choices"][0]["message"]["content"]
|
214 |
+
history[-1] = construct_assistant(content)
|
215 |
chatbot.append((parse_text(inputs), parse_text(content)))
|
216 |
total_token_count = response["usage"]["total_tokens"]
|
217 |
+
all_token_counts[-1] = total_token_count - sum(all_token_counts)
|
218 |
status_text = construct_token_message(total_token_count)
|
219 |
+
return chatbot, history, status_text, all_token_counts
|
|
|
220 |
|
221 |
|
222 |
+
def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model = MODELS[0], should_check_token_count = True): # repetition_penalty, top_k
|
223 |
print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
224 |
+
if len(openai_api_key) != 51:
|
225 |
+
status_text = standard_error_msg + no_apikey_msg
|
226 |
+
print(status_text)
|
227 |
+
history.append(construct_user(inputs))
|
228 |
+
history.append("")
|
229 |
+
chatbot.append((parse_text(inputs), ""))
|
230 |
+
all_token_counts.append(0)
|
231 |
+
yield chatbot, history, status_text, all_token_counts
|
232 |
+
return
|
233 |
+
yield chatbot, history, "开始生成回答……", all_token_counts
|
234 |
if stream:
|
235 |
print("使用流式传输")
|
236 |
+
iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
|
237 |
+
for chatbot, history, status_text, all_token_counts in iter:
|
238 |
+
yield chatbot, history, status_text, all_token_counts
|
239 |
else:
|
240 |
print("不使用流式传输")
|
241 |
+
chatbot, history, status_text, all_token_counts = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
|
242 |
+
yield chatbot, history, status_text, all_token_counts
|
243 |
+
print(f"传输完毕。当前token计数为{all_token_counts}")
|
244 |
+
if len(history) > 1 and history[-1]['content'] != inputs:
|
245 |
+
print("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
|
246 |
if stream:
|
247 |
max_token = max_token_streaming
|
248 |
else:
|
249 |
max_token = max_token_all
|
250 |
+
if sum(all_token_counts) > max_token and should_check_token_count:
|
251 |
+
print(f"精简token中{all_token_counts}/{max_token}")
|
252 |
+
iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, all_token_counts, top_p, temperature, stream=False, hidden=True)
|
253 |
+
for chatbot, history, status_text, all_token_counts in iter:
|
254 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
255 |
+
yield chatbot, history, status_text, all_token_counts
|
256 |
|
257 |
|
258 |
+
def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, selected_model = MODELS[0]):
|
259 |
print("重试中……")
|
260 |
if len(history) == 0:
|
261 |
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
|
|
263 |
history.pop()
|
264 |
inputs = history.pop()["content"]
|
265 |
token_count.pop()
|
266 |
+
iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream, selected_model=selected_model)
|
267 |
print("重试完毕")
|
268 |
for x in iter:
|
269 |
yield x
|
270 |
|
271 |
|
272 |
+
def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False, selected_model = MODELS[0]):
|
273 |
print("开始减少token数量……")
|
274 |
+
iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, selected_model = selected_model, should_check_token_count=False)
|
275 |
for chatbot, history, status_text, previous_token_count in iter:
|
276 |
history = history[-2:]
|
277 |
token_count = previous_token_count[-1:]
|
|
|
281 |
print("减少token数量完毕")
|
282 |
|
283 |
|
284 |
+
def delete_last_conversation(chatbot, history, previous_token_count):
|
285 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
286 |
print("由于包含报错信息,只删除chatbot记录")
|
287 |
chatbot.pop()
|
|
|
296 |
if len(previous_token_count) > 0:
|
297 |
print("删除了一组对话的token计数记录")
|
298 |
previous_token_count.pop()
|
299 |
+
return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count))
|
300 |
|
301 |
|
302 |
def save_chat_history(filename, system, history, chatbot):
|