Spaces:
Running
Running
Update bin_public/app/chat_func.py
Browse files- bin_public/app/chat_func.py +28 -17
bin_public/app/chat_func.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
# -*- coding:utf-8 -*-
|
2 |
from __future__ import annotations
|
3 |
|
4 |
-
import requests
|
5 |
import urllib3
|
6 |
|
7 |
from tqdm import tqdm
|
8 |
from duckduckgo_search import ddg
|
9 |
from llama_func import *
|
10 |
-
from bin_public.utils.
|
11 |
from bin_public.utils.utils_db import *
|
12 |
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
13 |
|
@@ -94,7 +93,8 @@ def stream_predict(
|
|
94 |
top_p,
|
95 |
temperature,
|
96 |
selected_model,
|
97 |
-
fake_input=None
|
|
|
98 |
):
|
99 |
def get_return_value():
|
100 |
return chatbot, history, status_text, all_token_counts
|
@@ -106,9 +106,9 @@ def stream_predict(
|
|
106 |
history.append(construct_user(inputs))
|
107 |
history.append(construct_assistant(""))
|
108 |
if fake_input:
|
109 |
-
chatbot.append((
|
110 |
else:
|
111 |
-
chatbot.append((
|
112 |
user_token_count = 0
|
113 |
if len(all_token_counts) == 0:
|
114 |
system_prompt_token_count = count_token(construct_system(system_prompt))
|
@@ -183,7 +183,7 @@ def stream_predict(
|
|
183 |
yield get_return_value()
|
184 |
break
|
185 |
history[-1] = construct_assistant(partial_words)
|
186 |
-
chatbot[-1] = (chatbot[-1][0],
|
187 |
all_token_counts[-1] += 1
|
188 |
yield get_return_value()
|
189 |
|
@@ -198,15 +198,16 @@ def predict_all(
|
|
198 |
top_p,
|
199 |
temperature,
|
200 |
selected_model,
|
201 |
-
fake_input=None
|
|
|
202 |
):
|
203 |
logging.info("一次性回答模式")
|
204 |
history.append(construct_user(inputs))
|
205 |
history.append(construct_assistant(""))
|
206 |
if fake_input:
|
207 |
-
chatbot.append((
|
208 |
else:
|
209 |
-
chatbot.append((
|
210 |
all_token_counts.append(count_token(construct_user(inputs)))
|
211 |
try:
|
212 |
response = get_response(
|
@@ -232,7 +233,7 @@ def predict_all(
|
|
232 |
response = json.loads(response.text)
|
233 |
content = response["choices"][0]["message"]["content"]
|
234 |
history[-1] = construct_assistant(content)
|
235 |
-
chatbot[-1] = (chatbot[-1][0],
|
236 |
total_token_count = response["usage"]["total_tokens"]
|
237 |
all_token_counts[-1] = total_token_count - sum(all_token_counts)
|
238 |
status_text = construct_token_message(total_token_count)
|
@@ -287,7 +288,7 @@ def predict(
|
|
287 |
if len(openai_api_key) != 51:
|
288 |
status_text = standard_error_msg + no_apikey_msg
|
289 |
logging.info(status_text)
|
290 |
-
chatbot.append(
|
291 |
if len(history) == 0:
|
292 |
history.append(construct_user(inputs))
|
293 |
history.append("")
|
@@ -341,12 +342,22 @@ def predict(
|
|
341 |
holo_query_insert_chat_message(invite_code, inputs, history[-1]['content'], token, history)
|
342 |
|
343 |
if use_websearch:
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
if stream:
|
352 |
max_token = max_token_streaming
|
|
|
1 |
# -*- coding:utf-8 -*-
|
2 |
from __future__ import annotations
|
3 |
|
|
|
4 |
import urllib3
|
5 |
|
6 |
from tqdm import tqdm
|
7 |
from duckduckgo_search import ddg
|
8 |
from llama_func import *
|
9 |
+
from bin_public.utils.utils import *
|
10 |
from bin_public.utils.utils_db import *
|
11 |
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
12 |
|
|
|
93 |
top_p,
|
94 |
temperature,
|
95 |
selected_model,
|
96 |
+
fake_input=None,
|
97 |
+
display_append=""
|
98 |
):
|
99 |
def get_return_value():
|
100 |
return chatbot, history, status_text, all_token_counts
|
|
|
106 |
history.append(construct_user(inputs))
|
107 |
history.append(construct_assistant(""))
|
108 |
if fake_input:
|
109 |
+
chatbot.append((fake_input, ""))
|
110 |
else:
|
111 |
+
chatbot.append((inputs, ""))
|
112 |
user_token_count = 0
|
113 |
if len(all_token_counts) == 0:
|
114 |
system_prompt_token_count = count_token(construct_system(system_prompt))
|
|
|
183 |
yield get_return_value()
|
184 |
break
|
185 |
history[-1] = construct_assistant(partial_words)
|
186 |
+
chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
|
187 |
all_token_counts[-1] += 1
|
188 |
yield get_return_value()
|
189 |
|
|
|
198 |
top_p,
|
199 |
temperature,
|
200 |
selected_model,
|
201 |
+
fake_input=None,
|
202 |
+
display_append=""
|
203 |
):
|
204 |
logging.info("一次性回答模式")
|
205 |
history.append(construct_user(inputs))
|
206 |
history.append(construct_assistant(""))
|
207 |
if fake_input:
|
208 |
+
chatbot.append((fake_input, ""))
|
209 |
else:
|
210 |
+
chatbot.append((inputs, ""))
|
211 |
all_token_counts.append(count_token(construct_user(inputs)))
|
212 |
try:
|
213 |
response = get_response(
|
|
|
233 |
response = json.loads(response.text)
|
234 |
content = response["choices"][0]["message"]["content"]
|
235 |
history[-1] = construct_assistant(content)
|
236 |
+
chatbot[-1] = (chatbot[-1][0], content+display_append)
|
237 |
total_token_count = response["usage"]["total_tokens"]
|
238 |
all_token_counts[-1] = total_token_count - sum(all_token_counts)
|
239 |
status_text = construct_token_message(total_token_count)
|
|
|
288 |
if len(openai_api_key) != 51:
|
289 |
status_text = standard_error_msg + no_apikey_msg
|
290 |
logging.info(status_text)
|
291 |
+
chatbot.append(inputs, "")
|
292 |
if len(history) == 0:
|
293 |
history.append(construct_user(inputs))
|
294 |
history.append("")
|
|
|
342 |
holo_query_insert_chat_message(invite_code, inputs, history[-1]['content'], token, history)
|
343 |
|
344 |
if use_websearch:
|
345 |
+
search_results = ddg(inputs, max_results=5)
|
346 |
+
old_inputs = inputs
|
347 |
+
web_results = []
|
348 |
+
for idx, result in enumerate(search_results):
|
349 |
+
logging.info(f"搜索结果{idx + 1}:{result}")
|
350 |
+
domain_name = urllib3.util.parse_url(result["href"]).host
|
351 |
+
web_results.append(f'[{idx + 1}]"{result["body"]}"\nURL: {result["href"]}')
|
352 |
+
link_references.append(f"{idx + 1}. [{domain_name}]({result['href']})\n")
|
353 |
+
link_references = "\n\n" + "".join(link_references)
|
354 |
+
inputs = (
|
355 |
+
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
356 |
+
.replace("{query}", inputs)
|
357 |
+
.replace("{web_results}", "\n\n".join(web_results))
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
link_references = ""
|
361 |
|
362 |
if stream:
|
363 |
max_token = max_token_streaming
|