JohnSmith9982 commited on
Commit
dc90d99
·
1 Parent(s): a51e754

回滚版本

Browse files
Files changed (4) hide show
  1. app.py +5 -11
  2. presets.py +15 -47
  3. requirements.txt +1 -3
  4. utils.py +418 -24
app.py CHANGED
@@ -1,17 +1,14 @@
1
  # -*- coding:utf-8 -*-
 
2
  import os
3
  import logging
4
  import sys
5
-
6
- import gradio as gr
7
-
8
  from utils import *
9
  from presets import *
10
- from overwrites import *
11
- from chat_func import *
12
 
13
  logging.basicConfig(
14
- level=logging.DEBUG,
15
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
16
  )
17
 
@@ -52,7 +49,6 @@ else:
52
  authflag = True
53
 
54
  gr.Chatbot.postprocess = postprocess
55
- PromptHelper.compact_text_chunks = compact_text_chunks
56
 
57
  with open("custom.css", "r", encoding="utf-8") as f:
58
  customCSS = f.read()
@@ -160,7 +156,7 @@ with gr.Blocks(
160
  value=hide_middle_chars(my_api_key),
161
  type="password",
162
  visible=not HIDE_MY_KEY,
163
- label="API-Key(按Enter提交)",
164
  )
165
  model_select_dropdown = gr.Dropdown(
166
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
@@ -169,7 +165,7 @@ with gr.Blocks(
169
  label="实时传输回答", value=True, visible=enable_streaming_option
170
  )
171
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
172
- index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
173
 
174
  with gr.Tab(label="Prompt"):
175
  systemPromptTxt = gr.Textbox(
@@ -290,7 +286,6 @@ with gr.Blocks(
290
  use_streaming_checkbox,
291
  model_select_dropdown,
292
  use_websearch_checkbox,
293
- index_files
294
  ],
295
  [chatbot, history, status_display, token_count],
296
  show_progress=True,
@@ -311,7 +306,6 @@ with gr.Blocks(
311
  use_streaming_checkbox,
312
  model_select_dropdown,
313
  use_websearch_checkbox,
314
- index_files
315
  ],
316
  [chatbot, history, status_display, token_count],
317
  show_progress=True,
 
1
  # -*- coding:utf-8 -*-
2
+ import gradio as gr
3
  import os
4
  import logging
5
  import sys
6
+ import argparse
 
 
7
  from utils import *
8
  from presets import *
 
 
9
 
10
  logging.basicConfig(
11
+ level=logging.INFO,
12
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
13
  )
14
 
 
49
  authflag = True
50
 
51
  gr.Chatbot.postprocess = postprocess
 
52
 
53
  with open("custom.css", "r", encoding="utf-8") as f:
54
  customCSS = f.read()
 
156
  value=hide_middle_chars(my_api_key),
157
  type="password",
158
  visible=not HIDE_MY_KEY,
159
+ label="API-Key",
160
  )
161
  model_select_dropdown = gr.Dropdown(
162
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
 
165
  label="实时传输回答", value=True, visible=enable_streaming_option
166
  )
167
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
168
+ index_files = gr.File(label="上传索引文件", type="file", multiple=True)
169
 
170
  with gr.Tab(label="Prompt"):
171
  systemPromptTxt = gr.Textbox(
 
286
  use_streaming_checkbox,
287
  model_select_dropdown,
288
  use_websearch_checkbox,
 
289
  ],
290
  [chatbot, history, status_display, token_count],
291
  show_progress=True,
 
306
  use_streaming_checkbox,
307
  model_select_dropdown,
308
  use_websearch_checkbox,
 
309
  ],
310
  [chatbot, history, status_display, token_count],
311
  show_progress=True,
presets.py CHANGED
@@ -1,23 +1,4 @@
1
  # -*- coding:utf-8 -*-
2
- # 错误信息
3
- standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
4
- error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
5
- connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
6
- read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
7
- proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
8
- ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
9
- no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
10
-
11
- max_token_streaming = 3500 # 流式对话时的最大 token 数
12
- timeout_streaming = 30 # 流式对话时的超时时间
13
- max_token_all = 3500 # 非流式对话时的最大 token 数
14
- timeout_all = 200 # 非流式对话时的超时时间
15
- enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
16
- HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
17
-
18
- SIM_K = 5
19
- INDEX_QUERY_TEMPRATURE = 1.0
20
-
21
  title = """<h1 align="left" style="min-width:200px; margin-top:0;">川虎ChatGPT 🚀</h1>"""
22
  description = """\
23
  <div align="center" style="margin:16px 0">
@@ -31,7 +12,6 @@ description = """\
31
  """
32
 
33
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
34
-
35
  MODELS = [
36
  "gpt-3.5-turbo",
37
  "gpt-3.5-turbo-0301",
@@ -41,8 +21,7 @@ MODELS = [
41
  "gpt-4-32k-0314",
42
  ] # 可选的模型
43
 
44
-
45
- WEBSEARCH_PTOMPT_TEMPLATE = """\
46
  Web search results:
47
 
48
  {web_results}
@@ -52,29 +31,18 @@ Instructions: Using the provided web search results, write a comprehensive reply
52
  Query: {query}
53
  Reply in 中文"""
54
 
55
- PROMPT_TEMPLATE = """\
56
- Context information is below.
57
- ---------------------
58
- {context_str}
59
- ---------------------
60
- Current date: {current_date}.
61
- Using the provided context information, write a comprehensive reply to the given query.
62
- Make sure to cite results using [number] notation after the reference.
63
- If the provided context information refer to multiple subjects with the same name, write separate answers for each subject.
64
- Use prior knowledge only if the given context didn't provide enough information.
65
- Answer the question: {query_str}
66
- Reply in 中文
67
- """
68
 
69
- REFINE_TEMPLATE = """\
70
- The original question is as follows: {query_str}
71
- We have provided an existing answer: {existing_answer}
72
- We have the opportunity to refine the existing answer
73
- (only if needed) with some more context below.
74
- ------------
75
- {context_msg}
76
- ------------
77
- Given the new context, refine the original answer to better
78
- Answer in the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch.
79
- If the context isn't useful, return the original answer.
80
- """
 
1
  # -*- coding:utf-8 -*-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  title = """<h1 align="left" style="min-width:200px; margin-top:0;">川虎ChatGPT 🚀</h1>"""
3
  description = """\
4
  <div align="center" style="margin:16px 0">
 
12
  """
13
 
14
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
 
15
  MODELS = [
16
  "gpt-3.5-turbo",
17
  "gpt-3.5-turbo-0301",
 
21
  "gpt-4-32k-0314",
22
  ] # 可选的模型
23
 
24
+ websearch_prompt = """\
 
25
  Web search results:
26
 
27
  {web_results}
 
31
  Query: {query}
32
  Reply in 中文"""
33
 
34
+ # 错误信息
35
+ standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
36
+ error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
37
+ connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
38
+ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
39
+ proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
40
+ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
41
+ no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51
 
 
 
 
 
42
 
43
+ max_token_streaming = 3500 # 流式对话时的最大 token 数
44
+ timeout_streaming = 30 # 流式对话时的超时时间
45
+ max_token_all = 3500 # 非流式对话时的最大 token
46
+ timeout_all = 200 # 非流式对话时的超时时间
47
+ enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
48
+ HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
 
 
 
 
 
 
requirements.txt CHANGED
@@ -6,6 +6,4 @@ socksio
6
  tqdm
7
  colorama
8
  duckduckgo_search
9
- Pygments
10
- llama_index
11
- langchain
 
6
  tqdm
7
  colorama
8
  duckduckgo_search
9
+ Pygments
 
 
utils.py CHANGED
@@ -3,16 +3,23 @@ from __future__ import annotations
3
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
4
  import logging
5
  import json
 
 
 
6
  import os
7
- import datetime
8
- import hashlib
9
- import csv
10
 
11
- import gradio as gr
 
 
12
  from pypinyin import lazy_pinyin
13
- import tiktoken
14
-
15
  from presets import *
 
 
 
 
 
16
 
17
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
18
 
@@ -30,6 +37,27 @@ HISTORY_DIR = "history"
30
  TEMPLATES_DIR = "templates"
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def count_token(message):
34
  encoding = tiktoken.get_encoding("cl100k_base")
35
  input_str = f"role: {message['role']}, content: {message['content']}"
@@ -74,6 +102,389 @@ def construct_token_message(token, stream=False):
74
  return f"Token 计数: {token}"
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def delete_last_conversation(chatbot, history, previous_token_count):
78
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
79
  logging.info("由于包含报错信息,只删除chatbot记录")
@@ -232,7 +643,6 @@ def reset_state():
232
  def reset_textbox():
233
  return gr.update(value="")
234
 
235
-
236
  def reset_default():
237
  global API_URL
238
  API_URL = "https://api.openai.com/v1/chat/completions"
@@ -240,7 +650,6 @@ def reset_default():
240
  os.environ.pop("https_proxy", None)
241
  return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
242
 
243
-
244
  def change_api_url(url):
245
  global API_URL
246
  API_URL = url
@@ -248,37 +657,22 @@ def change_api_url(url):
248
  logging.info(msg)
249
  return msg
250
 
251
-
252
  def change_proxy(proxy):
253
  os.environ["HTTPS_PROXY"] = proxy
254
  msg = f"代理更改为了{proxy}"
255
  logging.info(msg)
256
  return msg
257
 
258
-
259
  def hide_middle_chars(s):
260
  if len(s) <= 8:
261
  return s
262
  else:
263
  head = s[:4]
264
  tail = s[-4:]
265
- hidden = "*" * (len(s) - 8)
266
  return head + hidden + tail
267
 
268
-
269
  def submit_key(key):
270
- key = key.strip()
271
  msg = f"API密钥更改为了{hide_middle_chars(key)}"
272
  logging.info(msg)
273
  return key, msg
274
-
275
-
276
- def sha1sum(filename):
277
- sha1 = hashlib.sha1()
278
- sha1.update(filename.encode("utf-8"))
279
- return sha1.hexdigest()
280
-
281
-
282
- def replace_today(prompt):
283
- today = datetime.datetime.today().strftime("%Y-%m-%d")
284
- return prompt.replace("{current_date}", today)
 
3
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
4
  import logging
5
  import json
6
+ import gradio as gr
7
+
8
+ # import openai
9
  import os
10
+ import traceback
11
+ import requests
 
12
 
13
+ # import markdown
14
+ import csv
15
+ import mdtex2html
16
  from pypinyin import lazy_pinyin
 
 
17
  from presets import *
18
+ import tiktoken
19
+ from tqdm import tqdm
20
+ import colorama
21
+ from duckduckgo_search import ddg
22
+ import datetime
23
 
24
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
25
 
 
37
  TEMPLATES_DIR = "templates"
38
 
39
 
40
+ def postprocess(
41
+ self, y: List[Tuple[str | None, str | None]]
42
+ ) -> List[Tuple[str | None, str | None]]:
43
+ """
44
+ Parameters:
45
+ y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
46
+ Returns:
47
+ List of tuples representing the message and response. Each message and response will be a string of HTML.
48
+ """
49
+ if y is None:
50
+ return []
51
+ for i, (message, response) in enumerate(y):
52
+ y[i] = (
53
+ # None if message is None else markdown.markdown(message),
54
+ # None if response is None else markdown.markdown(response),
55
+ None if message is None else message,
56
+ None if response is None else mdtex2html.convert(response, extensions=['fenced_code','codehilite','tables']),
57
+ )
58
+ return y
59
+
60
+
61
  def count_token(message):
62
  encoding = tiktoken.get_encoding("cl100k_base")
63
  input_str = f"role: {message['role']}, content: {message['content']}"
 
102
  return f"Token 计数: {token}"
103
 
104
 
105
+ def get_response(
106
+ openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
107
+ ):
108
+ headers = {
109
+ "Content-Type": "application/json",
110
+ "Authorization": f"Bearer {openai_api_key}",
111
+ }
112
+
113
+ history = [construct_system(system_prompt), *history]
114
+
115
+ payload = {
116
+ "model": selected_model,
117
+ "messages": history, # [{"role": "user", "content": f"{inputs}"}],
118
+ "temperature": temperature, # 1.0,
119
+ "top_p": top_p, # 1.0,
120
+ "n": 1,
121
+ "stream": stream,
122
+ "presence_penalty": 0,
123
+ "frequency_penalty": 0,
124
+ }
125
+ if stream:
126
+ timeout = timeout_streaming
127
+ else:
128
+ timeout = timeout_all
129
+
130
+ # 获取环境变量中的代理设置
131
+ http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
132
+ https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
133
+
134
+ # 如果存在代理设置,使用它们
135
+ proxies = {}
136
+ if http_proxy:
137
+ logging.info(f"Using HTTP proxy: {http_proxy}")
138
+ proxies["http"] = http_proxy
139
+ if https_proxy:
140
+ logging.info(f"Using HTTPS proxy: {https_proxy}")
141
+ proxies["https"] = https_proxy
142
+
143
+ # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
144
+ if proxies:
145
+ response = requests.post(
146
+ API_URL,
147
+ headers=headers,
148
+ json=payload,
149
+ stream=True,
150
+ timeout=timeout,
151
+ proxies=proxies,
152
+ )
153
+ else:
154
+ response = requests.post(
155
+ API_URL,
156
+ headers=headers,
157
+ json=payload,
158
+ stream=True,
159
+ timeout=timeout,
160
+ )
161
+ return response
162
+
163
+
164
+ def stream_predict(
165
+ openai_api_key,
166
+ system_prompt,
167
+ history,
168
+ inputs,
169
+ chatbot,
170
+ all_token_counts,
171
+ top_p,
172
+ temperature,
173
+ selected_model,
174
+ ):
175
+ def get_return_value():
176
+ return chatbot, history, status_text, all_token_counts
177
+
178
+ logging.info("实时回答模式")
179
+ partial_words = ""
180
+ counter = 0
181
+ status_text = "开始实时传输回答……"
182
+ history.append(construct_user(inputs))
183
+ history.append(construct_assistant(""))
184
+ chatbot.append((parse_text(inputs), ""))
185
+ user_token_count = 0
186
+ if len(all_token_counts) == 0:
187
+ system_prompt_token_count = count_token(construct_system(system_prompt))
188
+ user_token_count = (
189
+ count_token(construct_user(inputs)) + system_prompt_token_count
190
+ )
191
+ else:
192
+ user_token_count = count_token(construct_user(inputs))
193
+ all_token_counts.append(user_token_count)
194
+ logging.info(f"输入token计数: {user_token_count}")
195
+ yield get_return_value()
196
+ try:
197
+ response = get_response(
198
+ openai_api_key,
199
+ system_prompt,
200
+ history,
201
+ temperature,
202
+ top_p,
203
+ True,
204
+ selected_model,
205
+ )
206
+ except requests.exceptions.ConnectTimeout:
207
+ status_text = (
208
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
209
+ )
210
+ yield get_return_value()
211
+ return
212
+ except requests.exceptions.ReadTimeout:
213
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
214
+ yield get_return_value()
215
+ return
216
+
217
+ yield get_return_value()
218
+ error_json_str = ""
219
+
220
+ for chunk in tqdm(response.iter_lines()):
221
+ if counter == 0:
222
+ counter += 1
223
+ continue
224
+ counter += 1
225
+ # check whether each line is non-empty
226
+ if chunk:
227
+ chunk = chunk.decode()
228
+ chunklength = len(chunk)
229
+ try:
230
+ chunk = json.loads(chunk[6:])
231
+ except json.JSONDecodeError:
232
+ logging.info(chunk)
233
+ error_json_str += chunk
234
+ status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
235
+ yield get_return_value()
236
+ continue
237
+ # decode each line as response data is in bytes
238
+ if chunklength > 6 and "delta" in chunk["choices"][0]:
239
+ finish_reason = chunk["choices"][0]["finish_reason"]
240
+ status_text = construct_token_message(
241
+ sum(all_token_counts), stream=True
242
+ )
243
+ if finish_reason == "stop":
244
+ yield get_return_value()
245
+ break
246
+ try:
247
+ partial_words = (
248
+ partial_words + chunk["choices"][0]["delta"]["content"]
249
+ )
250
+ except KeyError:
251
+ status_text = (
252
+ standard_error_msg
253
+ + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
254
+ + str(sum(all_token_counts))
255
+ )
256
+ yield get_return_value()
257
+ break
258
+ history[-1] = construct_assistant(partial_words)
259
+ chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
260
+ all_token_counts[-1] += 1
261
+ yield get_return_value()
262
+
263
+
264
+ def predict_all(
265
+ openai_api_key,
266
+ system_prompt,
267
+ history,
268
+ inputs,
269
+ chatbot,
270
+ all_token_counts,
271
+ top_p,
272
+ temperature,
273
+ selected_model,
274
+ ):
275
+ logging.info("一次性回答模式")
276
+ history.append(construct_user(inputs))
277
+ history.append(construct_assistant(""))
278
+ chatbot.append((parse_text(inputs), ""))
279
+ all_token_counts.append(count_token(construct_user(inputs)))
280
+ try:
281
+ response = get_response(
282
+ openai_api_key,
283
+ system_prompt,
284
+ history,
285
+ temperature,
286
+ top_p,
287
+ False,
288
+ selected_model,
289
+ )
290
+ except requests.exceptions.ConnectTimeout:
291
+ status_text = (
292
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
293
+ )
294
+ return chatbot, history, status_text, all_token_counts
295
+ except requests.exceptions.ProxyError:
296
+ status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
297
+ return chatbot, history, status_text, all_token_counts
298
+ except requests.exceptions.SSLError:
299
+ status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
300
+ return chatbot, history, status_text, all_token_counts
301
+ response = json.loads(response.text)
302
+ content = response["choices"][0]["message"]["content"]
303
+ history[-1] = construct_assistant(content)
304
+ chatbot[-1] = (parse_text(inputs), parse_text(content))
305
+ total_token_count = response["usage"]["total_tokens"]
306
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
307
+ status_text = construct_token_message(total_token_count)
308
+ return chatbot, history, status_text, all_token_counts
309
+
310
+
311
+ def predict(
312
+ openai_api_key,
313
+ system_prompt,
314
+ history,
315
+ inputs,
316
+ chatbot,
317
+ all_token_counts,
318
+ top_p,
319
+ temperature,
320
+ stream=False,
321
+ selected_model=MODELS[0],
322
+ use_websearch_checkbox=False,
323
+ should_check_token_count=True,
324
+ ): # repetition_penalty, top_k
325
+ logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
326
+ if use_websearch_checkbox:
327
+ results = ddg(inputs, max_results=3)
328
+ web_results = []
329
+ for idx, result in enumerate(results):
330
+ logging.info(f"搜索结果{idx + 1}:{result}")
331
+ web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
332
+ web_results = "\n\n".join(web_results)
333
+ today = datetime.datetime.today().strftime("%Y-%m-%d")
334
+ inputs = (
335
+ websearch_prompt.replace("{current_date}", today)
336
+ .replace("{query}", inputs)
337
+ .replace("{web_results}", web_results)
338
+ )
339
+ if len(openai_api_key) != 51:
340
+ status_text = standard_error_msg + no_apikey_msg
341
+ logging.info(status_text)
342
+ chatbot.append((parse_text(inputs), ""))
343
+ if len(history) == 0:
344
+ history.append(construct_user(inputs))
345
+ history.append("")
346
+ all_token_counts.append(0)
347
+ else:
348
+ history[-2] = construct_user(inputs)
349
+ yield chatbot, history, status_text, all_token_counts
350
+ return
351
+ if stream:
352
+ yield chatbot, history, "开始生成回答……", all_token_counts
353
+ if stream:
354
+ logging.info("使用流式传输")
355
+ iter = stream_predict(
356
+ openai_api_key,
357
+ system_prompt,
358
+ history,
359
+ inputs,
360
+ chatbot,
361
+ all_token_counts,
362
+ top_p,
363
+ temperature,
364
+ selected_model,
365
+ )
366
+ for chatbot, history, status_text, all_token_counts in iter:
367
+ yield chatbot, history, status_text, all_token_counts
368
+ else:
369
+ logging.info("不使用流式传输")
370
+ chatbot, history, status_text, all_token_counts = predict_all(
371
+ openai_api_key,
372
+ system_prompt,
373
+ history,
374
+ inputs,
375
+ chatbot,
376
+ all_token_counts,
377
+ top_p,
378
+ temperature,
379
+ selected_model,
380
+ )
381
+ yield chatbot, history, status_text, all_token_counts
382
+ logging.info(f"传输完毕。当前token计数为{all_token_counts}")
383
+ if len(history) > 1 and history[-1]["content"] != inputs:
384
+ logging.info(
385
+ "回答为:"
386
+ + colorama.Fore.BLUE
387
+ + f"{history[-1]['content']}"
388
+ + colorama.Style.RESET_ALL
389
+ )
390
+ if stream:
391
+ max_token = max_token_streaming
392
+ else:
393
+ max_token = max_token_all
394
+ if sum(all_token_counts) > max_token and should_check_token_count:
395
+ status_text = f"精简token中{all_token_counts}/{max_token}"
396
+ logging.info(status_text)
397
+ yield chatbot, history, status_text, all_token_counts
398
+ iter = reduce_token_size(
399
+ openai_api_key,
400
+ system_prompt,
401
+ history,
402
+ chatbot,
403
+ all_token_counts,
404
+ top_p,
405
+ temperature,
406
+ stream=False,
407
+ selected_model=selected_model,
408
+ hidden=True,
409
+ )
410
+ for chatbot, history, status_text, all_token_counts in iter:
411
+ status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
412
+ yield chatbot, history, status_text, all_token_counts
413
+
414
+
415
+ def retry(
416
+ openai_api_key,
417
+ system_prompt,
418
+ history,
419
+ chatbot,
420
+ token_count,
421
+ top_p,
422
+ temperature,
423
+ stream=False,
424
+ selected_model=MODELS[0],
425
+ ):
426
+ logging.info("重试中……")
427
+ if len(history) == 0:
428
+ yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
429
+ return
430
+ history.pop()
431
+ inputs = history.pop()["content"]
432
+ token_count.pop()
433
+ iter = predict(
434
+ openai_api_key,
435
+ system_prompt,
436
+ history,
437
+ inputs,
438
+ chatbot,
439
+ token_count,
440
+ top_p,
441
+ temperature,
442
+ stream=stream,
443
+ selected_model=selected_model,
444
+ )
445
+ logging.info("重试完毕")
446
+ for x in iter:
447
+ yield x
448
+
449
+
450
+ def reduce_token_size(
451
+ openai_api_key,
452
+ system_prompt,
453
+ history,
454
+ chatbot,
455
+ token_count,
456
+ top_p,
457
+ temperature,
458
+ stream=False,
459
+ selected_model=MODELS[0],
460
+ hidden=False,
461
+ ):
462
+ logging.info("开始减少token数量……")
463
+ iter = predict(
464
+ openai_api_key,
465
+ system_prompt,
466
+ history,
467
+ summarize_prompt,
468
+ chatbot,
469
+ token_count,
470
+ top_p,
471
+ temperature,
472
+ stream=stream,
473
+ selected_model=selected_model,
474
+ should_check_token_count=False,
475
+ )
476
+ logging.info(f"chatbot: {chatbot}")
477
+ for chatbot, history, status_text, previous_token_count in iter:
478
+ history = history[-2:]
479
+ token_count = previous_token_count[-1:]
480
+ if hidden:
481
+ chatbot.pop()
482
+ yield chatbot, history, construct_token_message(
483
+ sum(token_count), stream=stream
484
+ ), token_count
485
+ logging.info("减少token数量完毕")
486
+
487
+
488
  def delete_last_conversation(chatbot, history, previous_token_count):
489
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
490
  logging.info("由于包含报错信息,只删除chatbot记录")
 
643
  def reset_textbox():
644
  return gr.update(value="")
645
 
 
646
  def reset_default():
647
  global API_URL
648
  API_URL = "https://api.openai.com/v1/chat/completions"
 
650
  os.environ.pop("https_proxy", None)
651
  return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
652
 
 
653
  def change_api_url(url):
654
  global API_URL
655
  API_URL = url
 
657
  logging.info(msg)
658
  return msg
659
 
 
660
  def change_proxy(proxy):
661
  os.environ["HTTPS_PROXY"] = proxy
662
  msg = f"代理更改为了{proxy}"
663
  logging.info(msg)
664
  return msg
665
 
 
666
  def hide_middle_chars(s):
667
  if len(s) <= 8:
668
  return s
669
  else:
670
  head = s[:4]
671
  tail = s[-4:]
672
+ hidden = '*' * (len(s) - 8)
673
  return head + hidden + tail
674
 
 
675
  def submit_key(key):
 
676
  msg = f"API密钥更改为了{hide_middle_chars(key)}"
677
  logging.info(msg)
678
  return key, msg