Xudong Liu commited on
Commit
e17e77b
·
unverified ·
1 Parent(s): cd2e998

增加对文心一言的支持,支持文心一言的三个主要模型。 (#931)

Browse files
config_example.json CHANGED
@@ -15,6 +15,8 @@
15
  "spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
16
  "spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
17
  "claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
 
 
18
 
19
 
20
  //== Azure ==
 
15
  "spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
16
  "spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
17
  "claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
18
+ "ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
19
+ "ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
20
 
21
 
22
  //== Azure ==
modules/config.py CHANGED
@@ -135,6 +135,11 @@ os.environ["SPARK_API_SECRET"] = spark_api_secret
135
  claude_api_secret = config.get("claude_api_secret", "")
136
  os.environ["CLAUDE_API_SECRET"] = claude_api_secret
137
 
 
 
 
 
 
138
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
139
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
140
 
 
135
  claude_api_secret = config.get("claude_api_secret", "")
136
  os.environ["CLAUDE_API_SECRET"] = claude_api_secret
137
 
138
+ ernie_api_key = config.get("ernie_api_key", "")
139
+ os.environ["ERNIE_APIKEY"] = ernie_api_key
140
+ ernie_secret_key = config.get("ernie_secret_key", "")
141
+ os.environ["ERNIE_SECRETKEY"] = ernie_secret_key
142
+
143
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
144
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
145
 
modules/models/ERNIE.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..presets import *
2
+ from ..utils import *
3
+
4
+ from .base_model import BaseLLMModel
5
+
6
+
7
+ class ERNIE_Client(BaseLLMModel):
8
+ def __init__(self, model_name, api_key, secret_key) -> None:
9
+ super().__init__(model_name=model_name)
10
+ self.api_key = api_key
11
+ self.api_secret = secret_key
12
+ if None in [self.api_secret, self.api_key]:
13
+ raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key")
14
+
15
+ if self.model_name == "ERNIE-Bot-turbo":
16
+ self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="
17
+ elif self.model_name == "ERNIE-Bot":
18
+ self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
19
+ elif self.model_name == "ERNIE-Bot-4":
20
+ self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="
21
+
22
+ def get_access_token(self):
23
+ """
24
+ 使用 AK,SK 生成鉴权签名(Access Token)
25
+ :return: access_token,或是None(如果错误)
26
+ """
27
+ url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials"
28
+
29
+ payload = json.dumps("")
30
+ headers = {
31
+ 'Content-Type': 'application/json',
32
+ 'Accept': 'application/json'
33
+ }
34
+
35
+ response = requests.request("POST", url, headers=headers, data=payload)
36
+
37
+ return response.json()["access_token"]
38
+ def get_answer_stream_iter(self):
39
+ url = self.ERNIE_url + self.get_access_token()
40
+ system_prompt = self.system_prompt
41
+ history = self.history
42
+ if system_prompt is not None:
43
+ history = [construct_system(system_prompt), *history]
44
+
45
+ # 去除history中 history的role为system的
46
+ history = [i for i in history if i["role"] != "system"]
47
+
48
+ payload = json.dumps({
49
+ "messages":history,
50
+ "stream": True
51
+ })
52
+ headers = {
53
+ 'Content-Type': 'application/json'
54
+ }
55
+
56
+ response = requests.request("POST", url, headers=headers, data=payload, stream=True)
57
+
58
+ if response.status_code == 200:
59
+ partial_text = ""
60
+ for line in response.iter_lines():
61
+ if len(line) == 0:
62
+ continue
63
+ line = json.loads(line[5:])
64
+ partial_text += line['result']
65
+ yield partial_text
66
+ else:
67
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
68
+
69
+
70
+ def get_answer_at_once(self):
71
+ url = self.ERNIE_url + self.get_access_token()
72
+ system_prompt = self.system_prompt
73
+ history = self.history
74
+ if system_prompt is not None:
75
+ history = [construct_system(system_prompt), *history]
76
+
77
+ # 去除history中 history的role为system的
78
+ history = [i for i in history if i["role"] != "system"]
79
+
80
+ payload = json.dumps({
81
+ "messages": history,
82
+ "stream": True
83
+ })
84
+ headers = {
85
+ 'Content-Type': 'application/json'
86
+ }
87
+
88
+ response = requests.request("POST", url, headers=headers, data=payload, stream=True)
89
+
90
+ if response.status_code == 200:
91
+
92
+ return str(response.json()["result"]),len(response.json()["result"])
93
+ else:
94
+ return "获取资源错误", 0
95
+
96
+
modules/models/base_model.py CHANGED
@@ -148,6 +148,7 @@ class ModelType(Enum):
148
  Claude = 14
149
  Qwen = 15
150
  OpenAIVision = 16
 
151
 
152
  @classmethod
153
  def get_type(cls, model_name: str):
@@ -188,6 +189,8 @@ class ModelType(Enum):
188
  model_type = ModelType.Claude
189
  elif "qwen" in model_name_lower:
190
  model_type = ModelType.Qwen
 
 
191
  else:
192
  model_type = ModelType.LLaMA
193
  return model_type
 
148
  Claude = 14
149
  Qwen = 15
150
  OpenAIVision = 16
151
+ ERNIE = 17
152
 
153
  @classmethod
154
  def get_type(cls, model_name: str):
 
189
  model_type = ModelType.Claude
190
  elif "qwen" in model_name_lower:
191
  model_type = ModelType.Qwen
192
+ elif "ernie" in model_name_lower:
193
+ model_type = ModelType.ERNIE
194
  else:
195
  model_type = ModelType.LLaMA
196
  return model_type
modules/models/models.py CHANGED
@@ -128,6 +128,9 @@ def get_model(
128
  elif model_type == ModelType.Qwen:
129
  from .Qwen import Qwen_Client
130
  model = Qwen_Client(model_name, user_name=user_name)
 
 
 
131
  elif model_type == ModelType.Unknown:
132
  raise ValueError(f"未知模型: {model_name}")
133
  logging.info(msg)
 
128
  elif model_type == ModelType.Qwen:
129
  from .Qwen import Qwen_Client
130
  model = Qwen_Client(model_name, user_name=user_name)
131
+ elif model_type == ModelType.ERNIE:
132
+ from .ERNIE import ERNIE_Client
133
+ model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
134
  elif model_type == ModelType.Unknown:
135
  raise ValueError(f"未知模型: {model_name}")
136
  logging.info(msg)
modules/presets.py CHANGED
@@ -74,7 +74,10 @@ ONLINE_MODELS = [
74
  "讯飞星火大模型V3.0",
75
  "讯飞星火大模型V2.0",
76
  "讯飞星火大模型V1.5",
77
- "Claude"
 
 
 
78
  ]
79
 
80
  LOCAL_MODELS = [
@@ -146,6 +149,18 @@ MODEL_METADATA = {
146
  "model_name": "Claude",
147
  "token_limit": 4096,
148
  },
 
 
 
 
 
 
 
 
 
 
 
 
149
  }
150
 
151
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
 
74
  "讯飞星火大模型V3.0",
75
  "讯飞星火大模型V2.0",
76
  "讯飞星火大模型V1.5",
77
+ "Claude",
78
+ "ERNIE-Bot-turbo",
79
+ "ERNIE-Bot",
80
+ "ERNIE-Bot-4",
81
  ]
82
 
83
  LOCAL_MODELS = [
 
149
  "model_name": "Claude",
150
  "token_limit": 4096,
151
  },
152
+ "ERNIE-Bot-turbo": {
153
+ "model_name": "ERNIE-Bot-turbo",
154
+ "token_limit": 1024,
155
+ },
156
+ "ERNIE-Bot": {
157
+ "model_name": "ERNIE-Bot",
158
+ "token_limit": 1024,
159
+ },
160
+ "ERNIE-Bot-4": {
161
+ "model_name": "ERNIE-Bot-4",
162
+ "token_limit": 1024,
163
+ },
164
  }
165
 
166
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':