johnsmith253325 commited on
Commit
0fd73b9
·
1 Parent(s): 5dced7c

feat: 加入通义千问支持

Browse files
modules/models/Qwen.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from transformers.generation import GenerationConfig
3
+ import logging
4
+ import colorama
5
+ from .base_model import BaseLLMModel
6
+ from ..presets import MODEL_METADATA
7
+
8
+
9
+ class Qwen_Client(BaseLLMModel):
10
+ def __init__(self, model_name, user_name="") -> None:
11
+ super().__init__(model_name=model_name, user=user_name)
12
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True)
13
+ self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval()
14
+
15
+ def generation_config(self):
16
+ return GenerationConfig.from_dict({
17
+ "chat_format": "chatml",
18
+ "do_sample": True,
19
+ "eos_token_id": 151643,
20
+ "max_length": self.token_upper_limit,
21
+ "max_new_tokens": 512,
22
+ "max_window_size": 6144,
23
+ "pad_token_id": 151643,
24
+ "top_k": 0,
25
+ "top_p": self.top_p,
26
+ "transformers_version": "4.33.2",
27
+ "trust_remote_code": True,
28
+ "temperature": self.temperature,
29
+ })
30
+
31
+ def _get_glm_style_input(self):
32
+ history = [x["content"] for x in self.history]
33
+ query = history.pop()
34
+ logging.debug(colorama.Fore.YELLOW +
35
+ f"{history}" + colorama.Fore.RESET)
36
+ assert (
37
+ len(history) % 2 == 0
38
+ ), f"History should be even length. current history is: {history}"
39
+ history = [[history[i], history[i + 1]]
40
+ for i in range(0, len(history), 2)]
41
+ return history, query
42
+
43
+ def get_answer_at_once(self):
44
+ history, query = self._get_glm_style_input()
45
+ self.model.generation_config = self.generation_config()
46
+ response, history = self.model.chat(self.tokenizer, query, history=history)
47
+ return response, len(response)
48
+
49
+ def get_answer_stream_iter(self):
50
+ history, query = self._get_glm_style_input()
51
+ self.model.generation_config = self.generation_config()
52
+ for response in self.model.chat_stream(
53
+ self.tokenizer,
54
+ query,
55
+ history,
56
+ ):
57
+ yield response
modules/models/base_model.py CHANGED
@@ -146,6 +146,7 @@ class ModelType(Enum):
146
  Spark = 12
147
  OpenAIInstruct = 13
148
  Claude = 14
 
149
 
150
  @classmethod
151
  def get_type(cls, model_name: str):
@@ -181,7 +182,9 @@ class ModelType(Enum):
181
  elif "星火大模型" in model_name_lower:
182
  model_type = ModelType.Spark
183
  elif "claude" in model_name_lower:
184
- model_type = ModelType.Claude
 
 
185
  else:
186
  model_type = ModelType.LLaMA
187
  return model_type
@@ -656,14 +659,13 @@ class BaseLLMModel:
656
  def delete_last_conversation(self, chatbot):
657
  if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]:
658
  msg = "由于包含报错信息,只删除chatbot记录"
659
- chatbot.pop()
660
  return chatbot, self.history
661
  if len(self.history) > 0:
662
- self.history.pop()
663
- self.history.pop()
664
  if len(chatbot) > 0:
665
  msg = "删除了一组chatbot对话"
666
- chatbot.pop()
667
  if len(self.all_token_counts) > 0:
668
  msg = "删除了一组对话的token计数记录"
669
  self.all_token_counts.pop()
 
146
  Spark = 12
147
  OpenAIInstruct = 13
148
  Claude = 14
149
+ Qwen = 15
150
 
151
  @classmethod
152
  def get_type(cls, model_name: str):
 
182
  elif "星火大模型" in model_name_lower:
183
  model_type = ModelType.Spark
184
  elif "claude" in model_name_lower:
185
+ model_type = ModelType.Claude
186
+ elif "qwen" in model_name_lower:
187
+ model_type = ModelType.Qwen
188
  else:
189
  model_type = ModelType.LLaMA
190
  return model_type
 
659
  def delete_last_conversation(self, chatbot):
660
  if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]:
661
  msg = "由于包含报错信息,只删除chatbot记录"
662
+ chatbot = chatbot[:-1]
663
  return chatbot, self.history
664
  if len(self.history) > 0:
665
+ self.history = self.history[:-2]
 
666
  if len(chatbot) > 0:
667
  msg = "删除了一组chatbot对话"
668
+ chatbot = chatbot[:-1]
669
  if len(self.all_token_counts) > 0:
670
  msg = "删除了一组对话的token计数记录"
671
  self.all_token_counts.pop()
modules/models/models.py CHANGED
@@ -116,9 +116,12 @@ def get_model(
116
  from .spark import Spark_Client
117
  model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
118
  "SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
119
- elif model_type == ModelType.Claude:
120
  from .Claude import Claude_Client
121
  model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET"))
 
 
 
122
  elif model_type == ModelType.Unknown:
123
  raise ValueError(f"未知模型: {model_name}")
124
  logging.info(msg)
 
116
  from .spark import Spark_Client
117
  model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
118
  "SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
119
+ elif model_type == ModelType.Claude:
120
  from .Claude import Claude_Client
121
  model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET"))
122
+ elif model_type == ModelType.Qwen:
123
+ from .Qwen import Qwen_Client
124
+ model = Qwen_Client(model_name, user_name=user_name)
125
  elif model_type == ModelType.Unknown:
126
  raise ValueError(f"未知模型: {model_name}")
127
  logging.info(msg)
modules/presets.py CHANGED
@@ -87,6 +87,8 @@ LOCAL_MODELS = [
87
  "StableLM",
88
  "MOSS",
89
  "Llama-2-7B-Chat",
 
 
90
  ]
91
 
92
  # Additional metadate for local models
@@ -98,6 +100,12 @@ MODEL_METADATA = {
98
  "Llama-2-7B-Chat":{
99
  "repo_id": "TheBloke/Llama-2-7b-Chat-GGUF",
100
  "filelist": ["llama-2-7b-chat.Q6_K.gguf"],
 
 
 
 
 
 
101
  }
102
  }
103
 
 
87
  "StableLM",
88
  "MOSS",
89
  "Llama-2-7B-Chat",
90
+ "Qwen 7B",
91
+ "Qwen 14B"
92
  ]
93
 
94
  # Additional metadate for local models
 
100
  "Llama-2-7B-Chat":{
101
  "repo_id": "TheBloke/Llama-2-7b-Chat-GGUF",
102
  "filelist": ["llama-2-7b-chat.Q6_K.gguf"],
103
+ },
104
+ "Qwen 7B": {
105
+ "repo_id": "Qwen/Qwen-7B-Chat-Int4",
106
+ },
107
+ "Qwen 14B": {
108
+ "repo_id": "Qwen/Qwen-14B-Chat-Int4",
109
  }
110
  }
111
 
requirements_advanced.txt CHANGED
@@ -6,3 +6,7 @@ sentence_transformers
6
  accelerate
7
  sentencepiece
8
  llama-cpp-python
 
 
 
 
 
6
  accelerate
7
  sentencepiece
8
  llama-cpp-python
9
+ transformers_stream_generator
10
+ einops
11
+ optimum
12
+ auto-gptq