Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
76a432f
1
Parent(s):
c9a9fba
feat: 保存更多参数
Browse files- modules/models/base_model.py +195 -93
- modules/utils.py +165 -66
modules/models/base_model.py
CHANGED
@@ -70,13 +70,13 @@ class CallbackToIterator:
|
|
70 |
|
71 |
|
72 |
def get_action_description(text):
|
73 |
-
match = re.search(
|
74 |
json_text = match.group(1)
|
75 |
# 把json转化为python字典
|
76 |
json_dict = json.loads(json_text)
|
77 |
# 提取'action'和'action_input'的值
|
78 |
-
action_name = json_dict[
|
79 |
-
action_input = json_dict[
|
80 |
if action_name != "Final Answer":
|
81 |
return f'<!-- S O PREFIX --><p class="agent-prefix">{action_name}: {action_input}\n</p><!-- E O PREFIX -->'
|
82 |
else:
|
@@ -84,7 +84,6 @@ def get_action_description(text):
|
|
84 |
|
85 |
|
86 |
class ChuanhuCallbackHandler(BaseCallbackHandler):
|
87 |
-
|
88 |
def __init__(self, callback) -> None:
|
89 |
"""Initialize callback handler."""
|
90 |
self.callback = callback
|
@@ -124,7 +123,12 @@ class ChuanhuCallbackHandler(BaseCallbackHandler):
|
|
124 |
"""Run on new LLM token. Only available when streaming is enabled."""
|
125 |
self.callback(token)
|
126 |
|
127 |
-
def on_chat_model_start(
|
|
|
|
|
|
|
|
|
|
|
128 |
"""Run when a chat model starts running."""
|
129 |
pass
|
130 |
|
@@ -228,24 +232,26 @@ class BaseLLMModel:
|
|
228 |
self.need_api_key = False
|
229 |
self.single_turn = False
|
230 |
self.history_file_path = get_first_history_name(user)
|
|
|
231 |
|
232 |
self.temperature = temperature
|
233 |
self.top_p = top_p
|
234 |
self.n_choices = n_choices
|
235 |
self.stop_sequence = stop
|
236 |
-
self.max_generation_token =
|
237 |
self.presence_penalty = presence_penalty
|
238 |
self.frequency_penalty = frequency_penalty
|
239 |
self.logit_bias = logit_bias
|
240 |
self.user_identifier = user
|
241 |
|
|
|
|
|
242 |
def get_answer_stream_iter(self):
|
243 |
-
"""stream
|
244 |
-
|
245 |
-
|
246 |
"""
|
247 |
-
logging.warning(
|
248 |
-
"stream predict not implemented, using at once predict instead")
|
249 |
response, _ = self.get_answer_at_once()
|
250 |
yield response
|
251 |
|
@@ -256,8 +262,7 @@ class BaseLLMModel:
|
|
256 |
the answer (str)
|
257 |
total token count (int)
|
258 |
"""
|
259 |
-
logging.warning(
|
260 |
-
"at once predict not implemented, using stream predict instead")
|
261 |
response_iter = self.get_answer_stream_iter()
|
262 |
count = 0
|
263 |
for response in response_iter:
|
@@ -291,7 +296,9 @@ class BaseLLMModel:
|
|
291 |
stream_iter = self.get_answer_stream_iter()
|
292 |
|
293 |
if display_append:
|
294 |
-
display_append =
|
|
|
|
|
295 |
partial_text = ""
|
296 |
token_increment = 1
|
297 |
for partial_text in stream_iter:
|
@@ -322,11 +329,9 @@ class BaseLLMModel:
|
|
322 |
self.history[-2] = construct_user(fake_input)
|
323 |
chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
|
324 |
if fake_input is not None:
|
325 |
-
self.all_token_counts[-1] += count_token(
|
326 |
-
construct_assistant(ai_reply))
|
327 |
else:
|
328 |
-
self.all_token_counts[-1] = total_token_count -
|
329 |
-
sum(self.all_token_counts)
|
330 |
status_text = self.token_message()
|
331 |
return chatbot, status_text
|
332 |
|
@@ -349,46 +354,80 @@ class BaseLLMModel:
|
|
349 |
from langchain.prompts import PromptTemplate
|
350 |
from langchain.chat_models import ChatOpenAI
|
351 |
from langchain.callbacks import StdOutCallbackHandler
|
352 |
-
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
355 |
llm = ChatOpenAI()
|
356 |
chain = load_summarize_chain(
|
357 |
-
llm,
|
358 |
-
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
print(i18n("总结") + f": {summary}")
|
361 |
-
chatbot.append([i18n("上传了")+str(len(files))+"个文件", summary])
|
362 |
return chatbot, status
|
363 |
|
364 |
-
def prepare_inputs(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
display_append = []
|
366 |
limited_context = False
|
367 |
if type(real_inputs) == list:
|
368 |
-
fake_inputs = real_inputs[0][
|
369 |
else:
|
370 |
fake_inputs = real_inputs
|
371 |
if files:
|
372 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
373 |
from langchain.vectorstores.base import VectorStoreRetriever
|
|
|
374 |
limited_context = True
|
375 |
msg = "加载索引中……"
|
376 |
logging.info(msg)
|
377 |
-
index = construct_index(
|
|
|
|
|
|
|
|
|
378 |
assert index is not None, "获取索引失败"
|
379 |
msg = "索引获取成功,生成回答中……"
|
380 |
logging.info(msg)
|
381 |
with retrieve_proxy():
|
382 |
-
retriever = VectorStoreRetriever(
|
|
|
|
|
383 |
# retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity_score_threshold", search_kwargs={
|
384 |
# "k": 6, "score_threshold": 0.2})
|
385 |
try:
|
386 |
-
relevant_documents = retriever.get_relevant_documents(
|
387 |
-
fake_inputs)
|
388 |
except AssertionError:
|
389 |
-
return self.prepare_inputs(
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
reference_results = add_source_numbers(reference_results)
|
393 |
display_append = add_details(reference_results)
|
394 |
display_append = "\n\n" + "".join(display_append)
|
@@ -415,16 +454,17 @@ class BaseLLMModel:
|
|
415 |
reference_results = []
|
416 |
for idx, result in enumerate(search_results):
|
417 |
logging.debug(f"搜索结果{idx + 1}:{result}")
|
418 |
-
domain_name = urllib3.util.parse_url(result[
|
419 |
-
reference_results.append([result[
|
420 |
display_append.append(
|
421 |
# f"{idx+1}. [{domain_name}]({result['href']})\n"
|
422 |
f"<a href=\"{result['href']}\" target=\"_blank\">{idx+1}. {result['title']}</a>"
|
423 |
)
|
424 |
reference_results = add_source_numbers(reference_results)
|
425 |
# display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
|
426 |
-
display_append =
|
427 |
-
"".join(display_append) +
|
|
|
428 |
if type(real_inputs) == list:
|
429 |
real_inputs[0]["text"] = (
|
430 |
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
@@ -453,33 +493,54 @@ class BaseLLMModel:
|
|
453 |
reply_language="中文",
|
454 |
should_check_token_count=True,
|
455 |
): # repetition_penalty, top_k
|
456 |
-
|
457 |
status_text = "开始生成回答……"
|
458 |
if type(inputs) == list:
|
459 |
-
|
460 |
-
"用户"
|
461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
)
|
463 |
else:
|
464 |
logging.info(
|
465 |
-
"用户"
|
466 |
-
|
|
|
|
|
|
|
|
|
467 |
)
|
468 |
if should_check_token_count:
|
469 |
if type(inputs) == list:
|
470 |
-
|
471 |
else:
|
472 |
yield chatbot + [(inputs, "")], status_text
|
473 |
if reply_language == "跟随问题语言(不稳定)":
|
474 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
475 |
|
476 |
-
|
477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
yield chatbot + [(fake_inputs, "")], status_text
|
479 |
|
480 |
if (
|
481 |
-
self.need_api_key
|
482 |
-
self.api_key is None
|
483 |
and not shared.state.multi_api_key
|
484 |
):
|
485 |
status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
|
@@ -684,11 +745,16 @@ class BaseLLMModel:
|
|
684 |
self.history = []
|
685 |
self.all_token_counts = []
|
686 |
self.interrupted = False
|
687 |
-
self.history_file_path = new_auto_history_filename(self.
|
688 |
history_name = self.history_file_path[:-5]
|
689 |
-
choices = [history_name] + get_history_names(self.
|
690 |
system_prompt = self.system_prompt if remain_system_prompt else ""
|
691 |
-
return
|
|
|
|
|
|
|
|
|
|
|
692 |
|
693 |
def delete_first_conversation(self):
|
694 |
if self.history:
|
@@ -719,7 +785,12 @@ class BaseLLMModel:
|
|
719 |
token_sum = 0
|
720 |
for i in range(len(token_lst)):
|
721 |
token_sum += sum(token_lst[: i + 1])
|
722 |
-
return
|
|
|
|
|
|
|
|
|
|
|
723 |
|
724 |
def rename_chat_history(self, filename, chatbot):
|
725 |
if filename == "":
|
@@ -729,78 +800,103 @@ class BaseLLMModel:
|
|
729 |
self.delete_chat_history(self.history_file_path)
|
730 |
# 命名重复检测
|
731 |
repeat_file_index = 2
|
732 |
-
full_path = os.path.join(HISTORY_DIR, self.
|
733 |
while os.path.exists(full_path):
|
734 |
-
full_path = os.path.join(
|
|
|
|
|
735 |
repeat_file_index += 1
|
736 |
filename = os.path.basename(full_path)
|
737 |
|
738 |
self.history_file_path = filename
|
739 |
-
save_file(filename, self
|
740 |
-
return init_history_list(self.
|
741 |
|
742 |
-
def auto_name_chat_history(
|
|
|
|
|
743 |
if len(self.history) == 2 and not single_turn_checkbox:
|
744 |
user_question = self.history[0]["content"]
|
745 |
if type(user_question) == list:
|
746 |
user_question = user_question[0]["text"]
|
747 |
filename = replace_special_symbols(user_question)[:16] + ".json"
|
748 |
-
return self.rename_chat_history(filename, chatbot, self.
|
749 |
else:
|
750 |
return gr.update()
|
751 |
|
752 |
def auto_save(self, chatbot):
|
753 |
-
save_file(self.history_file_path, self
|
754 |
-
self.history, chatbot, self.user_identifier)
|
755 |
|
756 |
def export_markdown(self, filename, chatbot):
|
757 |
if filename == "":
|
758 |
return
|
759 |
if not filename.endswith(".md"):
|
760 |
filename += ".md"
|
761 |
-
save_file(filename, self
|
762 |
|
763 |
def load_chat_history(self, new_history_file_path=None):
|
764 |
-
logging.debug(f"{self.
|
765 |
if new_history_file_path is not None:
|
766 |
if type(new_history_file_path) != str:
|
767 |
-
# copy file from new_history_file_path.name to os.path.join(HISTORY_DIR, self.
|
768 |
new_history_file_path = new_history_file_path.name
|
769 |
-
shutil.copyfile(
|
770 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
771 |
self.history_file_path = os.path.basename(new_history_file_path)
|
772 |
else:
|
773 |
self.history_file_path = new_history_file_path
|
774 |
try:
|
775 |
if self.history_file_path == os.path.basename(self.history_file_path):
|
776 |
history_file_path = os.path.join(
|
777 |
-
HISTORY_DIR, self.
|
|
|
778 |
else:
|
779 |
history_file_path = self.history_file_path
|
780 |
if not self.history_file_path.endswith(".json"):
|
781 |
history_file_path += ".json"
|
782 |
with open(history_file_path, "r", encoding="utf-8") as f:
|
783 |
-
|
784 |
try:
|
785 |
-
if type(
|
786 |
logging.info("历史记录格式为旧版,正在转换……")
|
787 |
new_history = []
|
788 |
-
for index, item in enumerate(
|
789 |
if index % 2 == 0:
|
790 |
new_history.append(construct_user(item))
|
791 |
else:
|
792 |
new_history.append(construct_assistant(item))
|
793 |
-
|
794 |
logging.info(new_history)
|
795 |
except:
|
796 |
pass
|
797 |
-
if len(
|
798 |
logging.info("Trimming corrupted history...")
|
799 |
-
|
800 |
-
logging.info(f"Trimmed history: {
|
801 |
-
logging.debug(f"{self.
|
802 |
-
self.history =
|
803 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
804 |
except:
|
805 |
# 没有对话历史或者对话历史解析失败
|
806 |
logging.info(f"没有找到对话历史记录 {self.history_file_path}")
|
@@ -814,23 +910,28 @@ class BaseLLMModel:
|
|
814 |
if not filename.endswith(".json"):
|
815 |
filename += ".json"
|
816 |
if filename == os.path.basename(filename):
|
817 |
-
history_file_path = os.path.join(
|
|
|
|
|
818 |
else:
|
819 |
history_file_path = filename
|
820 |
md_history_file_path = history_file_path[:-5] + ".md"
|
821 |
try:
|
822 |
os.remove(history_file_path)
|
823 |
os.remove(md_history_file_path)
|
824 |
-
return i18n("删除对话历史成功"), get_history_list(self.
|
825 |
except:
|
826 |
logging.info(f"删除对话历史失败 {history_file_path}")
|
827 |
-
return
|
|
|
|
|
|
|
|
|
828 |
|
829 |
def auto_load(self):
|
830 |
-
filepath = get_history_filepath(self.
|
831 |
if not filepath:
|
832 |
-
self.history_file_path = new_auto_history_filename(
|
833 |
-
self.user_identifier)
|
834 |
else:
|
835 |
self.history_file_path = filepath
|
836 |
filename, system_prompt, chatbot = self.load_chat_history()
|
@@ -838,18 +939,15 @@ class BaseLLMModel:
|
|
838 |
return filename, system_prompt, chatbot
|
839 |
|
840 |
def like(self):
|
841 |
-
"""like the last response, implement if needed
|
842 |
-
"""
|
843 |
return gr.update()
|
844 |
|
845 |
def dislike(self):
|
846 |
-
"""dislike the last response, implement if needed
|
847 |
-
"""
|
848 |
return gr.update()
|
849 |
|
850 |
def deinitialize(self):
|
851 |
-
"""deinitialize the model, implement if needed
|
852 |
-
"""
|
853 |
pass
|
854 |
|
855 |
|
@@ -874,7 +972,8 @@ class Base_Chat_Langchain_Client(BaseLLMModel):
|
|
874 |
|
875 |
def get_answer_at_once(self):
|
876 |
assert isinstance(
|
877 |
-
self.model, BaseChatModel
|
|
|
878 |
history = self._get_langchain_style_history()
|
879 |
response = self.model.generate(history)
|
880 |
return response.content, sum(response.content)
|
@@ -882,13 +981,16 @@ class Base_Chat_Langchain_Client(BaseLLMModel):
|
|
882 |
def get_answer_stream_iter(self):
|
883 |
it = CallbackToIterator()
|
884 |
assert isinstance(
|
885 |
-
self.model, BaseChatModel
|
|
|
886 |
history = self._get_langchain_style_history()
|
887 |
|
888 |
def thread_func():
|
889 |
-
self.model(
|
890 |
-
ChuanhuCallbackHandler(it.callback)]
|
|
|
891 |
it.finish()
|
|
|
892 |
t = Thread(target=thread_func)
|
893 |
t.start()
|
894 |
partial_text = ""
|
|
|
70 |
|
71 |
|
72 |
def get_action_description(text):
|
73 |
+
match = re.search("```(.*?)```", text, re.S)
|
74 |
json_text = match.group(1)
|
75 |
# 把json转化为python字典
|
76 |
json_dict = json.loads(json_text)
|
77 |
# 提取'action'和'action_input'的值
|
78 |
+
action_name = json_dict["action"]
|
79 |
+
action_input = json_dict["action_input"]
|
80 |
if action_name != "Final Answer":
|
81 |
return f'<!-- S O PREFIX --><p class="agent-prefix">{action_name}: {action_input}\n</p><!-- E O PREFIX -->'
|
82 |
else:
|
|
|
84 |
|
85 |
|
86 |
class ChuanhuCallbackHandler(BaseCallbackHandler):
|
|
|
87 |
def __init__(self, callback) -> None:
|
88 |
"""Initialize callback handler."""
|
89 |
self.callback = callback
|
|
|
123 |
"""Run on new LLM token. Only available when streaming is enabled."""
|
124 |
self.callback(token)
|
125 |
|
126 |
+
def on_chat_model_start(
|
127 |
+
self,
|
128 |
+
serialized: Dict[str, Any],
|
129 |
+
messages: List[List[BaseMessage]],
|
130 |
+
**kwargs: Any,
|
131 |
+
) -> Any:
|
132 |
"""Run when a chat model starts running."""
|
133 |
pass
|
134 |
|
|
|
232 |
self.need_api_key = False
|
233 |
self.single_turn = False
|
234 |
self.history_file_path = get_first_history_name(user)
|
235 |
+
self.user_name = user
|
236 |
|
237 |
self.temperature = temperature
|
238 |
self.top_p = top_p
|
239 |
self.n_choices = n_choices
|
240 |
self.stop_sequence = stop
|
241 |
+
self.max_generation_token = max_generation_token
|
242 |
self.presence_penalty = presence_penalty
|
243 |
self.frequency_penalty = frequency_penalty
|
244 |
self.logit_bias = logit_bias
|
245 |
self.user_identifier = user
|
246 |
|
247 |
+
self.metadata = {}
|
248 |
+
|
249 |
def get_answer_stream_iter(self):
|
250 |
+
"""Implement stream prediction.
|
251 |
+
Conversations are stored in self.history, with the most recent question in OpenAI format.
|
252 |
+
Should return a generator that yields the next word (str) in the answer.
|
253 |
"""
|
254 |
+
logging.warning("Stream prediction is not implemented. Using at once prediction instead.")
|
|
|
255 |
response, _ = self.get_answer_at_once()
|
256 |
yield response
|
257 |
|
|
|
262 |
the answer (str)
|
263 |
total token count (int)
|
264 |
"""
|
265 |
+
logging.warning("at once predict not implemented, using stream predict instead")
|
|
|
266 |
response_iter = self.get_answer_stream_iter()
|
267 |
count = 0
|
268 |
for response in response_iter:
|
|
|
296 |
stream_iter = self.get_answer_stream_iter()
|
297 |
|
298 |
if display_append:
|
299 |
+
display_append = (
|
300 |
+
'\n\n<hr class="append-display no-in-raw" />' + display_append
|
301 |
+
)
|
302 |
partial_text = ""
|
303 |
token_increment = 1
|
304 |
for partial_text in stream_iter:
|
|
|
329 |
self.history[-2] = construct_user(fake_input)
|
330 |
chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
|
331 |
if fake_input is not None:
|
332 |
+
self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
|
|
|
333 |
else:
|
334 |
+
self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
|
|
|
335 |
status_text = self.token_message()
|
336 |
return chatbot, status_text
|
337 |
|
|
|
354 |
from langchain.prompts import PromptTemplate
|
355 |
from langchain.chat_models import ChatOpenAI
|
356 |
from langchain.callbacks import StdOutCallbackHandler
|
357 |
+
|
358 |
+
prompt_template = (
|
359 |
+
"Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
|
360 |
+
+ language
|
361 |
+
+ ":"
|
362 |
+
)
|
363 |
+
PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
|
364 |
llm = ChatOpenAI()
|
365 |
chain = load_summarize_chain(
|
366 |
+
llm,
|
367 |
+
chain_type="map_reduce",
|
368 |
+
return_intermediate_steps=True,
|
369 |
+
map_prompt=PROMPT,
|
370 |
+
combine_prompt=PROMPT,
|
371 |
+
)
|
372 |
+
summary = chain(
|
373 |
+
{"input_documents": list(index.docstore.__dict__["_dict"].values())},
|
374 |
+
return_only_outputs=True,
|
375 |
+
)["output_text"]
|
376 |
print(i18n("总结") + f": {summary}")
|
377 |
+
chatbot.append([i18n("上传了") + str(len(files)) + "个文件", summary])
|
378 |
return chatbot, status
|
379 |
|
380 |
+
def prepare_inputs(
|
381 |
+
self,
|
382 |
+
real_inputs,
|
383 |
+
use_websearch,
|
384 |
+
files,
|
385 |
+
reply_language,
|
386 |
+
chatbot,
|
387 |
+
load_from_cache_if_possible=True,
|
388 |
+
):
|
389 |
display_append = []
|
390 |
limited_context = False
|
391 |
if type(real_inputs) == list:
|
392 |
+
fake_inputs = real_inputs[0]["text"]
|
393 |
else:
|
394 |
fake_inputs = real_inputs
|
395 |
if files:
|
396 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
397 |
from langchain.vectorstores.base import VectorStoreRetriever
|
398 |
+
|
399 |
limited_context = True
|
400 |
msg = "加载索引中……"
|
401 |
logging.info(msg)
|
402 |
+
index = construct_index(
|
403 |
+
self.api_key,
|
404 |
+
file_src=files,
|
405 |
+
load_from_cache_if_possible=load_from_cache_if_possible,
|
406 |
+
)
|
407 |
assert index is not None, "获取索引失败"
|
408 |
msg = "索引获取成功,生成回答中……"
|
409 |
logging.info(msg)
|
410 |
with retrieve_proxy():
|
411 |
+
retriever = VectorStoreRetriever(
|
412 |
+
vectorstore=index, search_type="similarity", search_kwargs={"k": 6}
|
413 |
+
)
|
414 |
# retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity_score_threshold", search_kwargs={
|
415 |
# "k": 6, "score_threshold": 0.2})
|
416 |
try:
|
417 |
+
relevant_documents = retriever.get_relevant_documents(fake_inputs)
|
|
|
418 |
except AssertionError:
|
419 |
+
return self.prepare_inputs(
|
420 |
+
fake_inputs,
|
421 |
+
use_websearch,
|
422 |
+
files,
|
423 |
+
reply_language,
|
424 |
+
chatbot,
|
425 |
+
load_from_cache_if_possible=False,
|
426 |
+
)
|
427 |
+
reference_results = [
|
428 |
+
[d.page_content.strip("�"), os.path.basename(d.metadata["source"])]
|
429 |
+
for d in relevant_documents
|
430 |
+
]
|
431 |
reference_results = add_source_numbers(reference_results)
|
432 |
display_append = add_details(reference_results)
|
433 |
display_append = "\n\n" + "".join(display_append)
|
|
|
454 |
reference_results = []
|
455 |
for idx, result in enumerate(search_results):
|
456 |
logging.debug(f"搜索结果{idx + 1}:{result}")
|
457 |
+
domain_name = urllib3.util.parse_url(result["href"]).host
|
458 |
+
reference_results.append([result["body"], result["href"]])
|
459 |
display_append.append(
|
460 |
# f"{idx+1}. [{domain_name}]({result['href']})\n"
|
461 |
f"<a href=\"{result['href']}\" target=\"_blank\">{idx+1}. {result['title']}</a>"
|
462 |
)
|
463 |
reference_results = add_source_numbers(reference_results)
|
464 |
# display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
|
465 |
+
display_append = (
|
466 |
+
'<div class = "source-a">' + "".join(display_append) + "</div>"
|
467 |
+
)
|
468 |
if type(real_inputs) == list:
|
469 |
real_inputs[0]["text"] = (
|
470 |
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
|
|
493 |
reply_language="中文",
|
494 |
should_check_token_count=True,
|
495 |
): # repetition_penalty, top_k
|
|
|
496 |
status_text = "开始生成回答……"
|
497 |
if type(inputs) == list:
|
498 |
+
logging.info(
|
499 |
+
"用户"
|
500 |
+
+ f"{self.user_name}"
|
501 |
+
+ "的输入为:"
|
502 |
+
+ colorama.Fore.BLUE
|
503 |
+
+ "("
|
504 |
+
+ str(len(inputs) - 1)
|
505 |
+
+ " images) "
|
506 |
+
+ f"{inputs[0]['text']}"
|
507 |
+
+ colorama.Style.RESET_ALL
|
508 |
)
|
509 |
else:
|
510 |
logging.info(
|
511 |
+
"用户"
|
512 |
+
+ f"{self.user_name}"
|
513 |
+
+ "的输入为:"
|
514 |
+
+ colorama.Fore.BLUE
|
515 |
+
+ f"{inputs}"
|
516 |
+
+ colorama.Style.RESET_ALL
|
517 |
)
|
518 |
if should_check_token_count:
|
519 |
if type(inputs) == list:
|
520 |
+
yield chatbot + [(inputs[0]["text"], "")], status_text
|
521 |
else:
|
522 |
yield chatbot + [(inputs, "")], status_text
|
523 |
if reply_language == "跟随问题语言(不稳定)":
|
524 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
525 |
|
526 |
+
(
|
527 |
+
limited_context,
|
528 |
+
fake_inputs,
|
529 |
+
display_append,
|
530 |
+
inputs,
|
531 |
+
chatbot,
|
532 |
+
) = self.prepare_inputs(
|
533 |
+
real_inputs=inputs,
|
534 |
+
use_websearch=use_websearch,
|
535 |
+
files=files,
|
536 |
+
reply_language=reply_language,
|
537 |
+
chatbot=chatbot,
|
538 |
+
)
|
539 |
yield chatbot + [(fake_inputs, "")], status_text
|
540 |
|
541 |
if (
|
542 |
+
self.need_api_key
|
543 |
+
and self.api_key is None
|
544 |
and not shared.state.multi_api_key
|
545 |
):
|
546 |
status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
|
|
|
745 |
self.history = []
|
746 |
self.all_token_counts = []
|
747 |
self.interrupted = False
|
748 |
+
self.history_file_path = new_auto_history_filename(self.user_name)
|
749 |
history_name = self.history_file_path[:-5]
|
750 |
+
choices = [history_name] + get_history_names(self.user_name)
|
751 |
system_prompt = self.system_prompt if remain_system_prompt else ""
|
752 |
+
return (
|
753 |
+
[],
|
754 |
+
self.token_message([0]),
|
755 |
+
gr.Radio.update(choices=choices, value=history_name),
|
756 |
+
system_prompt,
|
757 |
+
)
|
758 |
|
759 |
def delete_first_conversation(self):
|
760 |
if self.history:
|
|
|
785 |
token_sum = 0
|
786 |
for i in range(len(token_lst)):
|
787 |
token_sum += sum(token_lst[: i + 1])
|
788 |
+
return (
|
789 |
+
i18n("Token 计数: ")
|
790 |
+
+ f"{sum(token_lst)}"
|
791 |
+
+ i18n(",本次对话累计消耗了 ")
|
792 |
+
+ f"{token_sum} tokens"
|
793 |
+
)
|
794 |
|
795 |
def rename_chat_history(self, filename, chatbot):
|
796 |
if filename == "":
|
|
|
800 |
self.delete_chat_history(self.history_file_path)
|
801 |
# 命名重复检测
|
802 |
repeat_file_index = 2
|
803 |
+
full_path = os.path.join(HISTORY_DIR, self.user_name, filename)
|
804 |
while os.path.exists(full_path):
|
805 |
+
full_path = os.path.join(
|
806 |
+
HISTORY_DIR, self.user_name, f"{repeat_file_index}_{filename}"
|
807 |
+
)
|
808 |
repeat_file_index += 1
|
809 |
filename = os.path.basename(full_path)
|
810 |
|
811 |
self.history_file_path = filename
|
812 |
+
save_file(filename, self, chatbot)
|
813 |
+
return init_history_list(self.user_name)
|
814 |
|
815 |
+
def auto_name_chat_history(
|
816 |
+
self, name_chat_method, user_question, chatbot, single_turn_checkbox
|
817 |
+
):
|
818 |
if len(self.history) == 2 and not single_turn_checkbox:
|
819 |
user_question = self.history[0]["content"]
|
820 |
if type(user_question) == list:
|
821 |
user_question = user_question[0]["text"]
|
822 |
filename = replace_special_symbols(user_question)[:16] + ".json"
|
823 |
+
return self.rename_chat_history(filename, chatbot, self.user_name)
|
824 |
else:
|
825 |
return gr.update()
|
826 |
|
827 |
def auto_save(self, chatbot):
|
828 |
+
save_file(self.history_file_path, self, chatbot)
|
|
|
829 |
|
830 |
def export_markdown(self, filename, chatbot):
|
831 |
if filename == "":
|
832 |
return
|
833 |
if not filename.endswith(".md"):
|
834 |
filename += ".md"
|
835 |
+
save_file(filename, self, chatbot)
|
836 |
|
837 |
def load_chat_history(self, new_history_file_path=None):
|
838 |
+
logging.debug(f"{self.user_name} 加载对话历史中……")
|
839 |
if new_history_file_path is not None:
|
840 |
if type(new_history_file_path) != str:
|
841 |
+
# copy file from new_history_file_path.name to os.path.join(HISTORY_DIR, self.user_name)
|
842 |
new_history_file_path = new_history_file_path.name
|
843 |
+
shutil.copyfile(
|
844 |
+
new_history_file_path,
|
845 |
+
os.path.join(
|
846 |
+
HISTORY_DIR,
|
847 |
+
self.user_name,
|
848 |
+
os.path.basename(new_history_file_path),
|
849 |
+
),
|
850 |
+
)
|
851 |
self.history_file_path = os.path.basename(new_history_file_path)
|
852 |
else:
|
853 |
self.history_file_path = new_history_file_path
|
854 |
try:
|
855 |
if self.history_file_path == os.path.basename(self.history_file_path):
|
856 |
history_file_path = os.path.join(
|
857 |
+
HISTORY_DIR, self.user_name, self.history_file_path
|
858 |
+
)
|
859 |
else:
|
860 |
history_file_path = self.history_file_path
|
861 |
if not self.history_file_path.endswith(".json"):
|
862 |
history_file_path += ".json"
|
863 |
with open(history_file_path, "r", encoding="utf-8") as f:
|
864 |
+
saved_json = json.load(f)
|
865 |
try:
|
866 |
+
if type(saved_json["history"][0]) == str:
|
867 |
logging.info("历史记录格式为旧版,正在转换……")
|
868 |
new_history = []
|
869 |
+
for index, item in enumerate(saved_json["history"]):
|
870 |
if index % 2 == 0:
|
871 |
new_history.append(construct_user(item))
|
872 |
else:
|
873 |
new_history.append(construct_assistant(item))
|
874 |
+
saved_json["history"] = new_history
|
875 |
logging.info(new_history)
|
876 |
except:
|
877 |
pass
|
878 |
+
if len(saved_json["chatbot"]) < len(saved_json["history"]) // 2:
|
879 |
logging.info("Trimming corrupted history...")
|
880 |
+
saved_json["history"] = saved_json["history"][-len(saved_json["chatbot"]) :]
|
881 |
+
logging.info(f"Trimmed history: {saved_json['history']}")
|
882 |
+
logging.debug(f"{self.user_name} 加载对话历史完毕")
|
883 |
+
self.history = saved_json["history"]
|
884 |
+
self.single_turn = saved_json.get("single_turn", False)
|
885 |
+
self.temperature = saved_json.get("temperature", 1.0)
|
886 |
+
self.top_p = saved_json.get("top_p", None)
|
887 |
+
self.n_choices = saved_json.get("n_choices", 1)
|
888 |
+
self.stop_sequence = saved_json.get("stop_sequence", None)
|
889 |
+
self.max_generation_token = saved_json.get("max_generation_token", None)
|
890 |
+
self.presence_penalty = saved_json.get("presence_penalty", 0)
|
891 |
+
self.frequency_penalty = saved_json.get("frequency_penalty", 0)
|
892 |
+
self.logit_bias = saved_json.get("logit_bias", None)
|
893 |
+
self.user_identifier = saved_json.get("user_identifier", self.user_name)
|
894 |
+
self.metadata = saved_json.get("metadata", {})
|
895 |
+
return (
|
896 |
+
os.path.basename(self.history_file_path),
|
897 |
+
saved_json["system"],
|
898 |
+
saved_json["chatbot"],
|
899 |
+
)
|
900 |
except:
|
901 |
# 没有对话历史或者对话历史解析失败
|
902 |
logging.info(f"没有找到对话历史记录 {self.history_file_path}")
|
|
|
910 |
if not filename.endswith(".json"):
|
911 |
filename += ".json"
|
912 |
if filename == os.path.basename(filename):
|
913 |
+
history_file_path = os.path.join(
|
914 |
+
HISTORY_DIR, self.user_name, filename
|
915 |
+
)
|
916 |
else:
|
917 |
history_file_path = filename
|
918 |
md_history_file_path = history_file_path[:-5] + ".md"
|
919 |
try:
|
920 |
os.remove(history_file_path)
|
921 |
os.remove(md_history_file_path)
|
922 |
+
return i18n("删除对话历史成功"), get_history_list(self.user_name), []
|
923 |
except:
|
924 |
logging.info(f"删除对话历史失败 {history_file_path}")
|
925 |
+
return (
|
926 |
+
i18n("对话历史") + filename + i18n("已经被删除啦"),
|
927 |
+
get_history_list(self.user_name),
|
928 |
+
[],
|
929 |
+
)
|
930 |
|
931 |
def auto_load(self):
|
932 |
+
filepath = get_history_filepath(self.user_name)
|
933 |
if not filepath:
|
934 |
+
self.history_file_path = new_auto_history_filename(self.user_name)
|
|
|
935 |
else:
|
936 |
self.history_file_path = filepath
|
937 |
filename, system_prompt, chatbot = self.load_chat_history()
|
|
|
939 |
return filename, system_prompt, chatbot
|
940 |
|
941 |
def like(self):
|
942 |
+
"""like the last response, implement if needed"""
|
|
|
943 |
return gr.update()
|
944 |
|
945 |
def dislike(self):
|
946 |
+
"""dislike the last response, implement if needed"""
|
|
|
947 |
return gr.update()
|
948 |
|
949 |
def deinitialize(self):
|
950 |
+
"""deinitialize the model, implement if needed"""
|
|
|
951 |
pass
|
952 |
|
953 |
|
|
|
972 |
|
973 |
def get_answer_at_once(self):
|
974 |
assert isinstance(
|
975 |
+
self.model, BaseChatModel
|
976 |
+
), "model is not instance of LangChain BaseChatModel"
|
977 |
history = self._get_langchain_style_history()
|
978 |
response = self.model.generate(history)
|
979 |
return response.content, sum(response.content)
|
|
|
981 |
def get_answer_stream_iter(self):
|
982 |
it = CallbackToIterator()
|
983 |
assert isinstance(
|
984 |
+
self.model, BaseChatModel
|
985 |
+
), "model is not instance of LangChain BaseChatModel"
|
986 |
history = self._get_langchain_style_history()
|
987 |
|
988 |
def thread_func():
|
989 |
+
self.model(
|
990 |
+
messages=history, callbacks=[ChuanhuCallbackHandler(it.callback)]
|
991 |
+
)
|
992 |
it.finish()
|
993 |
+
|
994 |
t = Thread(target=thread_func)
|
995 |
t.start()
|
996 |
partial_text = ""
|
modules/utils.py
CHANGED
@@ -31,97 +31,127 @@ if TYPE_CHECKING:
|
|
31 |
headers: List[str]
|
32 |
data: List[List[str | int | bool]]
|
33 |
|
|
|
34 |
def predict(current_model, *args):
|
35 |
iter = current_model.predict(*args)
|
36 |
for i in iter:
|
37 |
yield i
|
38 |
|
|
|
39 |
def billing_info(current_model):
|
40 |
return current_model.billing_info()
|
41 |
|
|
|
42 |
def set_key(current_model, *args):
|
43 |
return current_model.set_key(*args)
|
44 |
|
|
|
45 |
def load_chat_history(current_model, *args):
|
46 |
return current_model.load_chat_history(*args)
|
47 |
|
|
|
48 |
def delete_chat_history(current_model, *args):
|
49 |
return current_model.delete_chat_history(*args)
|
50 |
|
|
|
51 |
def interrupt(current_model, *args):
|
52 |
return current_model.interrupt(*args)
|
53 |
|
|
|
54 |
def reset(current_model, *args):
|
55 |
return current_model.reset(*args)
|
56 |
|
|
|
57 |
def retry(current_model, *args):
|
58 |
iter = current_model.retry(*args)
|
59 |
for i in iter:
|
60 |
yield i
|
61 |
|
|
|
62 |
def delete_first_conversation(current_model, *args):
|
63 |
return current_model.delete_first_conversation(*args)
|
64 |
|
|
|
65 |
def delete_last_conversation(current_model, *args):
|
66 |
return current_model.delete_last_conversation(*args)
|
67 |
|
|
|
68 |
def set_system_prompt(current_model, *args):
|
69 |
return current_model.set_system_prompt(*args)
|
70 |
|
|
|
71 |
def rename_chat_history(current_model, *args):
|
72 |
return current_model.rename_chat_history(*args)
|
73 |
|
|
|
74 |
def auto_name_chat_history(current_model, *args):
|
75 |
return current_model.auto_name_chat_history(*args)
|
76 |
|
|
|
77 |
def export_markdown(current_model, *args):
|
78 |
return current_model.export_markdown(*args)
|
79 |
|
|
|
80 |
def upload_chat_history(current_model, *args):
|
81 |
return current_model.load_chat_history(*args)
|
82 |
|
|
|
83 |
def set_token_upper_limit(current_model, *args):
|
84 |
return current_model.set_token_upper_limit(*args)
|
85 |
|
|
|
86 |
def set_temperature(current_model, *args):
|
87 |
current_model.set_temperature(*args)
|
88 |
|
|
|
89 |
def set_top_p(current_model, *args):
|
90 |
current_model.set_top_p(*args)
|
91 |
|
|
|
92 |
def set_n_choices(current_model, *args):
|
93 |
current_model.set_n_choices(*args)
|
94 |
|
|
|
95 |
def set_stop_sequence(current_model, *args):
|
96 |
current_model.set_stop_sequence(*args)
|
97 |
|
|
|
98 |
def set_max_tokens(current_model, *args):
|
99 |
current_model.set_max_tokens(*args)
|
100 |
|
|
|
101 |
def set_presence_penalty(current_model, *args):
|
102 |
current_model.set_presence_penalty(*args)
|
103 |
|
|
|
104 |
def set_frequency_penalty(current_model, *args):
|
105 |
current_model.set_frequency_penalty(*args)
|
106 |
|
|
|
107 |
def set_logit_bias(current_model, *args):
|
108 |
current_model.set_logit_bias(*args)
|
109 |
|
|
|
110 |
def set_user_identifier(current_model, *args):
|
111 |
current_model.set_user_identifier(*args)
|
112 |
|
|
|
113 |
def set_single_turn(current_model, *args):
|
114 |
current_model.set_single_turn(*args)
|
115 |
|
|
|
116 |
def handle_file_upload(current_model, *args):
|
117 |
return current_model.handle_file_upload(*args)
|
118 |
|
|
|
119 |
def handle_summarize_index(current_model, *args):
|
120 |
return current_model.summarize_index(*args)
|
121 |
|
|
|
122 |
def like(current_model, *args):
|
123 |
return current_model.like(*args)
|
124 |
|
|
|
125 |
def dislike(current_model, *args):
|
126 |
return current_model.dislike(*args)
|
127 |
|
@@ -134,7 +164,7 @@ def count_token(input_str):
|
|
134 |
return length
|
135 |
|
136 |
|
137 |
-
def markdown_to_html_with_syntax_highlight(md_str):
|
138 |
def replacer(match):
|
139 |
lang = match.group(1) or "text"
|
140 |
code = match.group(2)
|
@@ -156,7 +186,7 @@ def markdown_to_html_with_syntax_highlight(md_str): # deprecated
|
|
156 |
return html_str
|
157 |
|
158 |
|
159 |
-
def normalize_markdown(md_text: str) -> str:
|
160 |
lines = md_text.split("\n")
|
161 |
normalized_lines = []
|
162 |
inside_list = False
|
@@ -180,7 +210,7 @@ def normalize_markdown(md_text: str) -> str: # deprecated
|
|
180 |
return "\n".join(normalized_lines)
|
181 |
|
182 |
|
183 |
-
def convert_mdtext(md_text):
|
184 |
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
185 |
inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
|
186 |
code_blocks = code_block_pattern.findall(md_text)
|
@@ -209,16 +239,22 @@ def clip_rawtext(chat_message, need_escape=True):
|
|
209 |
# first, clip hr line
|
210 |
hr_pattern = r'\n\n<hr class="append-display no-in-raw" />(.*?)'
|
211 |
hr_match = re.search(hr_pattern, chat_message, re.DOTALL)
|
212 |
-
message_clipped = chat_message[:hr_match.start()] if hr_match else chat_message
|
213 |
# second, avoid agent-prefix being escaped
|
214 |
-
agent_prefix_pattern =
|
|
|
|
|
215 |
# agent_matches = re.findall(agent_prefix_pattern, message_clipped)
|
216 |
agent_parts = re.split(agent_prefix_pattern, message_clipped, flags=re.DOTALL)
|
217 |
final_message = ""
|
218 |
for i, part in enumerate(agent_parts):
|
219 |
if i % 2 == 0:
|
220 |
if part != "" and part != "\n":
|
221 |
-
final_message +=
|
|
|
|
|
|
|
|
|
222 |
else:
|
223 |
final_message += part
|
224 |
return final_message
|
@@ -248,51 +284,53 @@ def convert_bot_before_marked(chat_message):
|
|
248 |
md = f'<div class="md-message">\n\n{result}\n</div>'
|
249 |
return raw + md
|
250 |
|
|
|
251 |
def convert_user_before_marked(chat_message):
|
252 |
if '<div class="user-message">' in chat_message:
|
253 |
return chat_message
|
254 |
else:
|
255 |
return f'<div class="user-message">{escape_markdown(chat_message)}</div>'
|
256 |
|
|
|
257 |
def escape_markdown(text):
|
258 |
"""
|
259 |
Escape Markdown special characters to HTML-safe equivalents.
|
260 |
"""
|
261 |
escape_chars = {
|
262 |
# ' ': ' ',
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
}
|
284 |
-
text = text.replace(
|
285 |
-
return
|
286 |
|
287 |
|
288 |
-
def convert_asis(userinput):
|
289 |
return (
|
290 |
f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
|
291 |
+ ALREADY_CONVERTED_MARK
|
292 |
)
|
293 |
|
294 |
|
295 |
-
def detect_converted_mark(userinput):
|
296 |
try:
|
297 |
if userinput.endswith(ALREADY_CONVERTED_MARK):
|
298 |
return True
|
@@ -302,7 +340,7 @@ def detect_converted_mark(userinput): # deprecated
|
|
302 |
return True
|
303 |
|
304 |
|
305 |
-
def detect_language(code):
|
306 |
if code.startswith("\n"):
|
307 |
first_line = ""
|
308 |
else:
|
@@ -328,7 +366,10 @@ def construct_assistant(text):
|
|
328 |
return construct_text("assistant", text)
|
329 |
|
330 |
|
331 |
-
def save_file(filename,
|
|
|
|
|
|
|
332 |
os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
|
333 |
if filename is None:
|
334 |
filename = new_auto_history_filename(user_name)
|
@@ -339,22 +380,38 @@ def save_file(filename, system, history, chatbot, user_name):
|
|
339 |
if filename == ".json":
|
340 |
raise Exception("文件名不能为空")
|
341 |
|
342 |
-
json_s = {
|
343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
if not filename == os.path.basename(filename):
|
345 |
history_file_path = filename
|
346 |
else:
|
347 |
history_file_path = os.path.join(HISTORY_DIR, user_name, filename)
|
348 |
|
349 |
-
with open(history_file_path, "w", encoding=
|
350 |
-
json.dump(json_s, f, ensure_ascii=False)
|
351 |
|
352 |
filename = os.path.basename(filename)
|
353 |
filename_md = filename[:-5] + ".md"
|
354 |
md_s = f"system: \n- {system} \n"
|
355 |
for data in history:
|
356 |
md_s += f"\n{data['role']}: \n- {data['content']} \n"
|
357 |
-
with open(
|
|
|
|
|
358 |
f.write(md_s)
|
359 |
return os.path.join(HISTORY_DIR, user_name, filename)
|
360 |
|
@@ -362,8 +419,12 @@ def save_file(filename, system, history, chatbot, user_name):
|
|
362 |
def sorted_by_pinyin(list):
|
363 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
364 |
|
|
|
365 |
def sorted_by_last_modified_time(list, dir):
|
366 |
-
return sorted(
|
|
|
|
|
|
|
367 |
|
368 |
def get_file_names_by_type(dir, filetypes=[".json"]):
|
369 |
logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes}")
|
@@ -373,6 +434,7 @@ def get_file_names_by_type(dir, filetypes=[".json"]):
|
|
373 |
logging.debug(f"files are:{files}")
|
374 |
return files
|
375 |
|
|
|
376 |
def get_file_names_by_pinyin(dir, filetypes=[".json"]):
|
377 |
files = get_file_names_by_type(dir, filetypes)
|
378 |
if files != [""]:
|
@@ -380,10 +442,12 @@ def get_file_names_by_pinyin(dir, filetypes=[".json"]):
|
|
380 |
logging.debug(f"files are:{files}")
|
381 |
return files
|
382 |
|
|
|
383 |
def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]):
|
384 |
files = get_file_names_by_pinyin(dir, filetypes)
|
385 |
return gr.Dropdown.update(choices=files)
|
386 |
|
|
|
387 |
def get_file_names_by_last_modified_time(dir, filetypes=[".json"]):
|
388 |
files = get_file_names_by_type(dir, filetypes)
|
389 |
if files != [""]:
|
@@ -397,21 +461,29 @@ def get_history_names(user_name=""):
|
|
397 |
if user_name == "" and hide_history_when_not_logged_in:
|
398 |
return []
|
399 |
else:
|
400 |
-
history_files = get_file_names_by_last_modified_time(
|
401 |
-
|
|
|
|
|
402 |
return history_files
|
403 |
|
|
|
404 |
def get_first_history_name(user_name=""):
|
405 |
history_names = get_history_names(user_name)
|
406 |
return history_names[0] if history_names else None
|
407 |
|
|
|
408 |
def get_history_list(user_name=""):
|
409 |
history_names = get_history_names(user_name)
|
410 |
return gr.Radio.update(choices=history_names)
|
411 |
|
|
|
412 |
def init_history_list(user_name=""):
|
413 |
history_names = get_history_names(user_name)
|
414 |
-
return gr.Radio.update(
|
|
|
|
|
|
|
415 |
|
416 |
def filter_history(user_name, keyword):
|
417 |
history_names = get_history_names(user_name)
|
@@ -421,6 +493,7 @@ def filter_history(user_name, keyword):
|
|
421 |
except:
|
422 |
return gr.update(choices=history_names)
|
423 |
|
|
|
424 |
def load_template(filename, mode=0):
|
425 |
logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
|
426 |
lines = []
|
@@ -441,15 +514,14 @@ def load_template(filename, mode=0):
|
|
441 |
return {row[0]: row[1] for row in lines}
|
442 |
else:
|
443 |
choices = sorted_by_pinyin([row[0] for row in lines])
|
444 |
-
return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
|
445 |
-
choices=choices
|
446 |
-
)
|
447 |
|
448 |
|
449 |
def get_template_names():
|
450 |
logging.debug("获取模板文件名列表")
|
451 |
return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"])
|
452 |
|
|
|
453 |
def get_template_dropdown():
|
454 |
logging.debug("获取模板下拉菜单")
|
455 |
template_names = get_template_names()
|
@@ -524,9 +596,7 @@ def get_geoip():
|
|
524 |
if "error" in data.keys():
|
525 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
526 |
if data["reason"] == "RateLimited":
|
527 |
-
return (
|
528 |
-
i18n("您的IP区域:未知。")
|
529 |
-
)
|
530 |
else:
|
531 |
return i18n("获取IP地理位置失败。原因:") + f"{data['reason']}" + i18n("。你仍然可以使用聊天功能。")
|
532 |
else:
|
@@ -590,29 +660,36 @@ def update_chuanhu():
|
|
590 |
if update_status == "success":
|
591 |
logging.info("Successfully updated, restart needed")
|
592 |
status = '<span id="update-status" class="hideK">success</span>'
|
593 |
-
return gr.Markdown.update(value=i18n("更新成功,请重启本程序")+status)
|
594 |
else:
|
595 |
status = '<span id="update-status" class="hideK">failure</span>'
|
596 |
-
return gr.Markdown.update(
|
|
|
|
|
|
|
|
|
|
|
597 |
|
598 |
|
599 |
-
def add_source_numbers(lst, source_name
|
600 |
if use_source:
|
601 |
-
return [
|
|
|
|
|
|
|
602 |
else:
|
603 |
return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
|
604 |
|
|
|
605 |
def add_details(lst):
|
606 |
nodes = []
|
607 |
for index, txt in enumerate(lst):
|
608 |
brief = txt[:25].replace("\n", "")
|
609 |
-
nodes.append(
|
610 |
-
f"<details><summary>{brief}...</summary><p>{txt}</p></details>"
|
611 |
-
)
|
612 |
return nodes
|
613 |
|
614 |
|
615 |
-
def sheet_to_string(sheet, sheet_name
|
616 |
result = []
|
617 |
for index, row in sheet.iterrows():
|
618 |
row_string = ""
|
@@ -623,59 +700,70 @@ def sheet_to_string(sheet, sheet_name = None):
|
|
623 |
result.append(row_string)
|
624 |
return result
|
625 |
|
|
|
626 |
def excel_to_string(file_path):
|
627 |
# 读取Excel文件中的所有工作表
|
628 |
-
excel_file = pd.read_excel(file_path, engine=
|
629 |
|
630 |
# 初始化结果字符串
|
631 |
result = []
|
632 |
|
633 |
# 遍历每一个工作表
|
634 |
for sheet_name, sheet_data in excel_file.items():
|
635 |
-
|
636 |
# 处理当前工作表并添加到结果字符串
|
637 |
result += sheet_to_string(sheet_data, sheet_name=sheet_name)
|
638 |
|
639 |
-
|
640 |
return result
|
641 |
|
|
|
642 |
def get_last_day_of_month(any_day):
|
643 |
# The day 28 exists in every month. 4 days later, it's always next month
|
644 |
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
645 |
# subtracting the number of the current day brings us back one month
|
646 |
return next_month - datetime.timedelta(days=next_month.day)
|
647 |
|
|
|
648 |
def get_model_source(model_name, alternative_source):
|
649 |
if model_name == "gpt2-medium":
|
650 |
return "https://huggingface.co/gpt2-medium"
|
651 |
|
|
|
652 |
def refresh_ui_elements_on_load(current_model, selected_model_name, user_name):
|
653 |
current_model.set_user_identifier(user_name)
|
654 |
return toggle_like_btn_visibility(selected_model_name), *current_model.auto_load()
|
655 |
|
|
|
656 |
def toggle_like_btn_visibility(selected_model_name):
|
657 |
if selected_model_name == "xmchat":
|
658 |
return gr.update(visible=True)
|
659 |
else:
|
660 |
return gr.update(visible=False)
|
661 |
|
|
|
662 |
def get_corresponding_file_type_by_model_name(selected_model_name):
|
663 |
if selected_model_name in ["xmchat", "GPT4 Vision"]:
|
664 |
return ["image"]
|
665 |
else:
|
666 |
return [".pdf", ".docx", ".pptx", ".epub", ".xlsx", ".txt", "text"]
|
667 |
|
|
|
668 |
# def toggle_file_type(selected_model_name):
|
669 |
# return gr.Files.update(file_types=get_corresponding_file_type_by_model_name(selected_model_name))
|
670 |
|
|
|
671 |
def new_auto_history_filename(username):
|
672 |
latest_file = get_first_history_name(username)
|
673 |
if latest_file:
|
674 |
-
with open(
|
|
|
|
|
|
|
|
|
675 |
if len(f.read()) == 0:
|
676 |
return latest_file
|
677 |
-
now = i18n("新对话 ") + datetime.datetime.now().strftime(
|
678 |
-
return f
|
|
|
679 |
|
680 |
def get_history_filepath(username):
|
681 |
dirname = os.path.join(HISTORY_DIR, username)
|
@@ -687,20 +775,28 @@ def get_history_filepath(username):
|
|
687 |
latest_file = os.path.join(dirname, latest_file)
|
688 |
return latest_file
|
689 |
|
|
|
690 |
def beautify_err_msg(err_msg):
|
691 |
-
if "insufficient_quota" in
|
692 |
-
return i18n(
|
|
|
|
|
693 |
if "The model `gpt-4` does not exist" in err_msg:
|
694 |
-
return i18n(
|
|
|
|
|
695 |
if "Resource not found" in err_msg:
|
696 |
return i18n("请查看 config_example.json,配置 Azure OpenAI")
|
697 |
return err_msg
|
698 |
|
|
|
699 |
def auth_from_conf(username, password):
|
700 |
try:
|
701 |
with open("config.json", encoding="utf-8") as f:
|
702 |
conf = json.load(f)
|
703 |
-
usernames, passwords = [i[0] for i in conf["users"]], [
|
|
|
|
|
704 |
if username in usernames:
|
705 |
if passwords[usernames.index(username)] == password:
|
706 |
return True
|
@@ -708,6 +804,7 @@ def auth_from_conf(username, password):
|
|
708 |
except:
|
709 |
return False
|
710 |
|
|
|
711 |
def get_file_hash(file_src=None, file_paths=None):
|
712 |
if file_src:
|
713 |
file_paths = [x.name for x in file_src]
|
@@ -721,12 +818,14 @@ def get_file_hash(file_src=None, file_paths=None):
|
|
721 |
|
722 |
return md5_hash.hexdigest()
|
723 |
|
|
|
724 |
def myprint(**args):
|
725 |
print(args)
|
726 |
|
|
|
727 |
def replace_special_symbols(string, replace_string=" "):
|
728 |
# 定义正则表达式,匹配所有特殊符号
|
729 |
-
pattern = r
|
730 |
|
731 |
new_string = re.sub(pattern, replace_string, string)
|
732 |
|
|
|
31 |
headers: List[str]
|
32 |
data: List[List[str | int | bool]]
|
33 |
|
34 |
+
|
35 |
def predict(current_model, *args):
|
36 |
iter = current_model.predict(*args)
|
37 |
for i in iter:
|
38 |
yield i
|
39 |
|
40 |
+
|
41 |
def billing_info(current_model):
|
42 |
return current_model.billing_info()
|
43 |
|
44 |
+
|
45 |
def set_key(current_model, *args):
|
46 |
return current_model.set_key(*args)
|
47 |
|
48 |
+
|
49 |
def load_chat_history(current_model, *args):
|
50 |
return current_model.load_chat_history(*args)
|
51 |
|
52 |
+
|
53 |
def delete_chat_history(current_model, *args):
|
54 |
return current_model.delete_chat_history(*args)
|
55 |
|
56 |
+
|
57 |
def interrupt(current_model, *args):
|
58 |
return current_model.interrupt(*args)
|
59 |
|
60 |
+
|
61 |
def reset(current_model, *args):
|
62 |
return current_model.reset(*args)
|
63 |
|
64 |
+
|
65 |
def retry(current_model, *args):
|
66 |
iter = current_model.retry(*args)
|
67 |
for i in iter:
|
68 |
yield i
|
69 |
|
70 |
+
|
71 |
def delete_first_conversation(current_model, *args):
|
72 |
return current_model.delete_first_conversation(*args)
|
73 |
|
74 |
+
|
75 |
def delete_last_conversation(current_model, *args):
|
76 |
return current_model.delete_last_conversation(*args)
|
77 |
|
78 |
+
|
79 |
def set_system_prompt(current_model, *args):
|
80 |
return current_model.set_system_prompt(*args)
|
81 |
|
82 |
+
|
83 |
def rename_chat_history(current_model, *args):
|
84 |
return current_model.rename_chat_history(*args)
|
85 |
|
86 |
+
|
87 |
def auto_name_chat_history(current_model, *args):
|
88 |
return current_model.auto_name_chat_history(*args)
|
89 |
|
90 |
+
|
91 |
def export_markdown(current_model, *args):
|
92 |
return current_model.export_markdown(*args)
|
93 |
|
94 |
+
|
95 |
def upload_chat_history(current_model, *args):
|
96 |
return current_model.load_chat_history(*args)
|
97 |
|
98 |
+
|
99 |
def set_token_upper_limit(current_model, *args):
|
100 |
return current_model.set_token_upper_limit(*args)
|
101 |
|
102 |
+
|
103 |
def set_temperature(current_model, *args):
|
104 |
current_model.set_temperature(*args)
|
105 |
|
106 |
+
|
107 |
def set_top_p(current_model, *args):
|
108 |
current_model.set_top_p(*args)
|
109 |
|
110 |
+
|
111 |
def set_n_choices(current_model, *args):
|
112 |
current_model.set_n_choices(*args)
|
113 |
|
114 |
+
|
115 |
def set_stop_sequence(current_model, *args):
|
116 |
current_model.set_stop_sequence(*args)
|
117 |
|
118 |
+
|
119 |
def set_max_tokens(current_model, *args):
|
120 |
current_model.set_max_tokens(*args)
|
121 |
|
122 |
+
|
123 |
def set_presence_penalty(current_model, *args):
|
124 |
current_model.set_presence_penalty(*args)
|
125 |
|
126 |
+
|
127 |
def set_frequency_penalty(current_model, *args):
|
128 |
current_model.set_frequency_penalty(*args)
|
129 |
|
130 |
+
|
131 |
def set_logit_bias(current_model, *args):
|
132 |
current_model.set_logit_bias(*args)
|
133 |
|
134 |
+
|
135 |
def set_user_identifier(current_model, *args):
|
136 |
current_model.set_user_identifier(*args)
|
137 |
|
138 |
+
|
139 |
def set_single_turn(current_model, *args):
|
140 |
current_model.set_single_turn(*args)
|
141 |
|
142 |
+
|
143 |
def handle_file_upload(current_model, *args):
|
144 |
return current_model.handle_file_upload(*args)
|
145 |
|
146 |
+
|
147 |
def handle_summarize_index(current_model, *args):
|
148 |
return current_model.summarize_index(*args)
|
149 |
|
150 |
+
|
151 |
def like(current_model, *args):
|
152 |
return current_model.like(*args)
|
153 |
|
154 |
+
|
155 |
def dislike(current_model, *args):
|
156 |
return current_model.dislike(*args)
|
157 |
|
|
|
164 |
return length
|
165 |
|
166 |
|
167 |
+
def markdown_to_html_with_syntax_highlight(md_str): # deprecated
|
168 |
def replacer(match):
|
169 |
lang = match.group(1) or "text"
|
170 |
code = match.group(2)
|
|
|
186 |
return html_str
|
187 |
|
188 |
|
189 |
+
def normalize_markdown(md_text: str) -> str: # deprecated
|
190 |
lines = md_text.split("\n")
|
191 |
normalized_lines = []
|
192 |
inside_list = False
|
|
|
210 |
return "\n".join(normalized_lines)
|
211 |
|
212 |
|
213 |
+
def convert_mdtext(md_text): # deprecated
|
214 |
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
215 |
inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
|
216 |
code_blocks = code_block_pattern.findall(md_text)
|
|
|
239 |
# first, clip hr line
|
240 |
hr_pattern = r'\n\n<hr class="append-display no-in-raw" />(.*?)'
|
241 |
hr_match = re.search(hr_pattern, chat_message, re.DOTALL)
|
242 |
+
message_clipped = chat_message[: hr_match.start()] if hr_match else chat_message
|
243 |
# second, avoid agent-prefix being escaped
|
244 |
+
agent_prefix_pattern = (
|
245 |
+
r'(<!-- S O PREFIX --><p class="agent-prefix">.*?<\/p><!-- E O PREFIX -->)'
|
246 |
+
)
|
247 |
# agent_matches = re.findall(agent_prefix_pattern, message_clipped)
|
248 |
agent_parts = re.split(agent_prefix_pattern, message_clipped, flags=re.DOTALL)
|
249 |
final_message = ""
|
250 |
for i, part in enumerate(agent_parts):
|
251 |
if i % 2 == 0:
|
252 |
if part != "" and part != "\n":
|
253 |
+
final_message += (
|
254 |
+
f'<pre class="fake-pre">{escape_markdown(part)}</pre>'
|
255 |
+
if need_escape
|
256 |
+
else f'<pre class="fake-pre">{part}</pre>'
|
257 |
+
)
|
258 |
else:
|
259 |
final_message += part
|
260 |
return final_message
|
|
|
284 |
md = f'<div class="md-message">\n\n{result}\n</div>'
|
285 |
return raw + md
|
286 |
|
287 |
+
|
288 |
def convert_user_before_marked(chat_message):
|
289 |
if '<div class="user-message">' in chat_message:
|
290 |
return chat_message
|
291 |
else:
|
292 |
return f'<div class="user-message">{escape_markdown(chat_message)}</div>'
|
293 |
|
294 |
+
|
295 |
def escape_markdown(text):
|
296 |
"""
|
297 |
Escape Markdown special characters to HTML-safe equivalents.
|
298 |
"""
|
299 |
escape_chars = {
|
300 |
# ' ': ' ',
|
301 |
+
"_": "_",
|
302 |
+
"*": "*",
|
303 |
+
"[": "[",
|
304 |
+
"]": "]",
|
305 |
+
"(": "(",
|
306 |
+
")": ")",
|
307 |
+
"{": "{",
|
308 |
+
"}": "}",
|
309 |
+
"#": "#",
|
310 |
+
"+": "+",
|
311 |
+
"-": "-",
|
312 |
+
".": ".",
|
313 |
+
"!": "!",
|
314 |
+
"`": "`",
|
315 |
+
">": ">",
|
316 |
+
"<": "<",
|
317 |
+
"|": "|",
|
318 |
+
"$": "$",
|
319 |
+
":": ":",
|
320 |
+
"\n": "<br>",
|
321 |
}
|
322 |
+
text = text.replace(" ", " ")
|
323 |
+
return "".join(escape_chars.get(c, c) for c in text)
|
324 |
|
325 |
|
326 |
+
def convert_asis(userinput): # deprecated
|
327 |
return (
|
328 |
f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
|
329 |
+ ALREADY_CONVERTED_MARK
|
330 |
)
|
331 |
|
332 |
|
333 |
+
def detect_converted_mark(userinput): # deprecated
|
334 |
try:
|
335 |
if userinput.endswith(ALREADY_CONVERTED_MARK):
|
336 |
return True
|
|
|
340 |
return True
|
341 |
|
342 |
|
343 |
+
def detect_language(code): # deprecated
|
344 |
if code.startswith("\n"):
|
345 |
first_line = ""
|
346 |
else:
|
|
|
366 |
return construct_text("assistant", text)
|
367 |
|
368 |
|
369 |
+
def save_file(filename, model, chatbot):
|
370 |
+
system = model.system_prompt
|
371 |
+
history = model.history
|
372 |
+
user_name = model.user_name
|
373 |
os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
|
374 |
if filename is None:
|
375 |
filename = new_auto_history_filename(user_name)
|
|
|
380 |
if filename == ".json":
|
381 |
raise Exception("文件名不能为空")
|
382 |
|
383 |
+
json_s = {
|
384 |
+
"system": system,
|
385 |
+
"history": history,
|
386 |
+
"chatbot": chatbot,
|
387 |
+
"single_turn": model.single_turn,
|
388 |
+
"temperature": model.temperature,
|
389 |
+
"top_p": model.top_p,
|
390 |
+
"n_choices": model.n_choices,
|
391 |
+
"stop_sequence": model.stop_sequence,
|
392 |
+
"max_generation_token": model.max_generation_token,
|
393 |
+
"presence_penalty": model.presence_penalty,
|
394 |
+
"frequency_penalty": model.frequency_penalty,
|
395 |
+
"logit_bias": model.logit_bias,
|
396 |
+
"user_identifier": model.user_identifier,
|
397 |
+
"metadata": model.metadata
|
398 |
+
}
|
399 |
if not filename == os.path.basename(filename):
|
400 |
history_file_path = filename
|
401 |
else:
|
402 |
history_file_path = os.path.join(HISTORY_DIR, user_name, filename)
|
403 |
|
404 |
+
with open(history_file_path, "w", encoding="utf-8") as f:
|
405 |
+
json.dump(json_s, f, ensure_ascii=False, indent=4)
|
406 |
|
407 |
filename = os.path.basename(filename)
|
408 |
filename_md = filename[:-5] + ".md"
|
409 |
md_s = f"system: \n- {system} \n"
|
410 |
for data in history:
|
411 |
md_s += f"\n{data['role']}: \n- {data['content']} \n"
|
412 |
+
with open(
|
413 |
+
os.path.join(HISTORY_DIR, user_name, filename_md), "w", encoding="utf8"
|
414 |
+
) as f:
|
415 |
f.write(md_s)
|
416 |
return os.path.join(HISTORY_DIR, user_name, filename)
|
417 |
|
|
|
419 |
def sorted_by_pinyin(list):
|
420 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
421 |
|
422 |
+
|
423 |
def sorted_by_last_modified_time(list, dir):
|
424 |
+
return sorted(
|
425 |
+
list, key=lambda char: os.path.getctime(os.path.join(dir, char)), reverse=True
|
426 |
+
)
|
427 |
+
|
428 |
|
429 |
def get_file_names_by_type(dir, filetypes=[".json"]):
|
430 |
logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes}")
|
|
|
434 |
logging.debug(f"files are:{files}")
|
435 |
return files
|
436 |
|
437 |
+
|
438 |
def get_file_names_by_pinyin(dir, filetypes=[".json"]):
|
439 |
files = get_file_names_by_type(dir, filetypes)
|
440 |
if files != [""]:
|
|
|
442 |
logging.debug(f"files are:{files}")
|
443 |
return files
|
444 |
|
445 |
+
|
446 |
def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]):
|
447 |
files = get_file_names_by_pinyin(dir, filetypes)
|
448 |
return gr.Dropdown.update(choices=files)
|
449 |
|
450 |
+
|
451 |
def get_file_names_by_last_modified_time(dir, filetypes=[".json"]):
|
452 |
files = get_file_names_by_type(dir, filetypes)
|
453 |
if files != [""]:
|
|
|
461 |
if user_name == "" and hide_history_when_not_logged_in:
|
462 |
return []
|
463 |
else:
|
464 |
+
history_files = get_file_names_by_last_modified_time(
|
465 |
+
os.path.join(HISTORY_DIR, user_name)
|
466 |
+
)
|
467 |
+
history_files = [f[: f.rfind(".")] for f in history_files]
|
468 |
return history_files
|
469 |
|
470 |
+
|
471 |
def get_first_history_name(user_name=""):
|
472 |
history_names = get_history_names(user_name)
|
473 |
return history_names[0] if history_names else None
|
474 |
|
475 |
+
|
476 |
def get_history_list(user_name=""):
|
477 |
history_names = get_history_names(user_name)
|
478 |
return gr.Radio.update(choices=history_names)
|
479 |
|
480 |
+
|
481 |
def init_history_list(user_name=""):
|
482 |
history_names = get_history_names(user_name)
|
483 |
+
return gr.Radio.update(
|
484 |
+
choices=history_names, value=history_names[0] if history_names else ""
|
485 |
+
)
|
486 |
+
|
487 |
|
488 |
def filter_history(user_name, keyword):
|
489 |
history_names = get_history_names(user_name)
|
|
|
493 |
except:
|
494 |
return gr.update(choices=history_names)
|
495 |
|
496 |
+
|
497 |
def load_template(filename, mode=0):
|
498 |
logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
|
499 |
lines = []
|
|
|
514 |
return {row[0]: row[1] for row in lines}
|
515 |
else:
|
516 |
choices = sorted_by_pinyin([row[0] for row in lines])
|
517 |
+
return {row[0]: row[1] for row in lines}, gr.Dropdown.update(choices=choices)
|
|
|
|
|
518 |
|
519 |
|
520 |
def get_template_names():
|
521 |
logging.debug("获取模板文件名列表")
|
522 |
return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"])
|
523 |
|
524 |
+
|
525 |
def get_template_dropdown():
|
526 |
logging.debug("获取模板下拉菜单")
|
527 |
template_names = get_template_names()
|
|
|
596 |
if "error" in data.keys():
|
597 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
598 |
if data["reason"] == "RateLimited":
|
599 |
+
return i18n("您的IP区域:未知。")
|
|
|
|
|
600 |
else:
|
601 |
return i18n("获取IP地理位置失败。原因:") + f"{data['reason']}" + i18n("。你仍然可以使用聊天功能。")
|
602 |
else:
|
|
|
660 |
if update_status == "success":
|
661 |
logging.info("Successfully updated, restart needed")
|
662 |
status = '<span id="update-status" class="hideK">success</span>'
|
663 |
+
return gr.Markdown.update(value=i18n("更新成功,请重启本程序") + status)
|
664 |
else:
|
665 |
status = '<span id="update-status" class="hideK">failure</span>'
|
666 |
+
return gr.Markdown.update(
|
667 |
+
value=i18n(
|
668 |
+
"更新失败,请尝试[手动更新](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程#手动更新)"
|
669 |
+
)
|
670 |
+
+ status
|
671 |
+
)
|
672 |
|
673 |
|
674 |
+
def add_source_numbers(lst, source_name="Source", use_source=True):
|
675 |
if use_source:
|
676 |
+
return [
|
677 |
+
f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}'
|
678 |
+
for idx, item in enumerate(lst)
|
679 |
+
]
|
680 |
else:
|
681 |
return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
|
682 |
|
683 |
+
|
684 |
def add_details(lst):
|
685 |
nodes = []
|
686 |
for index, txt in enumerate(lst):
|
687 |
brief = txt[:25].replace("\n", "")
|
688 |
+
nodes.append(f"<details><summary>{brief}...</summary><p>{txt}</p></details>")
|
|
|
|
|
689 |
return nodes
|
690 |
|
691 |
|
692 |
+
def sheet_to_string(sheet, sheet_name=None):
|
693 |
result = []
|
694 |
for index, row in sheet.iterrows():
|
695 |
row_string = ""
|
|
|
700 |
result.append(row_string)
|
701 |
return result
|
702 |
|
703 |
+
|
704 |
def excel_to_string(file_path):
|
705 |
# 读取Excel文件中的所有工作表
|
706 |
+
excel_file = pd.read_excel(file_path, engine="openpyxl", sheet_name=None)
|
707 |
|
708 |
# 初始化结果字符串
|
709 |
result = []
|
710 |
|
711 |
# 遍历每一个工作表
|
712 |
for sheet_name, sheet_data in excel_file.items():
|
|
|
713 |
# 处理当前工作表并添加到结果字符串
|
714 |
result += sheet_to_string(sheet_data, sheet_name=sheet_name)
|
715 |
|
|
|
716 |
return result
|
717 |
|
718 |
+
|
719 |
def get_last_day_of_month(any_day):
|
720 |
# The day 28 exists in every month. 4 days later, it's always next month
|
721 |
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
722 |
# subtracting the number of the current day brings us back one month
|
723 |
return next_month - datetime.timedelta(days=next_month.day)
|
724 |
|
725 |
+
|
726 |
def get_model_source(model_name, alternative_source):
|
727 |
if model_name == "gpt2-medium":
|
728 |
return "https://huggingface.co/gpt2-medium"
|
729 |
|
730 |
+
|
731 |
def refresh_ui_elements_on_load(current_model, selected_model_name, user_name):
|
732 |
current_model.set_user_identifier(user_name)
|
733 |
return toggle_like_btn_visibility(selected_model_name), *current_model.auto_load()
|
734 |
|
735 |
+
|
736 |
def toggle_like_btn_visibility(selected_model_name):
|
737 |
if selected_model_name == "xmchat":
|
738 |
return gr.update(visible=True)
|
739 |
else:
|
740 |
return gr.update(visible=False)
|
741 |
|
742 |
+
|
743 |
def get_corresponding_file_type_by_model_name(selected_model_name):
|
744 |
if selected_model_name in ["xmchat", "GPT4 Vision"]:
|
745 |
return ["image"]
|
746 |
else:
|
747 |
return [".pdf", ".docx", ".pptx", ".epub", ".xlsx", ".txt", "text"]
|
748 |
|
749 |
+
|
750 |
# def toggle_file_type(selected_model_name):
|
751 |
# return gr.Files.update(file_types=get_corresponding_file_type_by_model_name(selected_model_name))
|
752 |
|
753 |
+
|
754 |
def new_auto_history_filename(username):
|
755 |
latest_file = get_first_history_name(username)
|
756 |
if latest_file:
|
757 |
+
with open(
|
758 |
+
os.path.join(HISTORY_DIR, username, latest_file + ".json"),
|
759 |
+
"r",
|
760 |
+
encoding="utf-8",
|
761 |
+
) as f:
|
762 |
if len(f.read()) == 0:
|
763 |
return latest_file
|
764 |
+
now = i18n("新对话 ") + datetime.datetime.now().strftime("%m-%d %H-%M")
|
765 |
+
return f"{now}.json"
|
766 |
+
|
767 |
|
768 |
def get_history_filepath(username):
|
769 |
dirname = os.path.join(HISTORY_DIR, username)
|
|
|
775 |
latest_file = os.path.join(dirname, latest_file)
|
776 |
return latest_file
|
777 |
|
778 |
+
|
779 |
def beautify_err_msg(err_msg):
|
780 |
+
if "insufficient_quota" in err_msg:
|
781 |
+
return i18n(
|
782 |
+
"剩余配额不足,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98#you-exceeded-your-current-quota-please-check-your-plan-and-billing-details)"
|
783 |
+
)
|
784 |
if "The model `gpt-4` does not exist" in err_msg:
|
785 |
+
return i18n(
|
786 |
+
"你没有权限访问 GPT4,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues/843)"
|
787 |
+
)
|
788 |
if "Resource not found" in err_msg:
|
789 |
return i18n("请查看 config_example.json,配置 Azure OpenAI")
|
790 |
return err_msg
|
791 |
|
792 |
+
|
793 |
def auth_from_conf(username, password):
|
794 |
try:
|
795 |
with open("config.json", encoding="utf-8") as f:
|
796 |
conf = json.load(f)
|
797 |
+
usernames, passwords = [i[0] for i in conf["users"]], [
|
798 |
+
i[1] for i in conf["users"]
|
799 |
+
]
|
800 |
if username in usernames:
|
801 |
if passwords[usernames.index(username)] == password:
|
802 |
return True
|
|
|
804 |
except:
|
805 |
return False
|
806 |
|
807 |
+
|
808 |
def get_file_hash(file_src=None, file_paths=None):
|
809 |
if file_src:
|
810 |
file_paths = [x.name for x in file_src]
|
|
|
818 |
|
819 |
return md5_hash.hexdigest()
|
820 |
|
821 |
+
|
822 |
def myprint(**args):
|
823 |
print(args)
|
824 |
|
825 |
+
|
826 |
def replace_special_symbols(string, replace_string=" "):
|
827 |
# 定义正则表达式,匹配所有特殊符号
|
828 |
+
pattern = r"[!@#$%^&*()<>?/\|}{~:]"
|
829 |
|
830 |
new_string = re.sub(pattern, replace_string, string)
|
831 |
|