JohnSmith9982 commited on
Commit
0405ac0
·
1 Parent(s): 2c3bb3b

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +21 -5
utils.py CHANGED
@@ -49,7 +49,7 @@ def postprocess(
49
  return y
50
 
51
  def count_token(input_str):
52
- encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
53
  length = len(encoding.encode(input_str))
54
  return length
55
 
@@ -144,14 +144,20 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
144
  try:
145
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
146
  except requests.exceptions.ConnectTimeout:
147
- status_text = standard_error_msg + error_retrieve_prompt
 
 
 
 
 
 
148
  yield get_return_value()
149
  return
150
 
151
  chatbot.append((parse_text(inputs), ""))
152
  yield get_return_value()
153
 
154
- for chunk in tqdm(response.iter_lines()):
155
  if counter == 0:
156
  counter += 1
157
  continue
@@ -160,7 +166,12 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
160
  if chunk:
161
  chunk = chunk.decode()
162
  chunklength = len(chunk)
163
- chunk = json.loads(chunk[6:])
 
 
 
 
 
164
  # decode each line as response data is in bytes
165
  if chunklength > 6 and "delta" in chunk['choices'][0]:
166
  finish_reason = chunk['choices'][0]['finish_reason']
@@ -169,7 +180,12 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
169
  print("生成完毕")
170
  yield get_return_value()
171
  break
172
- partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
 
 
 
 
 
173
  if token_counter == 0:
174
  history.append(construct_assistant(" " + partial_words))
175
  else:
 
49
  return y
50
 
51
  def count_token(input_str):
52
+ encoding = tiktoken.get_encoding("cl100k_base")
53
  length = len(encoding.encode(input_str))
54
  return length
55
 
 
144
  try:
145
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
146
  except requests.exceptions.ConnectTimeout:
147
+ history.pop()
148
+ status_text = standard_error_msg + "连接超时,无法获取对话。" + error_retrieve_prompt
149
+ yield get_return_value()
150
+ return
151
+ except requests.exceptions.ReadTimeout:
152
+ history.pop()
153
+ status_text = standard_error_msg + "读取超时,无法获取对话。" + error_retrieve_prompt
154
  yield get_return_value()
155
  return
156
 
157
  chatbot.append((parse_text(inputs), ""))
158
  yield get_return_value()
159
 
160
+ for chunk in response.iter_lines():
161
  if counter == 0:
162
  counter += 1
163
  continue
 
166
  if chunk:
167
  chunk = chunk.decode()
168
  chunklength = len(chunk)
169
+ try:
170
+ chunk = json.loads(chunk[6:])
171
+ except json.JSONDecodeError:
172
+ status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
173
+ yield get_return_value()
174
+ break
175
  # decode each line as response data is in bytes
176
  if chunklength > 6 and "delta" in chunk['choices'][0]:
177
  finish_reason = chunk['choices'][0]['finish_reason']
 
180
  print("生成完毕")
181
  yield get_return_value()
182
  break
183
+ try:
184
+ partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
185
+ except KeyError:
186
+ status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(previous_token_count)+token_counter+user_token_count)
187
+ yield get_return_value()
188
+ break
189
  if token_counter == 0:
190
  history.append(construct_assistant(" " + partial_words))
191
  else: