Spaces:
Sleeping
Sleeping
Xudong Liu
commited on
增加对文心一言的支持,支持文心一言的三个主要模型。 (#931)
Browse files- config_example.json +2 -0
- modules/config.py +5 -0
- modules/models/ERNIE.py +96 -0
- modules/models/base_model.py +3 -0
- modules/models/models.py +3 -0
- modules/presets.py +16 -1
config_example.json
CHANGED
@@ -15,6 +15,8 @@
|
|
15 |
"spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
|
16 |
"spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
|
17 |
"claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
|
|
|
|
|
18 |
|
19 |
|
20 |
//== Azure ==
|
|
|
15 |
"spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
|
16 |
"spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
|
17 |
"claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
|
18 |
+
"ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
|
19 |
+
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
|
20 |
|
21 |
|
22 |
//== Azure ==
|
modules/config.py
CHANGED
@@ -135,6 +135,11 @@ os.environ["SPARK_API_SECRET"] = spark_api_secret
|
|
135 |
claude_api_secret = config.get("claude_api_secret", "")
|
136 |
os.environ["CLAUDE_API_SECRET"] = claude_api_secret
|
137 |
|
|
|
|
|
|
|
|
|
|
|
138 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
139 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
140 |
|
|
|
135 |
claude_api_secret = config.get("claude_api_secret", "")
|
136 |
os.environ["CLAUDE_API_SECRET"] = claude_api_secret
|
137 |
|
138 |
+
ernie_api_key = config.get("ernie_api_key", "")
|
139 |
+
os.environ["ERNIE_APIKEY"] = ernie_api_key
|
140 |
+
ernie_secret_key = config.get("ernie_secret_key", "")
|
141 |
+
os.environ["ERNIE_SECRETKEY"] = ernie_secret_key
|
142 |
+
|
143 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
144 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
145 |
|
modules/models/ERNIE.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..presets import *
|
2 |
+
from ..utils import *
|
3 |
+
|
4 |
+
from .base_model import BaseLLMModel
|
5 |
+
|
6 |
+
|
7 |
+
class ERNIE_Client(BaseLLMModel):
|
8 |
+
def __init__(self, model_name, api_key, secret_key) -> None:
|
9 |
+
super().__init__(model_name=model_name)
|
10 |
+
self.api_key = api_key
|
11 |
+
self.api_secret = secret_key
|
12 |
+
if None in [self.api_secret, self.api_key]:
|
13 |
+
raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key")
|
14 |
+
|
15 |
+
if self.model_name == "ERNIE-Bot-turbo":
|
16 |
+
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="
|
17 |
+
elif self.model_name == "ERNIE-Bot":
|
18 |
+
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
|
19 |
+
elif self.model_name == "ERNIE-Bot-4":
|
20 |
+
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="
|
21 |
+
|
22 |
+
def get_access_token(self):
|
23 |
+
"""
|
24 |
+
使用 AK,SK 生成鉴权签名(Access Token)
|
25 |
+
:return: access_token,或是None(如果错误)
|
26 |
+
"""
|
27 |
+
url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials"
|
28 |
+
|
29 |
+
payload = json.dumps("")
|
30 |
+
headers = {
|
31 |
+
'Content-Type': 'application/json',
|
32 |
+
'Accept': 'application/json'
|
33 |
+
}
|
34 |
+
|
35 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
36 |
+
|
37 |
+
return response.json()["access_token"]
|
38 |
+
def get_answer_stream_iter(self):
|
39 |
+
url = self.ERNIE_url + self.get_access_token()
|
40 |
+
system_prompt = self.system_prompt
|
41 |
+
history = self.history
|
42 |
+
if system_prompt is not None:
|
43 |
+
history = [construct_system(system_prompt), *history]
|
44 |
+
|
45 |
+
# 去除history中 history的role为system的
|
46 |
+
history = [i for i in history if i["role"] != "system"]
|
47 |
+
|
48 |
+
payload = json.dumps({
|
49 |
+
"messages":history,
|
50 |
+
"stream": True
|
51 |
+
})
|
52 |
+
headers = {
|
53 |
+
'Content-Type': 'application/json'
|
54 |
+
}
|
55 |
+
|
56 |
+
response = requests.request("POST", url, headers=headers, data=payload, stream=True)
|
57 |
+
|
58 |
+
if response.status_code == 200:
|
59 |
+
partial_text = ""
|
60 |
+
for line in response.iter_lines():
|
61 |
+
if len(line) == 0:
|
62 |
+
continue
|
63 |
+
line = json.loads(line[5:])
|
64 |
+
partial_text += line['result']
|
65 |
+
yield partial_text
|
66 |
+
else:
|
67 |
+
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
|
68 |
+
|
69 |
+
|
70 |
+
def get_answer_at_once(self):
|
71 |
+
url = self.ERNIE_url + self.get_access_token()
|
72 |
+
system_prompt = self.system_prompt
|
73 |
+
history = self.history
|
74 |
+
if system_prompt is not None:
|
75 |
+
history = [construct_system(system_prompt), *history]
|
76 |
+
|
77 |
+
# 去除history中 history的role为system的
|
78 |
+
history = [i for i in history if i["role"] != "system"]
|
79 |
+
|
80 |
+
payload = json.dumps({
|
81 |
+
"messages": history,
|
82 |
+
"stream": True
|
83 |
+
})
|
84 |
+
headers = {
|
85 |
+
'Content-Type': 'application/json'
|
86 |
+
}
|
87 |
+
|
88 |
+
response = requests.request("POST", url, headers=headers, data=payload, stream=True)
|
89 |
+
|
90 |
+
if response.status_code == 200:
|
91 |
+
|
92 |
+
return str(response.json()["result"]),len(response.json()["result"])
|
93 |
+
else:
|
94 |
+
return "获取资源错误", 0
|
95 |
+
|
96 |
+
|
modules/models/base_model.py
CHANGED
@@ -148,6 +148,7 @@ class ModelType(Enum):
|
|
148 |
Claude = 14
|
149 |
Qwen = 15
|
150 |
OpenAIVision = 16
|
|
|
151 |
|
152 |
@classmethod
|
153 |
def get_type(cls, model_name: str):
|
@@ -188,6 +189,8 @@ class ModelType(Enum):
|
|
188 |
model_type = ModelType.Claude
|
189 |
elif "qwen" in model_name_lower:
|
190 |
model_type = ModelType.Qwen
|
|
|
|
|
191 |
else:
|
192 |
model_type = ModelType.LLaMA
|
193 |
return model_type
|
|
|
148 |
Claude = 14
|
149 |
Qwen = 15
|
150 |
OpenAIVision = 16
|
151 |
+
ERNIE = 17
|
152 |
|
153 |
@classmethod
|
154 |
def get_type(cls, model_name: str):
|
|
|
189 |
model_type = ModelType.Claude
|
190 |
elif "qwen" in model_name_lower:
|
191 |
model_type = ModelType.Qwen
|
192 |
+
elif "ernie" in model_name_lower:
|
193 |
+
model_type = ModelType.ERNIE
|
194 |
else:
|
195 |
model_type = ModelType.LLaMA
|
196 |
return model_type
|
modules/models/models.py
CHANGED
@@ -128,6 +128,9 @@ def get_model(
|
|
128 |
elif model_type == ModelType.Qwen:
|
129 |
from .Qwen import Qwen_Client
|
130 |
model = Qwen_Client(model_name, user_name=user_name)
|
|
|
|
|
|
|
131 |
elif model_type == ModelType.Unknown:
|
132 |
raise ValueError(f"未知模型: {model_name}")
|
133 |
logging.info(msg)
|
|
|
128 |
elif model_type == ModelType.Qwen:
|
129 |
from .Qwen import Qwen_Client
|
130 |
model = Qwen_Client(model_name, user_name=user_name)
|
131 |
+
elif model_type == ModelType.ERNIE:
|
132 |
+
from .ERNIE import ERNIE_Client
|
133 |
+
model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
|
134 |
elif model_type == ModelType.Unknown:
|
135 |
raise ValueError(f"未知模型: {model_name}")
|
136 |
logging.info(msg)
|
modules/presets.py
CHANGED
@@ -74,7 +74,10 @@ ONLINE_MODELS = [
|
|
74 |
"讯飞星火大模型V3.0",
|
75 |
"讯飞星火大模型V2.0",
|
76 |
"讯飞星火大模型V1.5",
|
77 |
-
"Claude"
|
|
|
|
|
|
|
78 |
]
|
79 |
|
80 |
LOCAL_MODELS = [
|
@@ -146,6 +149,18 @@ MODEL_METADATA = {
|
|
146 |
"model_name": "Claude",
|
147 |
"token_limit": 4096,
|
148 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
}
|
150 |
|
151 |
if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
|
|
|
74 |
"讯飞星火大模型V3.0",
|
75 |
"讯飞星火大模型V2.0",
|
76 |
"讯飞星火大模型V1.5",
|
77 |
+
"Claude",
|
78 |
+
"ERNIE-Bot-turbo",
|
79 |
+
"ERNIE-Bot",
|
80 |
+
"ERNIE-Bot-4",
|
81 |
]
|
82 |
|
83 |
LOCAL_MODELS = [
|
|
|
149 |
"model_name": "Claude",
|
150 |
"token_limit": 4096,
|
151 |
},
|
152 |
+
"ERNIE-Bot-turbo": {
|
153 |
+
"model_name": "ERNIE-Bot-turbo",
|
154 |
+
"token_limit": 1024,
|
155 |
+
},
|
156 |
+
"ERNIE-Bot": {
|
157 |
+
"model_name": "ERNIE-Bot",
|
158 |
+
"token_limit": 1024,
|
159 |
+
},
|
160 |
+
"ERNIE-Bot-4": {
|
161 |
+
"model_name": "ERNIE-Bot-4",
|
162 |
+
"token_limit": 1024,
|
163 |
+
},
|
164 |
}
|
165 |
|
166 |
if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
|