sparkleman commited on
Commit
94c4923
·
1 Parent(s): 664ff1c

UPDATE: Remove <think> tag in content & handle EOS token

Browse files
Files changed (2) hide show
  1. app.py +24 -9
  2. utils.py +27 -2
app.py CHANGED
@@ -44,6 +44,8 @@ from fastapi import FastAPI, HTTPException
44
  from fastapi.responses import StreamingResponse
45
  from fastapi.middleware.cors import CORSMiddleware
46
  from fastapi.staticfiles import StaticFiles
 
 
47
 
48
  from api_types import (
49
  ChatMessage,
@@ -54,7 +56,7 @@ from api_types import (
54
  ChatCompletionChoice,
55
  ChatCompletionMessage,
56
  )
57
- from utils import cleanMessages, parse_think_response
58
 
59
 
60
  class ModelStorage:
@@ -159,6 +161,7 @@ app.add_middleware(
159
  allow_methods=["*"],
160
  allow_headers=["*"],
161
  )
 
162
 
163
 
164
  async def runPrefill(
@@ -185,7 +188,6 @@ def generate(
185
  out,
186
  model_tokens: List[int],
187
  model_state,
188
- stops=["\n\n"],
189
  max_tokens=2048,
190
  ):
191
  args = PIPELINE_ARGS(
@@ -212,18 +214,29 @@ def generate(
212
  out, temperature=args.temperature, top_p=args.top_p
213
  )
214
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  out, model_state = MODEL_STORAGE[request.model].model.forward(
216
  [token], model_state
217
  )
218
- model_tokens.append(token)
219
 
220
- out_tokens.append(token)
221
 
222
  if token in request.stop_tokens:
223
  yield {
224
  "content": "",
225
  "tokens": out_tokens[out_last:],
226
- "finish_reason": "stop",
227
  "state": model_state,
228
  }
229
 
@@ -231,6 +244,8 @@ def generate(
231
  gc.collect()
232
  return
233
 
 
 
234
  for xxx in occurrence:
235
  occurrence[xxx] *= request.penalty_decay
236
  occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
@@ -243,13 +258,13 @@ def generate(
243
  output_cache.append(tmp)
244
  output_cache_str = "".join(output_cache)
245
 
246
- for stop_words in stops:
247
  if stop_words in output_cache_str:
248
 
249
  yield {
250
  "content": tmp.replace(stop_words, ""),
251
  "tokens": out_tokens[out_last:],
252
- "finish_reason": "stop",
253
  "state": model_state,
254
  }
255
 
@@ -365,7 +380,7 @@ async def chatResponseStream(
365
  createTimestamp = int(time.time())
366
 
367
  prompt = (
368
- f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
369
  if request.prompt == None
370
  else request.prompt.strip()
371
  )
@@ -415,7 +430,7 @@ async def chatResponseStream(
415
  buffer.append("<think")
416
 
417
  streamConfig = {
418
- "isChecking": False,
419
  "fullTextCursor": 0,
420
  "in_think": False,
421
  "cacheStr": "",
 
44
  from fastapi.responses import StreamingResponse
45
  from fastapi.middleware.cors import CORSMiddleware
46
  from fastapi.staticfiles import StaticFiles
47
+ from fastapi.middleware.gzip import GZipMiddleware
48
+
49
 
50
  from api_types import (
51
  ChatMessage,
 
56
  ChatCompletionChoice,
57
  ChatCompletionMessage,
58
  )
59
+ from utils import cleanMessages, parse_think_response, remove_nested_think_tags_stack
60
 
61
 
62
  class ModelStorage:
 
161
  allow_methods=["*"],
162
  allow_headers=["*"],
163
  )
164
+ app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
165
 
166
 
167
  async def runPrefill(
 
188
  out,
189
  model_tokens: List[int],
190
  model_state,
 
191
  max_tokens=2048,
192
  ):
193
  args = PIPELINE_ARGS(
 
214
  out, temperature=args.temperature, top_p=args.top_p
215
  )
216
 
217
+ if token == 0 and token in request.stop_tokens:
218
+ yield {
219
+ "content": "",
220
+ "tokens": out_tokens[out_last:],
221
+ "finish_reason": "stop:token:0",
222
+ "state": model_state,
223
+ }
224
+
225
+ del out
226
+ gc.collect()
227
+ return
228
+
229
  out, model_state = MODEL_STORAGE[request.model].model.forward(
230
  [token], model_state
231
  )
 
232
 
233
+ model_tokens.append(token)
234
 
235
  if token in request.stop_tokens:
236
  yield {
237
  "content": "",
238
  "tokens": out_tokens[out_last:],
239
+ "finish_reason": f"stop:token:{token}",
240
  "state": model_state,
241
  }
242
 
 
244
  gc.collect()
245
  return
246
 
247
+ out_tokens.append(token)
248
+
249
  for xxx in occurrence:
250
  occurrence[xxx] *= request.penalty_decay
251
  occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
 
258
  output_cache.append(tmp)
259
  output_cache_str = "".join(output_cache)
260
 
261
+ for stop_words in request.stop:
262
  if stop_words in output_cache_str:
263
 
264
  yield {
265
  "content": tmp.replace(stop_words, ""),
266
  "tokens": out_tokens[out_last:],
267
+ "finish_reason": f"stop:words:{stop_words}",
268
  "state": model_state,
269
  }
270
 
 
380
  createTimestamp = int(time.time())
381
 
382
  prompt = (
383
+ f"{cleanMessages(request.messages,enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}"
384
  if request.prompt == None
385
  else request.prompt.strip()
386
  )
 
430
  buffer.append("<think")
431
 
432
  streamConfig = {
433
+ "isChecking": False, # check whether is <think> tag
434
  "fullTextCursor": 0,
435
  "in_think": False,
436
  "cacheStr": "",
utils.py CHANGED
@@ -24,12 +24,37 @@ def parse_think_response(full_response: str):
24
  return reasoning_content, content
25
 
26
 
27
- def cleanMessages(messages: List[ChatMessage]):
28
  promptStrList = []
29
 
30
  for message in messages:
31
  content = message.content.strip()
32
  content = re.sub(r"\n+", "\n", content)
33
- promptStrList.append(f"{message.role.strip()}: {content}")
 
 
34
 
35
  return "\n\n".join(promptStrList)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return reasoning_content, content
25
 
26
 
27
+ def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = False):
28
  promptStrList = []
29
 
30
  for message in messages:
31
  content = message.content.strip()
32
  content = re.sub(r"\n+", "\n", content)
33
+ promptStrList.append(
34
+ f"{message.role.strip()}: {content if message.role!='Assistant' or not removeThinkingContent else remove_nested_think_tags_stack(content)}"
35
+ )
36
 
37
  return "\n\n".join(promptStrList)
38
+
39
+
40
+ def remove_nested_think_tags_stack(text):
41
+ stack = []
42
+ result = ""
43
+ i = 0
44
+ while i < len(text):
45
+ if text[i : i + 7] == "<think>":
46
+ stack.append("<think>")
47
+ i += 7
48
+ elif text[i : i + 8] == "</think>":
49
+ if stack and stack[-1] == "<think>":
50
+ stack.pop()
51
+ i += 8
52
+ else:
53
+ result += text[i : i + 8]
54
+ i += 8
55
+ elif not stack:
56
+ result += text[i]
57
+ i += 1
58
+ else:
59
+ i += 1
60
+ return result