qingxu99 commited on
Commit
0785ff2
·
1 Parent(s): 676fe40

微调对话裁剪

Browse files
Files changed (2) hide show
  1. request_llm/bridge_chatgpt.py +1 -1
  2. toolbox.py +10 -7
request_llm/bridge_chatgpt.py CHANGED
@@ -200,7 +200,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
200
  if "reduce the length" in error_msg:
201
  if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
202
  history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
203
- max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一
204
  chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
205
  # history = [] # 清除历史
206
  elif "does not exist" in error_msg:
 
200
  if "reduce the length" in error_msg:
201
  if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
202
  history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
203
+ max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
204
  chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
205
  # history = [] # 清除历史
206
  elif "does not exist" in error_msg:
toolbox.py CHANGED
@@ -555,23 +555,26 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
555
 
556
  def clip_history(inputs, history, tokenizer, max_token_limit):
557
  """
558
- reduce the length of input/history by clipping.
559
  this function search for the longest entries to clip, little by little,
560
- until the number of token of input/history is reduced under threshold.
561
- 通过剪辑来缩短输入/历史记录的长度。
562
  此函数逐渐地搜索最长的条目进行剪辑,
563
- 直到输入/历史记录的标记数量降低到阈值以下。
564
  """
565
  import numpy as np
566
  from request_llm.bridge_all import model_info
567
  def get_token_num(txt):
568
  return len(tokenizer.encode(txt, disallowed_special=()))
569
  input_token_num = get_token_num(inputs)
570
- if input_token_num < max_token_limit * 3 / 4:
571
- # 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来
 
572
  max_token_limit = max_token_limit - input_token_num
 
 
 
573
  if max_token_limit < 128:
574
- # 余量太小了,直接清除历史
575
  history = []
576
  return history
577
  else:
 
555
 
556
  def clip_history(inputs, history, tokenizer, max_token_limit):
557
  """
558
+ reduce the length of history by clipping.
559
  this function search for the longest entries to clip, little by little,
560
+ until the number of token of history is reduced under threshold.
561
+ 通过裁剪来缩短历史记录的长度。
562
  此函数逐渐地搜索最长的条目进行剪辑,
563
+ 直到历史记录的标记数量降低到阈值以下。
564
  """
565
  import numpy as np
566
  from request_llm.bridge_all import model_info
567
  def get_token_num(txt):
568
  return len(tokenizer.encode(txt, disallowed_special=()))
569
  input_token_num = get_token_num(inputs)
570
+ if input_token_num < max_token_limit * 3 / 4:
571
+ # 当输入部分的token占比小于限制的3/4时,裁剪时
572
+ # 1. 把input的余量留出来
573
  max_token_limit = max_token_limit - input_token_num
574
+ # 2. 把输出用的余量留出来
575
+ max_token_limit = max_token_limit - 128
576
+ # 3. 如果余量太小了,直接清除历史
577
  if max_token_limit < 128:
 
578
  history = []
579
  return history
580
  else: