Tuchuanhuhuhu commited on
Commit
007cc3d
·
1 Parent(s): 5d96469

feat: 更换模型时保持对话上下文

Browse files
Files changed (2) hide show
  1. ChuanhuChatbot.py +6 -6
  2. modules/models/models.py +7 -4
ChuanhuChatbot.py CHANGED
@@ -70,7 +70,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
70
  uploadFileBtn = gr.UploadButton(
71
  interactive=True, label="", file_types=[".json"], elem_id="gr-history-upload-btn")
72
  historyRefreshBtn = gr.Button("", elem_id="gr-history-refresh-btn")
73
-
74
 
75
  with gr.Row(elem_id="chuanhu-history-body"):
76
  with gr.Column(scale=6, elem_id="history-select-wrap"):
@@ -372,11 +372,11 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
372
  label=i18n("实时传输回答"), value=True, visible=ENABLE_STREAMING_OPTION, elem_classes="switch-checkbox"
373
  )
374
  name_chat_method = gr.Dropdown(
375
- label=i18n("对话命名方式"),
376
- choices=HISTORY_NAME_METHODS,
377
  multiselect=False,
378
  interactive=True,
379
- value=HISTORY_NAME_METHODS[chat_name_method_index],
380
  )
381
  single_turn_checkbox = gr.Checkbox(label=i18n(
382
  "单轮对话"), value=False, elem_classes="switch-checkbox", elem_id="gr-single-session-cb", visible=False)
@@ -638,12 +638,12 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
638
  keyTxt.submit(**get_usage_args)
639
  single_turn_checkbox.change(
640
  set_single_turn, [current_model, single_turn_checkbox], None)
641
- model_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider, top_p_slider, systemPromptTxt, user_name], [
642
  current_model, status_display, chatbot, lora_select_dropdown, user_api_key, keyTxt], show_progress=True, api_name="get_model")
643
  model_select_dropdown.change(toggle_like_btn_visibility, [model_select_dropdown], [
644
  like_dislike_area], show_progress=False)
645
  lora_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider,
646
- top_p_slider, systemPromptTxt, user_name], [current_model, status_display, chatbot], show_progress=True)
647
 
648
  # Template
649
  systemPromptTxt.change(set_system_prompt, [
 
70
  uploadFileBtn = gr.UploadButton(
71
  interactive=True, label="", file_types=[".json"], elem_id="gr-history-upload-btn")
72
  historyRefreshBtn = gr.Button("", elem_id="gr-history-refresh-btn")
73
+
74
 
75
  with gr.Row(elem_id="chuanhu-history-body"):
76
  with gr.Column(scale=6, elem_id="history-select-wrap"):
 
372
  label=i18n("实时传输回答"), value=True, visible=ENABLE_STREAMING_OPTION, elem_classes="switch-checkbox"
373
  )
374
  name_chat_method = gr.Dropdown(
375
+ label=i18n("对话命名方式"),
376
+ choices=HISTORY_NAME_METHODS,
377
  multiselect=False,
378
  interactive=True,
379
+ value=HISTORY_NAME_METHODS[chat_name_method_index],
380
  )
381
  single_turn_checkbox = gr.Checkbox(label=i18n(
382
  "单轮对话"), value=False, elem_classes="switch-checkbox", elem_id="gr-single-session-cb", visible=False)
 
638
  keyTxt.submit(**get_usage_args)
639
  single_turn_checkbox.change(
640
  set_single_turn, [current_model, single_turn_checkbox], None)
641
+ model_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider, top_p_slider, systemPromptTxt, user_name, current_model], [
642
  current_model, status_display, chatbot, lora_select_dropdown, user_api_key, keyTxt], show_progress=True, api_name="get_model")
643
  model_select_dropdown.change(toggle_like_btn_visibility, [model_select_dropdown], [
644
  like_dislike_area], show_progress=False)
645
  lora_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider,
646
+ top_p_slider, systemPromptTxt, user_name, current_model], [current_model, status_display, chatbot], show_progress=True)
647
 
648
  # Template
649
  systemPromptTxt.change(set_system_prompt, [
modules/models/models.py CHANGED
@@ -228,7 +228,7 @@ class OpenAIClient(BaseLLMModel):
228
  ret = super().set_key(new_access_key)
229
  self._refresh_header()
230
  return ret
231
-
232
  def _single_query_at_once(self, history, temperature=1.0):
233
  timeout = TIMEOUT_ALL
234
  headers = {
@@ -255,7 +255,7 @@ class OpenAIClient(BaseLLMModel):
255
 
256
  return response
257
 
258
-
259
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
260
  if len(self.history) == 2 and not single_turn_checkbox:
261
  user_question = self.history[0]["content"]
@@ -601,7 +601,8 @@ def get_model(
601
  temperature=None,
602
  top_p=None,
603
  system_prompt=None,
604
- user_name=""
 
605
  ) -> BaseLLMModel:
606
  msg = i18n("模型设置为了:") + f" {model_name}"
607
  model_type = ModelType.get_type(model_name)
@@ -671,7 +672,7 @@ def get_model(
671
  access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
672
  model = Google_PaLM_Client(model_name, access_key, user_name=user_name)
673
  elif model_type == ModelType.LangchainChat:
674
- from .azure import Azure_OpenAI_Client
675
  model = Azure_OpenAI_Client(model_name, user_name=user_name)
676
  elif model_type == ModelType.Midjourney:
677
  from .midjourney import Midjourney_Client
@@ -688,6 +689,8 @@ def get_model(
688
  traceback.print_exc()
689
  msg = f"{STANDARD_ERROR_MSG}: {e}"
690
  presudo_key = hide_middle_chars(access_key)
 
 
691
  if dont_change_lora_selector:
692
  return model, msg, chatbot, gr.update(), access_key, presudo_key
693
  else:
 
228
  ret = super().set_key(new_access_key)
229
  self._refresh_header()
230
  return ret
231
+
232
  def _single_query_at_once(self, history, temperature=1.0):
233
  timeout = TIMEOUT_ALL
234
  headers = {
 
255
 
256
  return response
257
 
258
+
259
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
260
  if len(self.history) == 2 and not single_turn_checkbox:
261
  user_question = self.history[0]["content"]
 
601
  temperature=None,
602
  top_p=None,
603
  system_prompt=None,
604
+ user_name="",
605
+ original_model = None
606
  ) -> BaseLLMModel:
607
  msg = i18n("模型设置为了:") + f" {model_name}"
608
  model_type = ModelType.get_type(model_name)
 
672
  access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
673
  model = Google_PaLM_Client(model_name, access_key, user_name=user_name)
674
  elif model_type == ModelType.LangchainChat:
675
+ from .Azure import Azure_OpenAI_Client
676
  model = Azure_OpenAI_Client(model_name, user_name=user_name)
677
  elif model_type == ModelType.Midjourney:
678
  from .midjourney import Midjourney_Client
 
689
  traceback.print_exc()
690
  msg = f"{STANDARD_ERROR_MSG}: {e}"
691
  presudo_key = hide_middle_chars(access_key)
692
+ if original_model is not None:
693
+ model.history = original_model.history
694
  if dont_change_lora_selector:
695
  return model, msg, chatbot, gr.update(), access_key, presudo_key
696
  else: