✨ Feature: Add feature: support Gemini API tool use
Browse files- README.md +1 -1
- request.py +56 -7
- response.py +21 -13
- test/test_vertex copy.py +190 -0
README.md
CHANGED
@@ -44,7 +44,7 @@ providers:
|
|
44 |
tools: true
|
45 |
|
46 |
- provider: gemini
|
47 |
-
base_url: https://generativelanguage.googleapis.com/
|
48 |
api: AIzaSyAN2k6IRdgw
|
49 |
model:
|
50 |
- gemini-1.5-pro
|
|
|
44 |
tools: true
|
45 |
|
46 |
- provider: gemini
|
47 |
+
base_url: https://generativelanguage.googleapis.com/v1beta # base_url 支持 v1beta/v1, 仅供 Gemini 模型使用,必填
|
48 |
api: AIzaSyAN2k6IRdgw
|
49 |
model:
|
50 |
- gemini-1.5-pro
|
request.py
CHANGED
@@ -39,32 +39,70 @@ async def get_gemini_payload(request, engine, provider):
|
|
39 |
headers = {
|
40 |
'Content-Type': 'application/json'
|
41 |
}
|
42 |
-
url = provider['base_url']
|
43 |
model = provider['model'][request.model]
|
44 |
if request.stream:
|
45 |
gemini_stream = "streamGenerateContent"
|
46 |
-
url =
|
|
|
|
|
|
|
|
|
47 |
|
48 |
messages = []
|
49 |
systemInstruction = None
|
|
|
50 |
for msg in request.messages:
|
51 |
if msg.role == "assistant":
|
52 |
msg.role = "model"
|
|
|
53 |
if isinstance(msg.content, list):
|
54 |
content = []
|
55 |
for item in msg.content:
|
56 |
if item.type == "text":
|
57 |
text_message = await get_text_message(msg.role, item.text, engine)
|
58 |
-
# print("text_message", text_message)
|
59 |
content.append(text_message)
|
60 |
elif item.type == "image_url":
|
61 |
image_message = await get_image_message(item.image_url.url, engine)
|
62 |
content.append(image_message)
|
63 |
else:
|
64 |
content = [{"text": msg.content}]
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
messages.append({"role": msg.role, "parts": content})
|
67 |
-
|
68 |
systemInstruction = {"parts": content}
|
69 |
|
70 |
|
@@ -96,7 +134,6 @@ async def get_gemini_payload(request, engine, provider):
|
|
96 |
'model',
|
97 |
'messages',
|
98 |
'stream',
|
99 |
-
'tools',
|
100 |
'tool_choice',
|
101 |
'temperature',
|
102 |
'top_p',
|
@@ -112,7 +149,19 @@ async def get_gemini_payload(request, engine, provider):
|
|
112 |
|
113 |
for field, value in request.model_dump(exclude_unset=True).items():
|
114 |
if field not in miss_fields and value is not None:
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
return url, headers, payload
|
118 |
|
|
|
39 |
headers = {
|
40 |
'Content-Type': 'application/json'
|
41 |
}
|
|
|
42 |
model = provider['model'][request.model]
|
43 |
if request.stream:
|
44 |
gemini_stream = "streamGenerateContent"
|
45 |
+
url = provider['base_url']
|
46 |
+
if url.endswith("v1beta"):
|
47 |
+
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['api'])
|
48 |
+
if url.endswith("v1"):
|
49 |
+
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['api'])
|
50 |
|
51 |
messages = []
|
52 |
systemInstruction = None
|
53 |
+
function_arguments = None
|
54 |
for msg in request.messages:
|
55 |
if msg.role == "assistant":
|
56 |
msg.role = "model"
|
57 |
+
tool_calls = None
|
58 |
if isinstance(msg.content, list):
|
59 |
content = []
|
60 |
for item in msg.content:
|
61 |
if item.type == "text":
|
62 |
text_message = await get_text_message(msg.role, item.text, engine)
|
|
|
63 |
content.append(text_message)
|
64 |
elif item.type == "image_url":
|
65 |
image_message = await get_image_message(item.image_url.url, engine)
|
66 |
content.append(image_message)
|
67 |
else:
|
68 |
content = [{"text": msg.content}]
|
69 |
+
tool_calls = msg.tool_calls
|
70 |
+
|
71 |
+
if tool_calls:
|
72 |
+
tool_call = tool_calls[0]
|
73 |
+
function_arguments = {
|
74 |
+
"functionCall": {
|
75 |
+
"name": tool_call.function.name,
|
76 |
+
"args": json.loads(tool_call.function.arguments)
|
77 |
+
}
|
78 |
+
}
|
79 |
+
messages.append(
|
80 |
+
{
|
81 |
+
"role": "model",
|
82 |
+
"parts": [function_arguments]
|
83 |
+
}
|
84 |
+
)
|
85 |
+
elif msg.role == "tool":
|
86 |
+
function_call_name = function_arguments["functionCall"]["name"]
|
87 |
+
messages.append(
|
88 |
+
{
|
89 |
+
"role": "function",
|
90 |
+
"parts": [{
|
91 |
+
"functionResponse": {
|
92 |
+
"name": function_call_name,
|
93 |
+
"response": {
|
94 |
+
"name": function_call_name,
|
95 |
+
"content": {
|
96 |
+
"result": msg.content,
|
97 |
+
}
|
98 |
+
}
|
99 |
+
}
|
100 |
+
}]
|
101 |
+
}
|
102 |
+
)
|
103 |
+
elif msg.role != "system":
|
104 |
messages.append({"role": msg.role, "parts": content})
|
105 |
+
elif msg.role == "system":
|
106 |
systemInstruction = {"parts": content}
|
107 |
|
108 |
|
|
|
134 |
'model',
|
135 |
'messages',
|
136 |
'stream',
|
|
|
137 |
'tool_choice',
|
138 |
'temperature',
|
139 |
'top_p',
|
|
|
149 |
|
150 |
for field, value in request.model_dump(exclude_unset=True).items():
|
151 |
if field not in miss_fields and value is not None:
|
152 |
+
if field == "tools":
|
153 |
+
payload.update({
|
154 |
+
"tools": [{
|
155 |
+
"function_declarations": [tool["function"] for tool in value]
|
156 |
+
}],
|
157 |
+
"tool_config": {
|
158 |
+
"function_calling_config": {
|
159 |
+
"mode": "AUTO"
|
160 |
+
}
|
161 |
+
}
|
162 |
+
})
|
163 |
+
else:
|
164 |
+
payload[field] = value
|
165 |
|
166 |
return url, headers, payload
|
167 |
|
response.py
CHANGED
@@ -25,7 +25,7 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
25 |
if function_call_content:
|
26 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
|
27 |
if tools_id and function_call_name:
|
28 |
-
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
|
29 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
30 |
if role:
|
31 |
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
@@ -48,6 +48,9 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
48 |
error_json = error_str
|
49 |
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
50 |
buffer = ""
|
|
|
|
|
|
|
51 |
async for chunk in response.aiter_text():
|
52 |
buffer += chunk
|
53 |
while "\n" in buffer:
|
@@ -63,18 +66,23 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
63 |
except json.JSONDecodeError:
|
64 |
logger.error(f"无法解析JSON: {line}")
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
|
80 |
redirect_count = 0
|
|
|
25 |
if function_call_content:
|
26 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
|
27 |
if tools_id and function_call_name:
|
28 |
+
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id": tools_id,"type":"function","function":{"name": function_call_name, "arguments":""}}]}
|
29 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
30 |
if role:
|
31 |
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
|
|
48 |
error_json = error_str
|
49 |
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
50 |
buffer = ""
|
51 |
+
revicing_function_call = False
|
52 |
+
function_full_response = "{"
|
53 |
+
need_function_call = False
|
54 |
async for chunk in response.aiter_text():
|
55 |
buffer += chunk
|
56 |
while "\n" in buffer:
|
|
|
66 |
except json.JSONDecodeError:
|
67 |
logger.error(f"无法解析JSON: {line}")
|
68 |
|
69 |
+
if line and ('\"functionCall\": {' in line or revicing_function_call):
|
70 |
+
revicing_function_call = True
|
71 |
+
need_function_call = True
|
72 |
+
if ']' in line:
|
73 |
+
revicing_function_call = False
|
74 |
+
continue
|
75 |
+
|
76 |
+
function_full_response += line
|
77 |
+
|
78 |
+
if need_function_call:
|
79 |
+
function_call = json.loads(function_full_response)
|
80 |
+
function_call_name = function_call["functionCall"]["name"]
|
81 |
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name)
|
82 |
+
yield sse_string
|
83 |
+
function_full_response = json.dumps(function_call["functionCall"]["args"])
|
84 |
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
85 |
+
yield sse_string
|
86 |
|
87 |
async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
|
88 |
redirect_count = 0
|
test/test_vertex copy.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import base64
|
3 |
+
import time
|
4 |
+
import httpx
|
5 |
+
from cryptography.hazmat.primitives import hashes
|
6 |
+
from cryptography.hazmat.primitives.asymmetric import padding
|
7 |
+
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
8 |
+
|
9 |
+
# 您的服务账号密钥(请将其保存在安全的地方,不要公开分享)
|
10 |
+
def create_jwt(client_email, private_key):
|
11 |
+
# JWT Header
|
12 |
+
header = json.dumps({
|
13 |
+
"alg": "RS256",
|
14 |
+
"typ": "JWT"
|
15 |
+
}).encode()
|
16 |
+
|
17 |
+
# JWT Payload
|
18 |
+
now = int(time.time())
|
19 |
+
payload = json.dumps({
|
20 |
+
"iss": client_email,
|
21 |
+
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
22 |
+
"aud": "https://oauth2.googleapis.com/token",
|
23 |
+
"exp": now + 3600,
|
24 |
+
"iat": now
|
25 |
+
}).encode()
|
26 |
+
|
27 |
+
# Encode header and payload
|
28 |
+
segments = [
|
29 |
+
base64.urlsafe_b64encode(header).rstrip(b'='),
|
30 |
+
base64.urlsafe_b64encode(payload).rstrip(b'=')
|
31 |
+
]
|
32 |
+
|
33 |
+
# Create signature
|
34 |
+
signing_input = b'.'.join(segments)
|
35 |
+
private_key = load_pem_private_key(private_key.encode(), password=None)
|
36 |
+
signature = private_key.sign(
|
37 |
+
signing_input,
|
38 |
+
padding.PKCS1v15(),
|
39 |
+
hashes.SHA256()
|
40 |
+
)
|
41 |
+
|
42 |
+
segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
|
43 |
+
return b'.'.join(segments).decode()
|
44 |
+
|
45 |
+
def get_access_token(client_email, private_key):
|
46 |
+
jwt = create_jwt(client_email, private_key)
|
47 |
+
|
48 |
+
with httpx.Client() as client:
|
49 |
+
response = client.post(
|
50 |
+
"https://oauth2.googleapis.com/token",
|
51 |
+
data={
|
52 |
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
53 |
+
"assertion": jwt
|
54 |
+
},
|
55 |
+
headers={'Content-Type': "application/x-www-form-urlencoded"}
|
56 |
+
)
|
57 |
+
response.raise_for_status()
|
58 |
+
return response.json()["access_token"]
|
59 |
+
|
60 |
+
def ask_stream(prompt, client_email, private_key, project_id, engine):
|
61 |
+
payload = {
|
62 |
+
"contents": [
|
63 |
+
{
|
64 |
+
"role": "user",
|
65 |
+
"parts": [
|
66 |
+
{
|
67 |
+
"text": prompt
|
68 |
+
}
|
69 |
+
]
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"system_instruction": {
|
73 |
+
"parts": [
|
74 |
+
{
|
75 |
+
"text": "You are Gemini, a large language model trained by Google. Respond conversationally"
|
76 |
+
}
|
77 |
+
]
|
78 |
+
},
|
79 |
+
# "safety_settings": [
|
80 |
+
# {
|
81 |
+
# "category": "HARM_CATEGORY_HARASSMENT",
|
82 |
+
# "threshold": "BLOCK_NONE"
|
83 |
+
# },
|
84 |
+
# {
|
85 |
+
# "category": "HARM_CATEGORY_HATE_SPEECH",
|
86 |
+
# "threshold": "BLOCK_NONE"
|
87 |
+
# },
|
88 |
+
# {
|
89 |
+
# "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
90 |
+
# "threshold": "BLOCK_NONE"
|
91 |
+
# },
|
92 |
+
# {
|
93 |
+
# "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
94 |
+
# "threshold": "BLOCK_NONE"
|
95 |
+
# }
|
96 |
+
# ],
|
97 |
+
"generationConfig": {
|
98 |
+
"temperature": 0.5,
|
99 |
+
"max_output_tokens": 256,
|
100 |
+
"top_k": 40,
|
101 |
+
"top_p": 0.95
|
102 |
+
},
|
103 |
+
"tools": [
|
104 |
+
{
|
105 |
+
"function_declarations": [
|
106 |
+
{
|
107 |
+
"name": "get_search_results",
|
108 |
+
"description": "Search Google to enhance knowledge.",
|
109 |
+
"parameters": {
|
110 |
+
"type": "object",
|
111 |
+
"properties": {
|
112 |
+
"prompt": {
|
113 |
+
"type": "string",
|
114 |
+
"description": "The prompt to search."
|
115 |
+
}
|
116 |
+
},
|
117 |
+
"required": [
|
118 |
+
"prompt"
|
119 |
+
]
|
120 |
+
}
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"name": "get_url_content",
|
124 |
+
"description": "Get the webpage content of a URL",
|
125 |
+
"parameters": {
|
126 |
+
"type": "object",
|
127 |
+
"properties": {
|
128 |
+
"url": {
|
129 |
+
"type": "string",
|
130 |
+
"description": "the URL to request"
|
131 |
+
}
|
132 |
+
},
|
133 |
+
"required": [
|
134 |
+
"url"
|
135 |
+
]
|
136 |
+
}
|
137 |
+
}
|
138 |
+
]
|
139 |
+
}
|
140 |
+
],
|
141 |
+
"tool_config": {
|
142 |
+
"function_calling_config": {
|
143 |
+
"mode": "AUTO"
|
144 |
+
}
|
145 |
+
}
|
146 |
+
}
|
147 |
+
# payload = {
|
148 |
+
# "contents": [
|
149 |
+
# {
|
150 |
+
# "role": "user",
|
151 |
+
# "parts": [
|
152 |
+
# {
|
153 |
+
# "text": prompt
|
154 |
+
# }
|
155 |
+
# ]
|
156 |
+
# },
|
157 |
+
# ],
|
158 |
+
# "generationConfig": {
|
159 |
+
# "temperature": 0.2,
|
160 |
+
# "maxOutputTokens": 256,
|
161 |
+
# "topK": 40,
|
162 |
+
# "topP": 0.95
|
163 |
+
# }
|
164 |
+
# }
|
165 |
+
|
166 |
+
access_token = get_access_token(client_email, private_key)
|
167 |
+
headers = {
|
168 |
+
'Authorization': f"Bearer {access_token}",
|
169 |
+
'Content-Type': "application/json"
|
170 |
+
}
|
171 |
+
|
172 |
+
MODEL_ID = engine
|
173 |
+
PROJECT_ID = project_id
|
174 |
+
stream = "generateContent"
|
175 |
+
with httpx.Client() as client:
|
176 |
+
response = client.post(
|
177 |
+
f"https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}",
|
178 |
+
json=payload,
|
179 |
+
headers=headers,
|
180 |
+
timeout=600,
|
181 |
+
)
|
182 |
+
response.raise_for_status()
|
183 |
+
return response.json()
|
184 |
+
|
185 |
+
# 使用示例
|
186 |
+
client_email, private_key, project_id = SERVICE_ACCOUNT_KEY["client_email"], SERVICE_ACCOUNT_KEY["private_key"], SERVICE_ACCOUNT_KEY["project_id"]
|
187 |
+
engine = "gemini-1.5-pro"
|
188 |
+
user_input = input("请输入您的问题: ")
|
189 |
+
result = ask_stream(user_input, client_email, private_key, project_id, engine)
|
190 |
+
print(json.dumps(result, ensure_ascii=False, indent=2))
|