vlff李飞飞 commited on
Commit
dc8d3c6
·
1 Parent(s): fc211c5
Files changed (1) hide show
  1. qwen_agent/llm/qwen_oai.py +25 -32
qwen_agent/llm/qwen_oai.py CHANGED
@@ -124,7 +124,7 @@ _TEXT_COMPLETION_CMD = object()
124
  #
125
  def parse_messages(messages, functions):
126
  if all(m.role != "user" for m in messages):
127
- raise Exception( f"Invalid request: Expecting at least one user message.",)
128
  messages = copy.deepcopy(messages)
129
  default_system = "You are a helpful assistant."
130
  system = ""
@@ -381,7 +381,7 @@ def predict(
381
  stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
382
  if stop_words:
383
  # TODO: It's a little bit tricky to trim stop words in the stream mode.
384
- raise Exception("Invalid request: custom stop words are not yet supported for stream mode.",)
385
  response_generator = qmodel.chat_stream(
386
  tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
387
  )
@@ -420,35 +420,34 @@ class QwenChatAsOAI(BaseChatModel):
420
  self.model = model
421
  super().__init__()
422
  tokenizer = AutoTokenizer.from_pretrained(
423
- self.checkpoint_path,
424
  trust_remote_code=True,
425
  resume_download=True,
426
  )
427
  device_map = "cpu"
428
  # device_map = "auto"
429
  qmodel = AutoModelForCausalLM.from_pretrained(
430
- self.checkpoint_path,
431
  device_map=device_map,
432
  trust_remote_code=True,
433
  resume_download=True,
434
  ).eval()
435
 
436
  qmodel.generation_config = GenerationConfig.from_pretrained(
437
- self.checkpoint_path,
438
  trust_remote_code=True,
439
  resume_download=True,
440
  )
441
 
442
-
443
  def _chat_stream(
444
- self,
445
- messages: List[Dict],
446
- stop: Optional[List[str]] = None,
447
  ) -> Iterator[str]:
448
  _request = ChatCompletionRequest(model=self.model,
449
- messages=messages,
450
- stop=stop,
451
- stream=True)
452
  response = create_chat_completion(_request)
453
  # TODO: error handling
454
  for chunk in response:
@@ -456,14 +455,11 @@ class QwenChatAsOAI(BaseChatModel):
456
  yield chunk.choices[0].delta.content
457
 
458
  def _chat_no_stream(
459
- self,
460
- messages: List[Dict],
461
- stop: Optional[List[str]] = None,
462
  ) -> str:
463
- _request = ChatCompletionRequest(model=self.model,
464
- messages=messages,
465
- stop=stop,
466
- stream=False)
467
  response = create_chat_completion(_request)
468
  # TODO: error handling
469
  return response.choices[0].message.content
@@ -472,16 +468,13 @@ class QwenChatAsOAI(BaseChatModel):
472
  messages: List[Dict],
473
  functions: Optional[List[Dict]] = None) -> Dict:
474
  if functions:
475
- _request = ChatCompletionRequest(model=self.model,
476
- messages=messages,
477
- functions=functions)
478
  response = create_chat_completion(_request)
479
  else:
480
- _request = ChatCompletionRequest(model=self.model,
481
- messages=messages)
482
  response = create_chat_completion(_request)
483
  # TODO: error handling
484
- return response.choices[0].message.dict()
485
 
486
 
487
  class QwenChatAsOAI1(BaseChatModel):
@@ -495,9 +488,9 @@ class QwenChatAsOAI1(BaseChatModel):
495
  self.model = model
496
 
497
  def _chat_stream(
498
- self,
499
- messages: List[Dict],
500
- stop: Optional[List[str]] = None,
501
  ) -> Iterator[str]:
502
  response = openai.ChatCompletion.create(model=self.model,
503
  messages=messages,
@@ -509,9 +502,9 @@ class QwenChatAsOAI1(BaseChatModel):
509
  yield chunk.choices[0].delta.content
510
 
511
  def _chat_no_stream(
512
- self,
513
- messages: List[Dict],
514
- stop: Optional[List[str]] = None,
515
  ) -> str:
516
  response = openai.ChatCompletion.create(model=self.model,
517
  messages=messages,
@@ -531,4 +524,4 @@ class QwenChatAsOAI1(BaseChatModel):
531
  response = openai.ChatCompletion.create(model=self.model,
532
  messages=messages)
533
  # TODO: error handling
534
- return response.choices[0].message
 
124
  #
125
  def parse_messages(messages, functions):
126
  if all(m.role != "user" for m in messages):
127
+ raise Exception(f"Invalid request: Expecting at least one user message.", )
128
  messages = copy.deepcopy(messages)
129
  default_system = "You are a helpful assistant."
130
  system = ""
 
381
  stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
382
  if stop_words:
383
  # TODO: It's a little bit tricky to trim stop words in the stream mode.
384
+ raise Exception("Invalid request: custom stop words are not yet supported for stream mode.", )
385
  response_generator = qmodel.chat_stream(
386
  tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
387
  )
 
420
  self.model = model
421
  super().__init__()
422
  tokenizer = AutoTokenizer.from_pretrained(
423
+ self.model,
424
  trust_remote_code=True,
425
  resume_download=True,
426
  )
427
  device_map = "cpu"
428
  # device_map = "auto"
429
  qmodel = AutoModelForCausalLM.from_pretrained(
430
+ self.model,
431
  device_map=device_map,
432
  trust_remote_code=True,
433
  resume_download=True,
434
  ).eval()
435
 
436
  qmodel.generation_config = GenerationConfig.from_pretrained(
437
+ self.model,
438
  trust_remote_code=True,
439
  resume_download=True,
440
  )
441
 
 
442
  def _chat_stream(
443
+ self,
444
+ messages: List[Dict],
445
+ stop: Optional[List[str]] = None,
446
  ) -> Iterator[str]:
447
  _request = ChatCompletionRequest(model=self.model,
448
+ messages=messages,
449
+ stop=stop,
450
+ stream=True)
451
  response = create_chat_completion(_request)
452
  # TODO: error handling
453
  for chunk in response:
 
455
  yield chunk.choices[0].delta.content
456
 
457
  def _chat_no_stream(
458
+ self,
459
+ messages: List[Dict],
460
+ stop: Optional[List[str]] = None,
461
  ) -> str:
462
+ _request = ChatCompletionRequest(model=self.model, messages=messages, stop=stop, stream=False)
 
 
 
463
  response = create_chat_completion(_request)
464
  # TODO: error handling
465
  return response.choices[0].message.content
 
468
  messages: List[Dict],
469
  functions: Optional[List[Dict]] = None) -> Dict:
470
  if functions:
471
+ _request = ChatCompletionRequest(model=self.model, messages=messages, functions=functions)
 
 
472
  response = create_chat_completion(_request)
473
  else:
474
+ _request = ChatCompletionRequest(model=self.model, messages=messages)
 
475
  response = create_chat_completion(_request)
476
  # TODO: error handling
477
+ return response.choices[0].message.model_dump()
478
 
479
 
480
  class QwenChatAsOAI1(BaseChatModel):
 
488
  self.model = model
489
 
490
  def _chat_stream(
491
+ self,
492
+ messages: List[Dict],
493
+ stop: Optional[List[str]] = None,
494
  ) -> Iterator[str]:
495
  response = openai.ChatCompletion.create(model=self.model,
496
  messages=messages,
 
502
  yield chunk.choices[0].delta.content
503
 
504
  def _chat_no_stream(
505
+ self,
506
+ messages: List[Dict],
507
+ stop: Optional[List[str]] = None,
508
  ) -> str:
509
  response = openai.ChatCompletion.create(model=self.model,
510
  messages=messages,
 
524
  response = openai.ChatCompletion.create(model=self.model,
525
  messages=messages)
526
  # TODO: error handling
527
+ return response.choices[0].message