AllenYkl commited on
Commit
da24d95
·
1 Parent(s): 9bc5b0a

Update bin_public/app/chat_func.py

Browse files
Files changed (1) hide show
  1. bin_public/app/chat_func.py +28 -17
bin_public/app/chat_func.py CHANGED
@@ -1,13 +1,12 @@
1
  # -*- coding:utf-8 -*-
2
  from __future__ import annotations
3
 
4
- import requests
5
  import urllib3
6
 
7
  from tqdm import tqdm
8
  from duckduckgo_search import ddg
9
  from llama_func import *
10
- from bin_public.utils.tools import *
11
  from bin_public.utils.utils_db import *
12
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
13
 
@@ -94,7 +93,8 @@ def stream_predict(
94
  top_p,
95
  temperature,
96
  selected_model,
97
- fake_input=None
 
98
  ):
99
  def get_return_value():
100
  return chatbot, history, status_text, all_token_counts
@@ -106,9 +106,9 @@ def stream_predict(
106
  history.append(construct_user(inputs))
107
  history.append(construct_assistant(""))
108
  if fake_input:
109
- chatbot.append((parse_text(fake_input), ""))
110
  else:
111
- chatbot.append((parse_text(inputs), ""))
112
  user_token_count = 0
113
  if len(all_token_counts) == 0:
114
  system_prompt_token_count = count_token(construct_system(system_prompt))
@@ -183,7 +183,7 @@ def stream_predict(
183
  yield get_return_value()
184
  break
185
  history[-1] = construct_assistant(partial_words)
186
- chatbot[-1] = (chatbot[-1][0], parse_text(partial_words))
187
  all_token_counts[-1] += 1
188
  yield get_return_value()
189
 
@@ -198,15 +198,16 @@ def predict_all(
198
  top_p,
199
  temperature,
200
  selected_model,
201
- fake_input=None
 
202
  ):
203
  logging.info("一次性回答模式")
204
  history.append(construct_user(inputs))
205
  history.append(construct_assistant(""))
206
  if fake_input:
207
- chatbot.append((parse_text(fake_input), ""))
208
  else:
209
- chatbot.append((parse_text(inputs), ""))
210
  all_token_counts.append(count_token(construct_user(inputs)))
211
  try:
212
  response = get_response(
@@ -232,7 +233,7 @@ def predict_all(
232
  response = json.loads(response.text)
233
  content = response["choices"][0]["message"]["content"]
234
  history[-1] = construct_assistant(content)
235
- chatbot[-1] = (chatbot[-1][0], parse_text(content))
236
  total_token_count = response["usage"]["total_tokens"]
237
  all_token_counts[-1] = total_token_count - sum(all_token_counts)
238
  status_text = construct_token_message(total_token_count)
@@ -287,7 +288,7 @@ def predict(
287
  if len(openai_api_key) != 51:
288
  status_text = standard_error_msg + no_apikey_msg
289
  logging.info(status_text)
290
- chatbot.append((parse_text(inputs), ""))
291
  if len(history) == 0:
292
  history.append(construct_user(inputs))
293
  history.append("")
@@ -341,12 +342,22 @@ def predict(
341
  holo_query_insert_chat_message(invite_code, inputs, history[-1]['content'], token, history)
342
 
343
  if use_websearch:
344
- response = history[-1]['content']
345
- response += "\n\n" + "\n".join(link_references)
346
- logging.info(f"Added link references.")
347
- logging.info(response)
348
- chatbot[-1] = (parse_text(old_inputs), response)
349
- yield chatbot, history, status_text, all_token_counts
 
 
 
 
 
 
 
 
 
 
350
 
351
  if stream:
352
  max_token = max_token_streaming
 
1
  # -*- coding:utf-8 -*-
2
  from __future__ import annotations
3
 
 
4
  import urllib3
5
 
6
  from tqdm import tqdm
7
  from duckduckgo_search import ddg
8
  from llama_func import *
9
+ from bin_public.utils.utils import *
10
  from bin_public.utils.utils_db import *
11
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
12
 
 
93
  top_p,
94
  temperature,
95
  selected_model,
96
+ fake_input=None,
97
+ display_append=""
98
  ):
99
  def get_return_value():
100
  return chatbot, history, status_text, all_token_counts
 
106
  history.append(construct_user(inputs))
107
  history.append(construct_assistant(""))
108
  if fake_input:
109
+ chatbot.append((fake_input, ""))
110
  else:
111
+ chatbot.append((inputs, ""))
112
  user_token_count = 0
113
  if len(all_token_counts) == 0:
114
  system_prompt_token_count = count_token(construct_system(system_prompt))
 
183
  yield get_return_value()
184
  break
185
  history[-1] = construct_assistant(partial_words)
186
+ chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
187
  all_token_counts[-1] += 1
188
  yield get_return_value()
189
 
 
198
  top_p,
199
  temperature,
200
  selected_model,
201
+ fake_input=None,
202
+ display_append=""
203
  ):
204
  logging.info("一次性回答模式")
205
  history.append(construct_user(inputs))
206
  history.append(construct_assistant(""))
207
  if fake_input:
208
+ chatbot.append((fake_input, ""))
209
  else:
210
+ chatbot.append((inputs, ""))
211
  all_token_counts.append(count_token(construct_user(inputs)))
212
  try:
213
  response = get_response(
 
233
  response = json.loads(response.text)
234
  content = response["choices"][0]["message"]["content"]
235
  history[-1] = construct_assistant(content)
236
+ chatbot[-1] = (chatbot[-1][0], content+display_append)
237
  total_token_count = response["usage"]["total_tokens"]
238
  all_token_counts[-1] = total_token_count - sum(all_token_counts)
239
  status_text = construct_token_message(total_token_count)
 
288
  if len(openai_api_key) != 51:
289
  status_text = standard_error_msg + no_apikey_msg
290
  logging.info(status_text)
291
+ chatbot.append(inputs, "")
292
  if len(history) == 0:
293
  history.append(construct_user(inputs))
294
  history.append("")
 
342
  holo_query_insert_chat_message(invite_code, inputs, history[-1]['content'], token, history)
343
 
344
  if use_websearch:
345
+ search_results = ddg(inputs, max_results=5)
346
+ old_inputs = inputs
347
+ web_results = []
348
+ for idx, result in enumerate(search_results):
349
+ logging.info(f"搜索结果{idx + 1}:{result}")
350
+ domain_name = urllib3.util.parse_url(result["href"]).host
351
+ web_results.append(f'[{idx + 1}]"{result["body"]}"\nURL: {result["href"]}')
352
+ link_references.append(f"{idx + 1}. [{domain_name}]({result['href']})\n")
353
+ link_references = "\n\n" + "".join(link_references)
354
+ inputs = (
355
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
356
+ .replace("{query}", inputs)
357
+ .replace("{web_results}", "\n\n".join(web_results))
358
+ )
359
+ else:
360
+ link_references = ""
361
 
362
  if stream:
363
  max_token = max_token_streaming