Tuchuanhuhuhu commited on
Commit
76a432f
·
1 Parent(s): c9a9fba

feat: 保存更多参数

Browse files
Files changed (2) hide show
  1. modules/models/base_model.py +195 -93
  2. modules/utils.py +165 -66
modules/models/base_model.py CHANGED
@@ -70,13 +70,13 @@ class CallbackToIterator:
70
 
71
 
72
  def get_action_description(text):
73
- match = re.search('```(.*?)```', text, re.S)
74
  json_text = match.group(1)
75
  # 把json转化为python字典
76
  json_dict = json.loads(json_text)
77
  # 提取'action'和'action_input'的值
78
- action_name = json_dict['action']
79
- action_input = json_dict['action_input']
80
  if action_name != "Final Answer":
81
  return f'<!-- S O PREFIX --><p class="agent-prefix">{action_name}: {action_input}\n</p><!-- E O PREFIX -->'
82
  else:
@@ -84,7 +84,6 @@ def get_action_description(text):
84
 
85
 
86
  class ChuanhuCallbackHandler(BaseCallbackHandler):
87
-
88
  def __init__(self, callback) -> None:
89
  """Initialize callback handler."""
90
  self.callback = callback
@@ -124,7 +123,12 @@ class ChuanhuCallbackHandler(BaseCallbackHandler):
124
  """Run on new LLM token. Only available when streaming is enabled."""
125
  self.callback(token)
126
 
127
- def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
 
 
 
 
 
128
  """Run when a chat model starts running."""
129
  pass
130
 
@@ -228,24 +232,26 @@ class BaseLLMModel:
228
  self.need_api_key = False
229
  self.single_turn = False
230
  self.history_file_path = get_first_history_name(user)
 
231
 
232
  self.temperature = temperature
233
  self.top_p = top_p
234
  self.n_choices = n_choices
235
  self.stop_sequence = stop
236
- self.max_generation_token = None
237
  self.presence_penalty = presence_penalty
238
  self.frequency_penalty = frequency_penalty
239
  self.logit_bias = logit_bias
240
  self.user_identifier = user
241
 
 
 
242
  def get_answer_stream_iter(self):
243
- """stream predict, need to be implemented
244
- conversations are stored in self.history, with the most recent question, in OpenAI format
245
- should return a generator, each time give the next word (str) in the answer
246
  """
247
- logging.warning(
248
- "stream predict not implemented, using at once predict instead")
249
  response, _ = self.get_answer_at_once()
250
  yield response
251
 
@@ -256,8 +262,7 @@ class BaseLLMModel:
256
  the answer (str)
257
  total token count (int)
258
  """
259
- logging.warning(
260
- "at once predict not implemented, using stream predict instead")
261
  response_iter = self.get_answer_stream_iter()
262
  count = 0
263
  for response in response_iter:
@@ -291,7 +296,9 @@ class BaseLLMModel:
291
  stream_iter = self.get_answer_stream_iter()
292
 
293
  if display_append:
294
- display_append = '\n\n<hr class="append-display no-in-raw" />' + display_append
 
 
295
  partial_text = ""
296
  token_increment = 1
297
  for partial_text in stream_iter:
@@ -322,11 +329,9 @@ class BaseLLMModel:
322
  self.history[-2] = construct_user(fake_input)
323
  chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
324
  if fake_input is not None:
325
- self.all_token_counts[-1] += count_token(
326
- construct_assistant(ai_reply))
327
  else:
328
- self.all_token_counts[-1] = total_token_count - \
329
- sum(self.all_token_counts)
330
  status_text = self.token_message()
331
  return chatbot, status_text
332
 
@@ -349,46 +354,80 @@ class BaseLLMModel:
349
  from langchain.prompts import PromptTemplate
350
  from langchain.chat_models import ChatOpenAI
351
  from langchain.callbacks import StdOutCallbackHandler
352
- prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
353
- PROMPT = PromptTemplate(
354
- template=prompt_template, input_variables=["text"])
 
 
 
 
355
  llm = ChatOpenAI()
356
  chain = load_summarize_chain(
357
- llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
358
- summary = chain({"input_documents": list(index.docstore.__dict__[
359
- "_dict"].values())}, return_only_outputs=True)["output_text"]
 
 
 
 
 
 
 
360
  print(i18n("总结") + f": {summary}")
361
- chatbot.append([i18n("上传了")+str(len(files))+"个文件", summary])
362
  return chatbot, status
363
 
364
- def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot, load_from_cache_if_possible=True):
 
 
 
 
 
 
 
 
365
  display_append = []
366
  limited_context = False
367
  if type(real_inputs) == list:
368
- fake_inputs = real_inputs[0]['text']
369
  else:
370
  fake_inputs = real_inputs
371
  if files:
372
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
373
  from langchain.vectorstores.base import VectorStoreRetriever
 
374
  limited_context = True
375
  msg = "加载索引中……"
376
  logging.info(msg)
377
- index = construct_index(self.api_key, file_src=files, load_from_cache_if_possible=load_from_cache_if_possible)
 
 
 
 
378
  assert index is not None, "获取索引失败"
379
  msg = "索引获取成功,生成回答中……"
380
  logging.info(msg)
381
  with retrieve_proxy():
382
- retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity", search_kwargs={"k": 6})
 
 
383
  # retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity_score_threshold", search_kwargs={
384
  # "k": 6, "score_threshold": 0.2})
385
  try:
386
- relevant_documents = retriever.get_relevant_documents(
387
- fake_inputs)
388
  except AssertionError:
389
- return self.prepare_inputs(fake_inputs, use_websearch, files, reply_language, chatbot, load_from_cache_if_possible=False)
390
- reference_results = [[d.page_content.strip("�"), os.path.basename(
391
- d.metadata["source"])] for d in relevant_documents]
 
 
 
 
 
 
 
 
 
392
  reference_results = add_source_numbers(reference_results)
393
  display_append = add_details(reference_results)
394
  display_append = "\n\n" + "".join(display_append)
@@ -415,16 +454,17 @@ class BaseLLMModel:
415
  reference_results = []
416
  for idx, result in enumerate(search_results):
417
  logging.debug(f"搜索结果{idx + 1}:{result}")
418
- domain_name = urllib3.util.parse_url(result['href']).host
419
- reference_results.append([result['body'], result['href']])
420
  display_append.append(
421
  # f"{idx+1}. [{domain_name}]({result['href']})\n"
422
  f"<a href=\"{result['href']}\" target=\"_blank\">{idx+1}.&nbsp;{result['title']}</a>"
423
  )
424
  reference_results = add_source_numbers(reference_results)
425
  # display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
426
- display_append = '<div class = "source-a">' + \
427
- "".join(display_append) + '</div>'
 
428
  if type(real_inputs) == list:
429
  real_inputs[0]["text"] = (
430
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
@@ -453,33 +493,54 @@ class BaseLLMModel:
453
  reply_language="中文",
454
  should_check_token_count=True,
455
  ): # repetition_penalty, top_k
456
-
457
  status_text = "开始生成回答……"
458
  if type(inputs) == list:
459
- logging.info(
460
- "用户" + f"{self.user_identifier}" + "的输入为:" +
461
- colorama.Fore.BLUE + "(" + str(len(inputs)-1) + " images) " + f"{inputs[0]['text']}" + colorama.Style.RESET_ALL
 
 
 
 
 
 
 
462
  )
463
  else:
464
  logging.info(
465
- "用户" + f"{self.user_identifier}" + "的输入为:" +
466
- colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
 
 
 
 
467
  )
468
  if should_check_token_count:
469
  if type(inputs) == list:
470
- yield chatbot + [(inputs[0]['text'], "")], status_text
471
  else:
472
  yield chatbot + [(inputs, "")], status_text
473
  if reply_language == "跟随问题语言(不稳定)":
474
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
475
 
476
- limited_context, fake_inputs, display_append, inputs, chatbot = self.prepare_inputs(
477
- real_inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language, chatbot=chatbot)
 
 
 
 
 
 
 
 
 
 
 
478
  yield chatbot + [(fake_inputs, "")], status_text
479
 
480
  if (
481
- self.need_api_key and
482
- self.api_key is None
483
  and not shared.state.multi_api_key
484
  ):
485
  status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
@@ -684,11 +745,16 @@ class BaseLLMModel:
684
  self.history = []
685
  self.all_token_counts = []
686
  self.interrupted = False
687
- self.history_file_path = new_auto_history_filename(self.user_identifier)
688
  history_name = self.history_file_path[:-5]
689
- choices = [history_name] + get_history_names(self.user_identifier)
690
  system_prompt = self.system_prompt if remain_system_prompt else ""
691
- return [], self.token_message([0]), gr.Radio.update(choices=choices, value=history_name), system_prompt
 
 
 
 
 
692
 
693
  def delete_first_conversation(self):
694
  if self.history:
@@ -719,7 +785,12 @@ class BaseLLMModel:
719
  token_sum = 0
720
  for i in range(len(token_lst)):
721
  token_sum += sum(token_lst[: i + 1])
722
- return i18n("Token 计数: ") + f"{sum(token_lst)}" + i18n(",本次对话累计消耗了 ") + f"{token_sum} tokens"
 
 
 
 
 
723
 
724
  def rename_chat_history(self, filename, chatbot):
725
  if filename == "":
@@ -729,78 +800,103 @@ class BaseLLMModel:
729
  self.delete_chat_history(self.history_file_path)
730
  # 命名重复检测
731
  repeat_file_index = 2
732
- full_path = os.path.join(HISTORY_DIR, self.user_identifier, filename)
733
  while os.path.exists(full_path):
734
- full_path = os.path.join(HISTORY_DIR, self.user_identifier, f"{repeat_file_index}_{filename}")
 
 
735
  repeat_file_index += 1
736
  filename = os.path.basename(full_path)
737
 
738
  self.history_file_path = filename
739
- save_file(filename, self.system_prompt, self.history, chatbot, self.user_identifier)
740
- return init_history_list(self.user_identifier)
741
 
742
- def auto_name_chat_history(self, name_chat_method, user_question, chatbot, single_turn_checkbox):
 
 
743
  if len(self.history) == 2 and not single_turn_checkbox:
744
  user_question = self.history[0]["content"]
745
  if type(user_question) == list:
746
  user_question = user_question[0]["text"]
747
  filename = replace_special_symbols(user_question)[:16] + ".json"
748
- return self.rename_chat_history(filename, chatbot, self.user_identifier)
749
  else:
750
  return gr.update()
751
 
752
  def auto_save(self, chatbot):
753
- save_file(self.history_file_path, self.system_prompt,
754
- self.history, chatbot, self.user_identifier)
755
 
756
  def export_markdown(self, filename, chatbot):
757
  if filename == "":
758
  return
759
  if not filename.endswith(".md"):
760
  filename += ".md"
761
- save_file(filename, self.system_prompt, self.history, chatbot, self.user_identifier)
762
 
763
  def load_chat_history(self, new_history_file_path=None):
764
- logging.debug(f"{self.user_identifier} 加载对话历史中……")
765
  if new_history_file_path is not None:
766
  if type(new_history_file_path) != str:
767
- # copy file from new_history_file_path.name to os.path.join(HISTORY_DIR, self.user_identifier)
768
  new_history_file_path = new_history_file_path.name
769
- shutil.copyfile(new_history_file_path, os.path.join(
770
- HISTORY_DIR, self.user_identifier, os.path.basename(new_history_file_path)))
 
 
 
 
 
 
771
  self.history_file_path = os.path.basename(new_history_file_path)
772
  else:
773
  self.history_file_path = new_history_file_path
774
  try:
775
  if self.history_file_path == os.path.basename(self.history_file_path):
776
  history_file_path = os.path.join(
777
- HISTORY_DIR, self.user_identifier, self.history_file_path)
 
778
  else:
779
  history_file_path = self.history_file_path
780
  if not self.history_file_path.endswith(".json"):
781
  history_file_path += ".json"
782
  with open(history_file_path, "r", encoding="utf-8") as f:
783
- json_s = json.load(f)
784
  try:
785
- if type(json_s["history"][0]) == str:
786
  logging.info("历史记录格式为旧版,正在转换……")
787
  new_history = []
788
- for index, item in enumerate(json_s["history"]):
789
  if index % 2 == 0:
790
  new_history.append(construct_user(item))
791
  else:
792
  new_history.append(construct_assistant(item))
793
- json_s["history"] = new_history
794
  logging.info(new_history)
795
  except:
796
  pass
797
- if len(json_s["chatbot"]) < len(json_s["history"])//2:
798
  logging.info("Trimming corrupted history...")
799
- json_s["history"] = json_s["history"][-len(json_s["chatbot"]):]
800
- logging.info(f"Trimmed history: {json_s['history']}")
801
- logging.debug(f"{self.user_identifier} 加载对话历史完毕")
802
- self.history = json_s["history"]
803
- return os.path.basename(self.history_file_path), json_s["system"], json_s["chatbot"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
  except:
805
  # 没有对话历史或者对话历史解析失败
806
  logging.info(f"没有找到对话历史记录 {self.history_file_path}")
@@ -814,23 +910,28 @@ class BaseLLMModel:
814
  if not filename.endswith(".json"):
815
  filename += ".json"
816
  if filename == os.path.basename(filename):
817
- history_file_path = os.path.join(HISTORY_DIR, self.user_identifier, filename)
 
 
818
  else:
819
  history_file_path = filename
820
  md_history_file_path = history_file_path[:-5] + ".md"
821
  try:
822
  os.remove(history_file_path)
823
  os.remove(md_history_file_path)
824
- return i18n("删除对话历史成功"), get_history_list(self.user_identifier), []
825
  except:
826
  logging.info(f"删除对话历史失败 {history_file_path}")
827
- return i18n("对话历史")+filename+i18n("已经被删除啦"), get_history_list(self.user_identifier), []
 
 
 
 
828
 
829
  def auto_load(self):
830
- filepath = get_history_filepath(self.user_identifier)
831
  if not filepath:
832
- self.history_file_path = new_auto_history_filename(
833
- self.user_identifier)
834
  else:
835
  self.history_file_path = filepath
836
  filename, system_prompt, chatbot = self.load_chat_history()
@@ -838,18 +939,15 @@ class BaseLLMModel:
838
  return filename, system_prompt, chatbot
839
 
840
  def like(self):
841
- """like the last response, implement if needed
842
- """
843
  return gr.update()
844
 
845
  def dislike(self):
846
- """dislike the last response, implement if needed
847
- """
848
  return gr.update()
849
 
850
  def deinitialize(self):
851
- """deinitialize the model, implement if needed
852
- """
853
  pass
854
 
855
 
@@ -874,7 +972,8 @@ class Base_Chat_Langchain_Client(BaseLLMModel):
874
 
875
  def get_answer_at_once(self):
876
  assert isinstance(
877
- self.model, BaseChatModel), "model is not instance of LangChain BaseChatModel"
 
878
  history = self._get_langchain_style_history()
879
  response = self.model.generate(history)
880
  return response.content, sum(response.content)
@@ -882,13 +981,16 @@ class Base_Chat_Langchain_Client(BaseLLMModel):
882
  def get_answer_stream_iter(self):
883
  it = CallbackToIterator()
884
  assert isinstance(
885
- self.model, BaseChatModel), "model is not instance of LangChain BaseChatModel"
 
886
  history = self._get_langchain_style_history()
887
 
888
  def thread_func():
889
- self.model(messages=history, callbacks=[
890
- ChuanhuCallbackHandler(it.callback)])
 
891
  it.finish()
 
892
  t = Thread(target=thread_func)
893
  t.start()
894
  partial_text = ""
 
70
 
71
 
72
  def get_action_description(text):
73
+ match = re.search("```(.*?)```", text, re.S)
74
  json_text = match.group(1)
75
  # 把json转化为python字典
76
  json_dict = json.loads(json_text)
77
  # 提取'action'和'action_input'的值
78
+ action_name = json_dict["action"]
79
+ action_input = json_dict["action_input"]
80
  if action_name != "Final Answer":
81
  return f'<!-- S O PREFIX --><p class="agent-prefix">{action_name}: {action_input}\n</p><!-- E O PREFIX -->'
82
  else:
 
84
 
85
 
86
  class ChuanhuCallbackHandler(BaseCallbackHandler):
 
87
  def __init__(self, callback) -> None:
88
  """Initialize callback handler."""
89
  self.callback = callback
 
123
  """Run on new LLM token. Only available when streaming is enabled."""
124
  self.callback(token)
125
 
126
+ def on_chat_model_start(
127
+ self,
128
+ serialized: Dict[str, Any],
129
+ messages: List[List[BaseMessage]],
130
+ **kwargs: Any,
131
+ ) -> Any:
132
  """Run when a chat model starts running."""
133
  pass
134
 
 
232
  self.need_api_key = False
233
  self.single_turn = False
234
  self.history_file_path = get_first_history_name(user)
235
+ self.user_name = user
236
 
237
  self.temperature = temperature
238
  self.top_p = top_p
239
  self.n_choices = n_choices
240
  self.stop_sequence = stop
241
+ self.max_generation_token = max_generation_token
242
  self.presence_penalty = presence_penalty
243
  self.frequency_penalty = frequency_penalty
244
  self.logit_bias = logit_bias
245
  self.user_identifier = user
246
 
247
+ self.metadata = {}
248
+
249
  def get_answer_stream_iter(self):
250
+ """Implement stream prediction.
251
+ Conversations are stored in self.history, with the most recent question in OpenAI format.
252
+ Should return a generator that yields the next word (str) in the answer.
253
  """
254
+ logging.warning("Stream prediction is not implemented. Using at once prediction instead.")
 
255
  response, _ = self.get_answer_at_once()
256
  yield response
257
 
 
262
  the answer (str)
263
  total token count (int)
264
  """
265
+ logging.warning("at once predict not implemented, using stream predict instead")
 
266
  response_iter = self.get_answer_stream_iter()
267
  count = 0
268
  for response in response_iter:
 
296
  stream_iter = self.get_answer_stream_iter()
297
 
298
  if display_append:
299
+ display_append = (
300
+ '\n\n<hr class="append-display no-in-raw" />' + display_append
301
+ )
302
  partial_text = ""
303
  token_increment = 1
304
  for partial_text in stream_iter:
 
329
  self.history[-2] = construct_user(fake_input)
330
  chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
331
  if fake_input is not None:
332
+ self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
 
333
  else:
334
+ self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
 
335
  status_text = self.token_message()
336
  return chatbot, status_text
337
 
 
354
  from langchain.prompts import PromptTemplate
355
  from langchain.chat_models import ChatOpenAI
356
  from langchain.callbacks import StdOutCallbackHandler
357
+
358
+ prompt_template = (
359
+ "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
360
+ + language
361
+ + ":"
362
+ )
363
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
364
  llm = ChatOpenAI()
365
  chain = load_summarize_chain(
366
+ llm,
367
+ chain_type="map_reduce",
368
+ return_intermediate_steps=True,
369
+ map_prompt=PROMPT,
370
+ combine_prompt=PROMPT,
371
+ )
372
+ summary = chain(
373
+ {"input_documents": list(index.docstore.__dict__["_dict"].values())},
374
+ return_only_outputs=True,
375
+ )["output_text"]
376
  print(i18n("总结") + f": {summary}")
377
+ chatbot.append([i18n("上传了") + str(len(files)) + "个文件", summary])
378
  return chatbot, status
379
 
380
+ def prepare_inputs(
381
+ self,
382
+ real_inputs,
383
+ use_websearch,
384
+ files,
385
+ reply_language,
386
+ chatbot,
387
+ load_from_cache_if_possible=True,
388
+ ):
389
  display_append = []
390
  limited_context = False
391
  if type(real_inputs) == list:
392
+ fake_inputs = real_inputs[0]["text"]
393
  else:
394
  fake_inputs = real_inputs
395
  if files:
396
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
397
  from langchain.vectorstores.base import VectorStoreRetriever
398
+
399
  limited_context = True
400
  msg = "加载索引中……"
401
  logging.info(msg)
402
+ index = construct_index(
403
+ self.api_key,
404
+ file_src=files,
405
+ load_from_cache_if_possible=load_from_cache_if_possible,
406
+ )
407
  assert index is not None, "获取索引失败"
408
  msg = "索引获取成功,生成回答中……"
409
  logging.info(msg)
410
  with retrieve_proxy():
411
+ retriever = VectorStoreRetriever(
412
+ vectorstore=index, search_type="similarity", search_kwargs={"k": 6}
413
+ )
414
  # retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity_score_threshold", search_kwargs={
415
  # "k": 6, "score_threshold": 0.2})
416
  try:
417
+ relevant_documents = retriever.get_relevant_documents(fake_inputs)
 
418
  except AssertionError:
419
+ return self.prepare_inputs(
420
+ fake_inputs,
421
+ use_websearch,
422
+ files,
423
+ reply_language,
424
+ chatbot,
425
+ load_from_cache_if_possible=False,
426
+ )
427
+ reference_results = [
428
+ [d.page_content.strip("�"), os.path.basename(d.metadata["source"])]
429
+ for d in relevant_documents
430
+ ]
431
  reference_results = add_source_numbers(reference_results)
432
  display_append = add_details(reference_results)
433
  display_append = "\n\n" + "".join(display_append)
 
454
  reference_results = []
455
  for idx, result in enumerate(search_results):
456
  logging.debug(f"搜索结果{idx + 1}:{result}")
457
+ domain_name = urllib3.util.parse_url(result["href"]).host
458
+ reference_results.append([result["body"], result["href"]])
459
  display_append.append(
460
  # f"{idx+1}. [{domain_name}]({result['href']})\n"
461
  f"<a href=\"{result['href']}\" target=\"_blank\">{idx+1}.&nbsp;{result['title']}</a>"
462
  )
463
  reference_results = add_source_numbers(reference_results)
464
  # display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
465
+ display_append = (
466
+ '<div class = "source-a">' + "".join(display_append) + "</div>"
467
+ )
468
  if type(real_inputs) == list:
469
  real_inputs[0]["text"] = (
470
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
 
493
  reply_language="中文",
494
  should_check_token_count=True,
495
  ): # repetition_penalty, top_k
 
496
  status_text = "开始生成回答……"
497
  if type(inputs) == list:
498
+ logging.info(
499
+ "用户"
500
+ + f"{self.user_name}"
501
+ + "的输入为:"
502
+ + colorama.Fore.BLUE
503
+ + "("
504
+ + str(len(inputs) - 1)
505
+ + " images) "
506
+ + f"{inputs[0]['text']}"
507
+ + colorama.Style.RESET_ALL
508
  )
509
  else:
510
  logging.info(
511
+ "用户"
512
+ + f"{self.user_name}"
513
+ + "的输入为:"
514
+ + colorama.Fore.BLUE
515
+ + f"{inputs}"
516
+ + colorama.Style.RESET_ALL
517
  )
518
  if should_check_token_count:
519
  if type(inputs) == list:
520
+ yield chatbot + [(inputs[0]["text"], "")], status_text
521
  else:
522
  yield chatbot + [(inputs, "")], status_text
523
  if reply_language == "跟随问题语言(不稳定)":
524
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
525
 
526
+ (
527
+ limited_context,
528
+ fake_inputs,
529
+ display_append,
530
+ inputs,
531
+ chatbot,
532
+ ) = self.prepare_inputs(
533
+ real_inputs=inputs,
534
+ use_websearch=use_websearch,
535
+ files=files,
536
+ reply_language=reply_language,
537
+ chatbot=chatbot,
538
+ )
539
  yield chatbot + [(fake_inputs, "")], status_text
540
 
541
  if (
542
+ self.need_api_key
543
+ and self.api_key is None
544
  and not shared.state.multi_api_key
545
  ):
546
  status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
 
745
  self.history = []
746
  self.all_token_counts = []
747
  self.interrupted = False
748
+ self.history_file_path = new_auto_history_filename(self.user_name)
749
  history_name = self.history_file_path[:-5]
750
+ choices = [history_name] + get_history_names(self.user_name)
751
  system_prompt = self.system_prompt if remain_system_prompt else ""
752
+ return (
753
+ [],
754
+ self.token_message([0]),
755
+ gr.Radio.update(choices=choices, value=history_name),
756
+ system_prompt,
757
+ )
758
 
759
  def delete_first_conversation(self):
760
  if self.history:
 
785
  token_sum = 0
786
  for i in range(len(token_lst)):
787
  token_sum += sum(token_lst[: i + 1])
788
+ return (
789
+ i18n("Token 计数: ")
790
+ + f"{sum(token_lst)}"
791
+ + i18n(",本次对话累计消耗了 ")
792
+ + f"{token_sum} tokens"
793
+ )
794
 
795
  def rename_chat_history(self, filename, chatbot):
796
  if filename == "":
 
800
  self.delete_chat_history(self.history_file_path)
801
  # 命名重复检测
802
  repeat_file_index = 2
803
+ full_path = os.path.join(HISTORY_DIR, self.user_name, filename)
804
  while os.path.exists(full_path):
805
+ full_path = os.path.join(
806
+ HISTORY_DIR, self.user_name, f"{repeat_file_index}_{filename}"
807
+ )
808
  repeat_file_index += 1
809
  filename = os.path.basename(full_path)
810
 
811
  self.history_file_path = filename
812
+ save_file(filename, self, chatbot)
813
+ return init_history_list(self.user_name)
814
 
815
+ def auto_name_chat_history(
816
+ self, name_chat_method, user_question, chatbot, single_turn_checkbox
817
+ ):
818
  if len(self.history) == 2 and not single_turn_checkbox:
819
  user_question = self.history[0]["content"]
820
  if type(user_question) == list:
821
  user_question = user_question[0]["text"]
822
  filename = replace_special_symbols(user_question)[:16] + ".json"
823
+ return self.rename_chat_history(filename, chatbot, self.user_name)
824
  else:
825
  return gr.update()
826
 
827
  def auto_save(self, chatbot):
828
+ save_file(self.history_file_path, self, chatbot)
 
829
 
830
  def export_markdown(self, filename, chatbot):
831
  if filename == "":
832
  return
833
  if not filename.endswith(".md"):
834
  filename += ".md"
835
+ save_file(filename, self, chatbot)
836
 
837
  def load_chat_history(self, new_history_file_path=None):
838
+ logging.debug(f"{self.user_name} 加载对话历史中……")
839
  if new_history_file_path is not None:
840
  if type(new_history_file_path) != str:
841
+ # copy file from new_history_file_path.name to os.path.join(HISTORY_DIR, self.user_name)
842
  new_history_file_path = new_history_file_path.name
843
+ shutil.copyfile(
844
+ new_history_file_path,
845
+ os.path.join(
846
+ HISTORY_DIR,
847
+ self.user_name,
848
+ os.path.basename(new_history_file_path),
849
+ ),
850
+ )
851
  self.history_file_path = os.path.basename(new_history_file_path)
852
  else:
853
  self.history_file_path = new_history_file_path
854
  try:
855
  if self.history_file_path == os.path.basename(self.history_file_path):
856
  history_file_path = os.path.join(
857
+ HISTORY_DIR, self.user_name, self.history_file_path
858
+ )
859
  else:
860
  history_file_path = self.history_file_path
861
  if not self.history_file_path.endswith(".json"):
862
  history_file_path += ".json"
863
  with open(history_file_path, "r", encoding="utf-8") as f:
864
+ saved_json = json.load(f)
865
  try:
866
+ if type(saved_json["history"][0]) == str:
867
  logging.info("历史记录格式为旧版,正在转换……")
868
  new_history = []
869
+ for index, item in enumerate(saved_json["history"]):
870
  if index % 2 == 0:
871
  new_history.append(construct_user(item))
872
  else:
873
  new_history.append(construct_assistant(item))
874
+ saved_json["history"] = new_history
875
  logging.info(new_history)
876
  except:
877
  pass
878
+ if len(saved_json["chatbot"]) < len(saved_json["history"]) // 2:
879
  logging.info("Trimming corrupted history...")
880
+ saved_json["history"] = saved_json["history"][-len(saved_json["chatbot"]) :]
881
+ logging.info(f"Trimmed history: {saved_json['history']}")
882
+ logging.debug(f"{self.user_name} 加载对话历史完毕")
883
+ self.history = saved_json["history"]
884
+ self.single_turn = saved_json.get("single_turn", False)
885
+ self.temperature = saved_json.get("temperature", 1.0)
886
+ self.top_p = saved_json.get("top_p", None)
887
+ self.n_choices = saved_json.get("n_choices", 1)
888
+ self.stop_sequence = saved_json.get("stop_sequence", None)
889
+ self.max_generation_token = saved_json.get("max_generation_token", None)
890
+ self.presence_penalty = saved_json.get("presence_penalty", 0)
891
+ self.frequency_penalty = saved_json.get("frequency_penalty", 0)
892
+ self.logit_bias = saved_json.get("logit_bias", None)
893
+ self.user_identifier = saved_json.get("user_identifier", self.user_name)
894
+ self.metadata = saved_json.get("metadata", {})
895
+ return (
896
+ os.path.basename(self.history_file_path),
897
+ saved_json["system"],
898
+ saved_json["chatbot"],
899
+ )
900
  except:
901
  # 没有对话历史或者对话历史解析失败
902
  logging.info(f"没有找到对话历史记录 {self.history_file_path}")
 
910
  if not filename.endswith(".json"):
911
  filename += ".json"
912
  if filename == os.path.basename(filename):
913
+ history_file_path = os.path.join(
914
+ HISTORY_DIR, self.user_name, filename
915
+ )
916
  else:
917
  history_file_path = filename
918
  md_history_file_path = history_file_path[:-5] + ".md"
919
  try:
920
  os.remove(history_file_path)
921
  os.remove(md_history_file_path)
922
+ return i18n("删除对话历史成功"), get_history_list(self.user_name), []
923
  except:
924
  logging.info(f"删除对话历史失败 {history_file_path}")
925
+ return (
926
+ i18n("对话历史") + filename + i18n("已经被删除啦"),
927
+ get_history_list(self.user_name),
928
+ [],
929
+ )
930
 
931
  def auto_load(self):
932
+ filepath = get_history_filepath(self.user_name)
933
  if not filepath:
934
+ self.history_file_path = new_auto_history_filename(self.user_name)
 
935
  else:
936
  self.history_file_path = filepath
937
  filename, system_prompt, chatbot = self.load_chat_history()
 
939
  return filename, system_prompt, chatbot
940
 
941
  def like(self):
942
+ """like the last response, implement if needed"""
 
943
  return gr.update()
944
 
945
  def dislike(self):
946
+ """dislike the last response, implement if needed"""
 
947
  return gr.update()
948
 
949
  def deinitialize(self):
950
+ """deinitialize the model, implement if needed"""
 
951
  pass
952
 
953
 
 
972
 
973
  def get_answer_at_once(self):
974
  assert isinstance(
975
+ self.model, BaseChatModel
976
+ ), "model is not instance of LangChain BaseChatModel"
977
  history = self._get_langchain_style_history()
978
  response = self.model.generate(history)
979
  return response.content, sum(response.content)
 
981
  def get_answer_stream_iter(self):
982
  it = CallbackToIterator()
983
  assert isinstance(
984
+ self.model, BaseChatModel
985
+ ), "model is not instance of LangChain BaseChatModel"
986
  history = self._get_langchain_style_history()
987
 
988
  def thread_func():
989
+ self.model(
990
+ messages=history, callbacks=[ChuanhuCallbackHandler(it.callback)]
991
+ )
992
  it.finish()
993
+
994
  t = Thread(target=thread_func)
995
  t.start()
996
  partial_text = ""
modules/utils.py CHANGED
@@ -31,97 +31,127 @@ if TYPE_CHECKING:
31
  headers: List[str]
32
  data: List[List[str | int | bool]]
33
 
 
34
  def predict(current_model, *args):
35
  iter = current_model.predict(*args)
36
  for i in iter:
37
  yield i
38
 
 
39
  def billing_info(current_model):
40
  return current_model.billing_info()
41
 
 
42
  def set_key(current_model, *args):
43
  return current_model.set_key(*args)
44
 
 
45
  def load_chat_history(current_model, *args):
46
  return current_model.load_chat_history(*args)
47
 
 
48
  def delete_chat_history(current_model, *args):
49
  return current_model.delete_chat_history(*args)
50
 
 
51
  def interrupt(current_model, *args):
52
  return current_model.interrupt(*args)
53
 
 
54
  def reset(current_model, *args):
55
  return current_model.reset(*args)
56
 
 
57
  def retry(current_model, *args):
58
  iter = current_model.retry(*args)
59
  for i in iter:
60
  yield i
61
 
 
62
  def delete_first_conversation(current_model, *args):
63
  return current_model.delete_first_conversation(*args)
64
 
 
65
  def delete_last_conversation(current_model, *args):
66
  return current_model.delete_last_conversation(*args)
67
 
 
68
  def set_system_prompt(current_model, *args):
69
  return current_model.set_system_prompt(*args)
70
 
 
71
  def rename_chat_history(current_model, *args):
72
  return current_model.rename_chat_history(*args)
73
 
 
74
  def auto_name_chat_history(current_model, *args):
75
  return current_model.auto_name_chat_history(*args)
76
 
 
77
  def export_markdown(current_model, *args):
78
  return current_model.export_markdown(*args)
79
 
 
80
  def upload_chat_history(current_model, *args):
81
  return current_model.load_chat_history(*args)
82
 
 
83
  def set_token_upper_limit(current_model, *args):
84
  return current_model.set_token_upper_limit(*args)
85
 
 
86
  def set_temperature(current_model, *args):
87
  current_model.set_temperature(*args)
88
 
 
89
  def set_top_p(current_model, *args):
90
  current_model.set_top_p(*args)
91
 
 
92
  def set_n_choices(current_model, *args):
93
  current_model.set_n_choices(*args)
94
 
 
95
  def set_stop_sequence(current_model, *args):
96
  current_model.set_stop_sequence(*args)
97
 
 
98
  def set_max_tokens(current_model, *args):
99
  current_model.set_max_tokens(*args)
100
 
 
101
  def set_presence_penalty(current_model, *args):
102
  current_model.set_presence_penalty(*args)
103
 
 
104
  def set_frequency_penalty(current_model, *args):
105
  current_model.set_frequency_penalty(*args)
106
 
 
107
  def set_logit_bias(current_model, *args):
108
  current_model.set_logit_bias(*args)
109
 
 
110
  def set_user_identifier(current_model, *args):
111
  current_model.set_user_identifier(*args)
112
 
 
113
  def set_single_turn(current_model, *args):
114
  current_model.set_single_turn(*args)
115
 
 
116
  def handle_file_upload(current_model, *args):
117
  return current_model.handle_file_upload(*args)
118
 
 
119
  def handle_summarize_index(current_model, *args):
120
  return current_model.summarize_index(*args)
121
 
 
122
  def like(current_model, *args):
123
  return current_model.like(*args)
124
 
 
125
  def dislike(current_model, *args):
126
  return current_model.dislike(*args)
127
 
@@ -134,7 +164,7 @@ def count_token(input_str):
134
  return length
135
 
136
 
137
- def markdown_to_html_with_syntax_highlight(md_str): # deprecated
138
  def replacer(match):
139
  lang = match.group(1) or "text"
140
  code = match.group(2)
@@ -156,7 +186,7 @@ def markdown_to_html_with_syntax_highlight(md_str): # deprecated
156
  return html_str
157
 
158
 
159
- def normalize_markdown(md_text: str) -> str: # deprecated
160
  lines = md_text.split("\n")
161
  normalized_lines = []
162
  inside_list = False
@@ -180,7 +210,7 @@ def normalize_markdown(md_text: str) -> str: # deprecated
180
  return "\n".join(normalized_lines)
181
 
182
 
183
- def convert_mdtext(md_text): # deprecated
184
  code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
185
  inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
186
  code_blocks = code_block_pattern.findall(md_text)
@@ -209,16 +239,22 @@ def clip_rawtext(chat_message, need_escape=True):
209
  # first, clip hr line
210
  hr_pattern = r'\n\n<hr class="append-display no-in-raw" />(.*?)'
211
  hr_match = re.search(hr_pattern, chat_message, re.DOTALL)
212
- message_clipped = chat_message[:hr_match.start()] if hr_match else chat_message
213
  # second, avoid agent-prefix being escaped
214
- agent_prefix_pattern = r'(<!-- S O PREFIX --><p class="agent-prefix">.*?<\/p><!-- E O PREFIX -->)'
 
 
215
  # agent_matches = re.findall(agent_prefix_pattern, message_clipped)
216
  agent_parts = re.split(agent_prefix_pattern, message_clipped, flags=re.DOTALL)
217
  final_message = ""
218
  for i, part in enumerate(agent_parts):
219
  if i % 2 == 0:
220
  if part != "" and part != "\n":
221
- final_message += f'<pre class="fake-pre">{escape_markdown(part)}</pre>' if need_escape else f'<pre class="fake-pre">{part}</pre>'
 
 
 
 
222
  else:
223
  final_message += part
224
  return final_message
@@ -248,51 +284,53 @@ def convert_bot_before_marked(chat_message):
248
  md = f'<div class="md-message">\n\n{result}\n</div>'
249
  return raw + md
250
 
 
251
  def convert_user_before_marked(chat_message):
252
  if '<div class="user-message">' in chat_message:
253
  return chat_message
254
  else:
255
  return f'<div class="user-message">{escape_markdown(chat_message)}</div>'
256
 
 
257
  def escape_markdown(text):
258
  """
259
  Escape Markdown special characters to HTML-safe equivalents.
260
  """
261
  escape_chars = {
262
  # ' ': '&nbsp;',
263
- '_': '&#95;',
264
- '*': '&#42;',
265
- '[': '&#91;',
266
- ']': '&#93;',
267
- '(': '&#40;',
268
- ')': '&#41;',
269
- '{': '&#123;',
270
- '}': '&#125;',
271
- '#': '&#35;',
272
- '+': '&#43;',
273
- '-': '&#45;',
274
- '.': '&#46;',
275
- '!': '&#33;',
276
- '`': '&#96;',
277
- '>': '&#62;',
278
- '<': '&#60;',
279
- '|': '&#124;',
280
- '$': '&#36;',
281
- ':': '&#58;',
282
- '\n': '<br>',
283
  }
284
- text = text.replace(' ', '&nbsp;&nbsp;&nbsp;&nbsp;')
285
- return ''.join(escape_chars.get(c, c) for c in text)
286
 
287
 
288
- def convert_asis(userinput): # deprecated
289
  return (
290
  f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
291
  + ALREADY_CONVERTED_MARK
292
  )
293
 
294
 
295
- def detect_converted_mark(userinput): # deprecated
296
  try:
297
  if userinput.endswith(ALREADY_CONVERTED_MARK):
298
  return True
@@ -302,7 +340,7 @@ def detect_converted_mark(userinput): # deprecated
302
  return True
303
 
304
 
305
- def detect_language(code): # deprecated
306
  if code.startswith("\n"):
307
  first_line = ""
308
  else:
@@ -328,7 +366,10 @@ def construct_assistant(text):
328
  return construct_text("assistant", text)
329
 
330
 
331
- def save_file(filename, system, history, chatbot, user_name):
 
 
 
332
  os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
333
  if filename is None:
334
  filename = new_auto_history_filename(user_name)
@@ -339,22 +380,38 @@ def save_file(filename, system, history, chatbot, user_name):
339
  if filename == ".json":
340
  raise Exception("文件名不能为空")
341
 
342
- json_s = {"system": system, "history": history, "chatbot": chatbot}
343
- repeat_file_index = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  if not filename == os.path.basename(filename):
345
  history_file_path = filename
346
  else:
347
  history_file_path = os.path.join(HISTORY_DIR, user_name, filename)
348
 
349
- with open(history_file_path, "w", encoding='utf-8') as f:
350
- json.dump(json_s, f, ensure_ascii=False)
351
 
352
  filename = os.path.basename(filename)
353
  filename_md = filename[:-5] + ".md"
354
  md_s = f"system: \n- {system} \n"
355
  for data in history:
356
  md_s += f"\n{data['role']}: \n- {data['content']} \n"
357
- with open(os.path.join(HISTORY_DIR, user_name, filename_md), "w", encoding="utf8") as f:
 
 
358
  f.write(md_s)
359
  return os.path.join(HISTORY_DIR, user_name, filename)
360
 
@@ -362,8 +419,12 @@ def save_file(filename, system, history, chatbot, user_name):
362
  def sorted_by_pinyin(list):
363
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
364
 
 
365
  def sorted_by_last_modified_time(list, dir):
366
- return sorted(list, key=lambda char: os.path.getctime(os.path.join(dir, char)), reverse=True)
 
 
 
367
 
368
  def get_file_names_by_type(dir, filetypes=[".json"]):
369
  logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes}")
@@ -373,6 +434,7 @@ def get_file_names_by_type(dir, filetypes=[".json"]):
373
  logging.debug(f"files are:{files}")
374
  return files
375
 
 
376
  def get_file_names_by_pinyin(dir, filetypes=[".json"]):
377
  files = get_file_names_by_type(dir, filetypes)
378
  if files != [""]:
@@ -380,10 +442,12 @@ def get_file_names_by_pinyin(dir, filetypes=[".json"]):
380
  logging.debug(f"files are:{files}")
381
  return files
382
 
 
383
  def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]):
384
  files = get_file_names_by_pinyin(dir, filetypes)
385
  return gr.Dropdown.update(choices=files)
386
 
 
387
  def get_file_names_by_last_modified_time(dir, filetypes=[".json"]):
388
  files = get_file_names_by_type(dir, filetypes)
389
  if files != [""]:
@@ -397,21 +461,29 @@ def get_history_names(user_name=""):
397
  if user_name == "" and hide_history_when_not_logged_in:
398
  return []
399
  else:
400
- history_files = get_file_names_by_last_modified_time(os.path.join(HISTORY_DIR, user_name))
401
- history_files = [f[:f.rfind(".")] for f in history_files]
 
 
402
  return history_files
403
 
 
404
  def get_first_history_name(user_name=""):
405
  history_names = get_history_names(user_name)
406
  return history_names[0] if history_names else None
407
 
 
408
  def get_history_list(user_name=""):
409
  history_names = get_history_names(user_name)
410
  return gr.Radio.update(choices=history_names)
411
 
 
412
  def init_history_list(user_name=""):
413
  history_names = get_history_names(user_name)
414
- return gr.Radio.update(choices=history_names, value=history_names[0] if history_names else "")
 
 
 
415
 
416
  def filter_history(user_name, keyword):
417
  history_names = get_history_names(user_name)
@@ -421,6 +493,7 @@ def filter_history(user_name, keyword):
421
  except:
422
  return gr.update(choices=history_names)
423
 
 
424
  def load_template(filename, mode=0):
425
  logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
426
  lines = []
@@ -441,15 +514,14 @@ def load_template(filename, mode=0):
441
  return {row[0]: row[1] for row in lines}
442
  else:
443
  choices = sorted_by_pinyin([row[0] for row in lines])
444
- return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
445
- choices=choices
446
- )
447
 
448
 
449
  def get_template_names():
450
  logging.debug("获取模板文件名列表")
451
  return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"])
452
 
 
453
  def get_template_dropdown():
454
  logging.debug("获取模板下拉菜单")
455
  template_names = get_template_names()
@@ -524,9 +596,7 @@ def get_geoip():
524
  if "error" in data.keys():
525
  logging.warning(f"无法获取IP地址信息。\n{data}")
526
  if data["reason"] == "RateLimited":
527
- return (
528
- i18n("您的IP区域:未知。")
529
- )
530
  else:
531
  return i18n("获取IP地理位置失败。原因:") + f"{data['reason']}" + i18n("。你仍然可以使用聊天功能。")
532
  else:
@@ -590,29 +660,36 @@ def update_chuanhu():
590
  if update_status == "success":
591
  logging.info("Successfully updated, restart needed")
592
  status = '<span id="update-status" class="hideK">success</span>'
593
- return gr.Markdown.update(value=i18n("更新成功,请重启本程序")+status)
594
  else:
595
  status = '<span id="update-status" class="hideK">failure</span>'
596
- return gr.Markdown.update(value=i18n("更新失败,请尝试[手动更新](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程#手动更新)")+status)
 
 
 
 
 
597
 
598
 
599
- def add_source_numbers(lst, source_name = "Source", use_source = True):
600
  if use_source:
601
- return [f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)]
 
 
 
602
  else:
603
  return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
604
 
 
605
  def add_details(lst):
606
  nodes = []
607
  for index, txt in enumerate(lst):
608
  brief = txt[:25].replace("\n", "")
609
- nodes.append(
610
- f"<details><summary>{brief}...</summary><p>{txt}</p></details>"
611
- )
612
  return nodes
613
 
614
 
615
- def sheet_to_string(sheet, sheet_name = None):
616
  result = []
617
  for index, row in sheet.iterrows():
618
  row_string = ""
@@ -623,59 +700,70 @@ def sheet_to_string(sheet, sheet_name = None):
623
  result.append(row_string)
624
  return result
625
 
 
626
  def excel_to_string(file_path):
627
  # 读取Excel文件中的所有工作表
628
- excel_file = pd.read_excel(file_path, engine='openpyxl', sheet_name=None)
629
 
630
  # 初始化结果字符串
631
  result = []
632
 
633
  # 遍历每一个工作表
634
  for sheet_name, sheet_data in excel_file.items():
635
-
636
  # 处理当前工作表并添加到结果字符串
637
  result += sheet_to_string(sheet_data, sheet_name=sheet_name)
638
 
639
-
640
  return result
641
 
 
642
  def get_last_day_of_month(any_day):
643
  # The day 28 exists in every month. 4 days later, it's always next month
644
  next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
645
  # subtracting the number of the current day brings us back one month
646
  return next_month - datetime.timedelta(days=next_month.day)
647
 
 
648
  def get_model_source(model_name, alternative_source):
649
  if model_name == "gpt2-medium":
650
  return "https://huggingface.co/gpt2-medium"
651
 
 
652
  def refresh_ui_elements_on_load(current_model, selected_model_name, user_name):
653
  current_model.set_user_identifier(user_name)
654
  return toggle_like_btn_visibility(selected_model_name), *current_model.auto_load()
655
 
 
656
  def toggle_like_btn_visibility(selected_model_name):
657
  if selected_model_name == "xmchat":
658
  return gr.update(visible=True)
659
  else:
660
  return gr.update(visible=False)
661
 
 
662
  def get_corresponding_file_type_by_model_name(selected_model_name):
663
  if selected_model_name in ["xmchat", "GPT4 Vision"]:
664
  return ["image"]
665
  else:
666
  return [".pdf", ".docx", ".pptx", ".epub", ".xlsx", ".txt", "text"]
667
 
 
668
  # def toggle_file_type(selected_model_name):
669
  # return gr.Files.update(file_types=get_corresponding_file_type_by_model_name(selected_model_name))
670
 
 
671
  def new_auto_history_filename(username):
672
  latest_file = get_first_history_name(username)
673
  if latest_file:
674
- with open(os.path.join(HISTORY_DIR, username, latest_file + ".json"), 'r', encoding="utf-8") as f:
 
 
 
 
675
  if len(f.read()) == 0:
676
  return latest_file
677
- now = i18n("新对话 ") + datetime.datetime.now().strftime('%m-%d %H-%M')
678
- return f'{now}.json'
 
679
 
680
  def get_history_filepath(username):
681
  dirname = os.path.join(HISTORY_DIR, username)
@@ -687,20 +775,28 @@ def get_history_filepath(username):
687
  latest_file = os.path.join(dirname, latest_file)
688
  return latest_file
689
 
 
690
  def beautify_err_msg(err_msg):
691
- if "insufficient_quota" in err_msg:
692
- return i18n("剩余配额不足,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98#you-exceeded-your-current-quota-please-check-your-plan-and-billing-details)")
 
 
693
  if "The model `gpt-4` does not exist" in err_msg:
694
- return i18n("你没有权限访问 GPT4,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues/843)")
 
 
695
  if "Resource not found" in err_msg:
696
  return i18n("请查看 config_example.json,配置 Azure OpenAI")
697
  return err_msg
698
 
 
699
  def auth_from_conf(username, password):
700
  try:
701
  with open("config.json", encoding="utf-8") as f:
702
  conf = json.load(f)
703
- usernames, passwords = [i[0] for i in conf["users"]], [i[1] for i in conf["users"]]
 
 
704
  if username in usernames:
705
  if passwords[usernames.index(username)] == password:
706
  return True
@@ -708,6 +804,7 @@ def auth_from_conf(username, password):
708
  except:
709
  return False
710
 
 
711
  def get_file_hash(file_src=None, file_paths=None):
712
  if file_src:
713
  file_paths = [x.name for x in file_src]
@@ -721,12 +818,14 @@ def get_file_hash(file_src=None, file_paths=None):
721
 
722
  return md5_hash.hexdigest()
723
 
 
724
  def myprint(**args):
725
  print(args)
726
 
 
727
  def replace_special_symbols(string, replace_string=" "):
728
  # 定义正则表达式,匹配所有特殊符号
729
- pattern = r'[!@#$%^&*()<>?/\|}{~:]'
730
 
731
  new_string = re.sub(pattern, replace_string, string)
732
 
 
31
  headers: List[str]
32
  data: List[List[str | int | bool]]
33
 
34
+
35
  def predict(current_model, *args):
36
  iter = current_model.predict(*args)
37
  for i in iter:
38
  yield i
39
 
40
+
41
  def billing_info(current_model):
42
  return current_model.billing_info()
43
 
44
+
45
  def set_key(current_model, *args):
46
  return current_model.set_key(*args)
47
 
48
+
49
  def load_chat_history(current_model, *args):
50
  return current_model.load_chat_history(*args)
51
 
52
+
53
  def delete_chat_history(current_model, *args):
54
  return current_model.delete_chat_history(*args)
55
 
56
+
57
  def interrupt(current_model, *args):
58
  return current_model.interrupt(*args)
59
 
60
+
61
  def reset(current_model, *args):
62
  return current_model.reset(*args)
63
 
64
+
65
  def retry(current_model, *args):
66
  iter = current_model.retry(*args)
67
  for i in iter:
68
  yield i
69
 
70
+
71
  def delete_first_conversation(current_model, *args):
72
  return current_model.delete_first_conversation(*args)
73
 
74
+
75
  def delete_last_conversation(current_model, *args):
76
  return current_model.delete_last_conversation(*args)
77
 
78
+
79
  def set_system_prompt(current_model, *args):
80
  return current_model.set_system_prompt(*args)
81
 
82
+
83
  def rename_chat_history(current_model, *args):
84
  return current_model.rename_chat_history(*args)
85
 
86
+
87
  def auto_name_chat_history(current_model, *args):
88
  return current_model.auto_name_chat_history(*args)
89
 
90
+
91
  def export_markdown(current_model, *args):
92
  return current_model.export_markdown(*args)
93
 
94
+
95
  def upload_chat_history(current_model, *args):
96
  return current_model.load_chat_history(*args)
97
 
98
+
99
  def set_token_upper_limit(current_model, *args):
100
  return current_model.set_token_upper_limit(*args)
101
 
102
+
103
  def set_temperature(current_model, *args):
104
  current_model.set_temperature(*args)
105
 
106
+
107
  def set_top_p(current_model, *args):
108
  current_model.set_top_p(*args)
109
 
110
+
111
  def set_n_choices(current_model, *args):
112
  current_model.set_n_choices(*args)
113
 
114
+
115
  def set_stop_sequence(current_model, *args):
116
  current_model.set_stop_sequence(*args)
117
 
118
+
119
  def set_max_tokens(current_model, *args):
120
  current_model.set_max_tokens(*args)
121
 
122
+
123
  def set_presence_penalty(current_model, *args):
124
  current_model.set_presence_penalty(*args)
125
 
126
+
127
  def set_frequency_penalty(current_model, *args):
128
  current_model.set_frequency_penalty(*args)
129
 
130
+
131
  def set_logit_bias(current_model, *args):
132
  current_model.set_logit_bias(*args)
133
 
134
+
135
  def set_user_identifier(current_model, *args):
136
  current_model.set_user_identifier(*args)
137
 
138
+
139
  def set_single_turn(current_model, *args):
140
  current_model.set_single_turn(*args)
141
 
142
+
143
  def handle_file_upload(current_model, *args):
144
  return current_model.handle_file_upload(*args)
145
 
146
+
147
  def handle_summarize_index(current_model, *args):
148
  return current_model.summarize_index(*args)
149
 
150
+
151
  def like(current_model, *args):
152
  return current_model.like(*args)
153
 
154
+
155
  def dislike(current_model, *args):
156
  return current_model.dislike(*args)
157
 
 
164
  return length
165
 
166
 
167
+ def markdown_to_html_with_syntax_highlight(md_str): # deprecated
168
  def replacer(match):
169
  lang = match.group(1) or "text"
170
  code = match.group(2)
 
186
  return html_str
187
 
188
 
189
+ def normalize_markdown(md_text: str) -> str: # deprecated
190
  lines = md_text.split("\n")
191
  normalized_lines = []
192
  inside_list = False
 
210
  return "\n".join(normalized_lines)
211
 
212
 
213
+ def convert_mdtext(md_text): # deprecated
214
  code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
215
  inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
216
  code_blocks = code_block_pattern.findall(md_text)
 
239
  # first, clip hr line
240
  hr_pattern = r'\n\n<hr class="append-display no-in-raw" />(.*?)'
241
  hr_match = re.search(hr_pattern, chat_message, re.DOTALL)
242
+ message_clipped = chat_message[: hr_match.start()] if hr_match else chat_message
243
  # second, avoid agent-prefix being escaped
244
+ agent_prefix_pattern = (
245
+ r'(<!-- S O PREFIX --><p class="agent-prefix">.*?<\/p><!-- E O PREFIX -->)'
246
+ )
247
  # agent_matches = re.findall(agent_prefix_pattern, message_clipped)
248
  agent_parts = re.split(agent_prefix_pattern, message_clipped, flags=re.DOTALL)
249
  final_message = ""
250
  for i, part in enumerate(agent_parts):
251
  if i % 2 == 0:
252
  if part != "" and part != "\n":
253
+ final_message += (
254
+ f'<pre class="fake-pre">{escape_markdown(part)}</pre>'
255
+ if need_escape
256
+ else f'<pre class="fake-pre">{part}</pre>'
257
+ )
258
  else:
259
  final_message += part
260
  return final_message
 
284
  md = f'<div class="md-message">\n\n{result}\n</div>'
285
  return raw + md
286
 
287
+
288
  def convert_user_before_marked(chat_message):
289
  if '<div class="user-message">' in chat_message:
290
  return chat_message
291
  else:
292
  return f'<div class="user-message">{escape_markdown(chat_message)}</div>'
293
 
294
+
295
  def escape_markdown(text):
296
  """
297
  Escape Markdown special characters to HTML-safe equivalents.
298
  """
299
  escape_chars = {
300
  # ' ': '&nbsp;',
301
+ "_": "&#95;",
302
+ "*": "&#42;",
303
+ "[": "&#91;",
304
+ "]": "&#93;",
305
+ "(": "&#40;",
306
+ ")": "&#41;",
307
+ "{": "&#123;",
308
+ "}": "&#125;",
309
+ "#": "&#35;",
310
+ "+": "&#43;",
311
+ "-": "&#45;",
312
+ ".": "&#46;",
313
+ "!": "&#33;",
314
+ "`": "&#96;",
315
+ ">": "&#62;",
316
+ "<": "&#60;",
317
+ "|": "&#124;",
318
+ "$": "&#36;",
319
+ ":": "&#58;",
320
+ "\n": "<br>",
321
  }
322
+ text = text.replace(" ", "&nbsp;&nbsp;&nbsp;&nbsp;")
323
+ return "".join(escape_chars.get(c, c) for c in text)
324
 
325
 
326
+ def convert_asis(userinput): # deprecated
327
  return (
328
  f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
329
  + ALREADY_CONVERTED_MARK
330
  )
331
 
332
 
333
+ def detect_converted_mark(userinput): # deprecated
334
  try:
335
  if userinput.endswith(ALREADY_CONVERTED_MARK):
336
  return True
 
340
  return True
341
 
342
 
343
+ def detect_language(code): # deprecated
344
  if code.startswith("\n"):
345
  first_line = ""
346
  else:
 
366
  return construct_text("assistant", text)
367
 
368
 
369
+ def save_file(filename, model, chatbot):
370
+ system = model.system_prompt
371
+ history = model.history
372
+ user_name = model.user_name
373
  os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
374
  if filename is None:
375
  filename = new_auto_history_filename(user_name)
 
380
  if filename == ".json":
381
  raise Exception("文件名不能为空")
382
 
383
+ json_s = {
384
+ "system": system,
385
+ "history": history,
386
+ "chatbot": chatbot,
387
+ "single_turn": model.single_turn,
388
+ "temperature": model.temperature,
389
+ "top_p": model.top_p,
390
+ "n_choices": model.n_choices,
391
+ "stop_sequence": model.stop_sequence,
392
+ "max_generation_token": model.max_generation_token,
393
+ "presence_penalty": model.presence_penalty,
394
+ "frequency_penalty": model.frequency_penalty,
395
+ "logit_bias": model.logit_bias,
396
+ "user_identifier": model.user_identifier,
397
+ "metadata": model.metadata
398
+ }
399
  if not filename == os.path.basename(filename):
400
  history_file_path = filename
401
  else:
402
  history_file_path = os.path.join(HISTORY_DIR, user_name, filename)
403
 
404
+ with open(history_file_path, "w", encoding="utf-8") as f:
405
+ json.dump(json_s, f, ensure_ascii=False, indent=4)
406
 
407
  filename = os.path.basename(filename)
408
  filename_md = filename[:-5] + ".md"
409
  md_s = f"system: \n- {system} \n"
410
  for data in history:
411
  md_s += f"\n{data['role']}: \n- {data['content']} \n"
412
+ with open(
413
+ os.path.join(HISTORY_DIR, user_name, filename_md), "w", encoding="utf8"
414
+ ) as f:
415
  f.write(md_s)
416
  return os.path.join(HISTORY_DIR, user_name, filename)
417
 
 
419
  def sorted_by_pinyin(list):
420
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
421
 
422
+
423
  def sorted_by_last_modified_time(list, dir):
424
+ return sorted(
425
+ list, key=lambda char: os.path.getctime(os.path.join(dir, char)), reverse=True
426
+ )
427
+
428
 
429
  def get_file_names_by_type(dir, filetypes=[".json"]):
430
  logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes}")
 
434
  logging.debug(f"files are:{files}")
435
  return files
436
 
437
+
438
  def get_file_names_by_pinyin(dir, filetypes=[".json"]):
439
  files = get_file_names_by_type(dir, filetypes)
440
  if files != [""]:
 
442
  logging.debug(f"files are:{files}")
443
  return files
444
 
445
+
446
  def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]):
447
  files = get_file_names_by_pinyin(dir, filetypes)
448
  return gr.Dropdown.update(choices=files)
449
 
450
+
451
  def get_file_names_by_last_modified_time(dir, filetypes=[".json"]):
452
  files = get_file_names_by_type(dir, filetypes)
453
  if files != [""]:
 
461
  if user_name == "" and hide_history_when_not_logged_in:
462
  return []
463
  else:
464
+ history_files = get_file_names_by_last_modified_time(
465
+ os.path.join(HISTORY_DIR, user_name)
466
+ )
467
+ history_files = [f[: f.rfind(".")] for f in history_files]
468
  return history_files
469
 
470
+
471
  def get_first_history_name(user_name=""):
472
  history_names = get_history_names(user_name)
473
  return history_names[0] if history_names else None
474
 
475
+
476
  def get_history_list(user_name=""):
477
  history_names = get_history_names(user_name)
478
  return gr.Radio.update(choices=history_names)
479
 
480
+
481
  def init_history_list(user_name=""):
482
  history_names = get_history_names(user_name)
483
+ return gr.Radio.update(
484
+ choices=history_names, value=history_names[0] if history_names else ""
485
+ )
486
+
487
 
488
  def filter_history(user_name, keyword):
489
  history_names = get_history_names(user_name)
 
493
  except:
494
  return gr.update(choices=history_names)
495
 
496
+
497
  def load_template(filename, mode=0):
498
  logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
499
  lines = []
 
514
  return {row[0]: row[1] for row in lines}
515
  else:
516
  choices = sorted_by_pinyin([row[0] for row in lines])
517
+ return {row[0]: row[1] for row in lines}, gr.Dropdown.update(choices=choices)
 
 
518
 
519
 
520
  def get_template_names():
521
  logging.debug("获取模板文件名列表")
522
  return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"])
523
 
524
+
525
  def get_template_dropdown():
526
  logging.debug("获取模板下拉菜单")
527
  template_names = get_template_names()
 
596
  if "error" in data.keys():
597
  logging.warning(f"无法获取IP地址信息。\n{data}")
598
  if data["reason"] == "RateLimited":
599
+ return i18n("您的IP区域:未知。")
 
 
600
  else:
601
  return i18n("获取IP地理位置失败。原因:") + f"{data['reason']}" + i18n("。你仍然可以使用聊天功能。")
602
  else:
 
660
  if update_status == "success":
661
  logging.info("Successfully updated, restart needed")
662
  status = '<span id="update-status" class="hideK">success</span>'
663
+ return gr.Markdown.update(value=i18n("更新成功,请重启本程序") + status)
664
  else:
665
  status = '<span id="update-status" class="hideK">failure</span>'
666
+ return gr.Markdown.update(
667
+ value=i18n(
668
+ "更新失败,请尝试[手动更新](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程#手动更新)"
669
+ )
670
+ + status
671
+ )
672
 
673
 
674
+ def add_source_numbers(lst, source_name="Source", use_source=True):
675
  if use_source:
676
+ return [
677
+ f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}'
678
+ for idx, item in enumerate(lst)
679
+ ]
680
  else:
681
  return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
682
 
683
+
684
  def add_details(lst):
685
  nodes = []
686
  for index, txt in enumerate(lst):
687
  brief = txt[:25].replace("\n", "")
688
+ nodes.append(f"<details><summary>{brief}...</summary><p>{txt}</p></details>")
 
 
689
  return nodes
690
 
691
 
692
+ def sheet_to_string(sheet, sheet_name=None):
693
  result = []
694
  for index, row in sheet.iterrows():
695
  row_string = ""
 
700
  result.append(row_string)
701
  return result
702
 
703
+
704
  def excel_to_string(file_path):
705
  # 读取Excel文件中的所有工作表
706
+ excel_file = pd.read_excel(file_path, engine="openpyxl", sheet_name=None)
707
 
708
  # 初始化结果字符串
709
  result = []
710
 
711
  # 遍历每一个工作表
712
  for sheet_name, sheet_data in excel_file.items():
 
713
  # 处理当前工作表并添加到结果字符串
714
  result += sheet_to_string(sheet_data, sheet_name=sheet_name)
715
 
 
716
  return result
717
 
718
+
719
  def get_last_day_of_month(any_day):
720
  # The day 28 exists in every month. 4 days later, it's always next month
721
  next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
722
  # subtracting the number of the current day brings us back one month
723
  return next_month - datetime.timedelta(days=next_month.day)
724
 
725
+
726
  def get_model_source(model_name, alternative_source):
727
  if model_name == "gpt2-medium":
728
  return "https://huggingface.co/gpt2-medium"
729
 
730
+
731
  def refresh_ui_elements_on_load(current_model, selected_model_name, user_name):
732
  current_model.set_user_identifier(user_name)
733
  return toggle_like_btn_visibility(selected_model_name), *current_model.auto_load()
734
 
735
+
736
  def toggle_like_btn_visibility(selected_model_name):
737
  if selected_model_name == "xmchat":
738
  return gr.update(visible=True)
739
  else:
740
  return gr.update(visible=False)
741
 
742
+
743
  def get_corresponding_file_type_by_model_name(selected_model_name):
744
  if selected_model_name in ["xmchat", "GPT4 Vision"]:
745
  return ["image"]
746
  else:
747
  return [".pdf", ".docx", ".pptx", ".epub", ".xlsx", ".txt", "text"]
748
 
749
+
750
  # def toggle_file_type(selected_model_name):
751
  # return gr.Files.update(file_types=get_corresponding_file_type_by_model_name(selected_model_name))
752
 
753
+
754
  def new_auto_history_filename(username):
755
  latest_file = get_first_history_name(username)
756
  if latest_file:
757
+ with open(
758
+ os.path.join(HISTORY_DIR, username, latest_file + ".json"),
759
+ "r",
760
+ encoding="utf-8",
761
+ ) as f:
762
  if len(f.read()) == 0:
763
  return latest_file
764
+ now = i18n("新对话 ") + datetime.datetime.now().strftime("%m-%d %H-%M")
765
+ return f"{now}.json"
766
+
767
 
768
  def get_history_filepath(username):
769
  dirname = os.path.join(HISTORY_DIR, username)
 
775
  latest_file = os.path.join(dirname, latest_file)
776
  return latest_file
777
 
778
+
779
  def beautify_err_msg(err_msg):
780
+ if "insufficient_quota" in err_msg:
781
+ return i18n(
782
+ "剩余配额不足,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98#you-exceeded-your-current-quota-please-check-your-plan-and-billing-details)"
783
+ )
784
  if "The model `gpt-4` does not exist" in err_msg:
785
+ return i18n(
786
+ "你没有权限访问 GPT4,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues/843)"
787
+ )
788
  if "Resource not found" in err_msg:
789
  return i18n("请查看 config_example.json,配置 Azure OpenAI")
790
  return err_msg
791
 
792
+
793
  def auth_from_conf(username, password):
794
  try:
795
  with open("config.json", encoding="utf-8") as f:
796
  conf = json.load(f)
797
+ usernames, passwords = [i[0] for i in conf["users"]], [
798
+ i[1] for i in conf["users"]
799
+ ]
800
  if username in usernames:
801
  if passwords[usernames.index(username)] == password:
802
  return True
 
804
  except:
805
  return False
806
 
807
+
808
  def get_file_hash(file_src=None, file_paths=None):
809
  if file_src:
810
  file_paths = [x.name for x in file_src]
 
818
 
819
  return md5_hash.hexdigest()
820
 
821
+
822
  def myprint(**args):
823
  print(args)
824
 
825
+
826
  def replace_special_symbols(string, replace_string=" "):
827
  # 定义正则表达式,匹配所有特殊符号
828
+ pattern = r"[!@#$%^&*()<>?/\|}{~:]"
829
 
830
  new_string = re.sub(pattern, replace_string, string)
831