yym68686 commited on
Commit
cb6cbda
·
1 Parent(s): 0f410b4

✨ Feature: Add feature: support Gemini API tool use

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. request.py +56 -7
  3. response.py +21 -13
  4. test/test_vertex copy.py +190 -0
README.md CHANGED
@@ -44,7 +44,7 @@ providers:
44
  tools: true
45
 
46
  - provider: gemini
47
- base_url: https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key} # base_url 支持变量替换,{model} 会被替换为模型名称,{stream} 会被替换为 stream 参数,{api_key} 会被替换为 api_key 参数, 仅供 Gemini 模型使用,必填
48
  api: AIzaSyAN2k6IRdgw
49
  model:
50
  - gemini-1.5-pro
 
44
  tools: true
45
 
46
  - provider: gemini
47
+ base_url: https://generativelanguage.googleapis.com/v1beta # base_url 支持 v1beta/v1, 仅供 Gemini 模型使用,必填
48
  api: AIzaSyAN2k6IRdgw
49
  model:
50
  - gemini-1.5-pro
request.py CHANGED
@@ -39,32 +39,70 @@ async def get_gemini_payload(request, engine, provider):
39
  headers = {
40
  'Content-Type': 'application/json'
41
  }
42
- url = provider['base_url']
43
  model = provider['model'][request.model]
44
  if request.stream:
45
  gemini_stream = "streamGenerateContent"
46
- url = url.format(model=model, stream=gemini_stream, api_key=provider['api'])
 
 
 
 
47
 
48
  messages = []
49
  systemInstruction = None
 
50
  for msg in request.messages:
51
  if msg.role == "assistant":
52
  msg.role = "model"
 
53
  if isinstance(msg.content, list):
54
  content = []
55
  for item in msg.content:
56
  if item.type == "text":
57
  text_message = await get_text_message(msg.role, item.text, engine)
58
- # print("text_message", text_message)
59
  content.append(text_message)
60
  elif item.type == "image_url":
61
  image_message = await get_image_message(item.image_url.url, engine)
62
  content.append(image_message)
63
  else:
64
  content = [{"text": msg.content}]
65
- if msg.role != "system":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  messages.append({"role": msg.role, "parts": content})
67
- if msg.role == "system":
68
  systemInstruction = {"parts": content}
69
 
70
 
@@ -96,7 +134,6 @@ async def get_gemini_payload(request, engine, provider):
96
  'model',
97
  'messages',
98
  'stream',
99
- 'tools',
100
  'tool_choice',
101
  'temperature',
102
  'top_p',
@@ -112,7 +149,19 @@ async def get_gemini_payload(request, engine, provider):
112
 
113
  for field, value in request.model_dump(exclude_unset=True).items():
114
  if field not in miss_fields and value is not None:
115
- payload[field] = value
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  return url, headers, payload
118
 
 
39
  headers = {
40
  'Content-Type': 'application/json'
41
  }
 
42
  model = provider['model'][request.model]
43
  if request.stream:
44
  gemini_stream = "streamGenerateContent"
45
+ url = provider['base_url']
46
+ if url.endswith("v1beta"):
47
+ url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['api'])
48
+ if url.endswith("v1"):
49
+ url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['api'])
50
 
51
  messages = []
52
  systemInstruction = None
53
+ function_arguments = None
54
  for msg in request.messages:
55
  if msg.role == "assistant":
56
  msg.role = "model"
57
+ tool_calls = None
58
  if isinstance(msg.content, list):
59
  content = []
60
  for item in msg.content:
61
  if item.type == "text":
62
  text_message = await get_text_message(msg.role, item.text, engine)
 
63
  content.append(text_message)
64
  elif item.type == "image_url":
65
  image_message = await get_image_message(item.image_url.url, engine)
66
  content.append(image_message)
67
  else:
68
  content = [{"text": msg.content}]
69
+ tool_calls = msg.tool_calls
70
+
71
+ if tool_calls:
72
+ tool_call = tool_calls[0]
73
+ function_arguments = {
74
+ "functionCall": {
75
+ "name": tool_call.function.name,
76
+ "args": json.loads(tool_call.function.arguments)
77
+ }
78
+ }
79
+ messages.append(
80
+ {
81
+ "role": "model",
82
+ "parts": [function_arguments]
83
+ }
84
+ )
85
+ elif msg.role == "tool":
86
+ function_call_name = function_arguments["functionCall"]["name"]
87
+ messages.append(
88
+ {
89
+ "role": "function",
90
+ "parts": [{
91
+ "functionResponse": {
92
+ "name": function_call_name,
93
+ "response": {
94
+ "name": function_call_name,
95
+ "content": {
96
+ "result": msg.content,
97
+ }
98
+ }
99
+ }
100
+ }]
101
+ }
102
+ )
103
+ elif msg.role != "system":
104
  messages.append({"role": msg.role, "parts": content})
105
+ elif msg.role == "system":
106
  systemInstruction = {"parts": content}
107
 
108
 
 
134
  'model',
135
  'messages',
136
  'stream',
 
137
  'tool_choice',
138
  'temperature',
139
  'top_p',
 
149
 
150
  for field, value in request.model_dump(exclude_unset=True).items():
151
  if field not in miss_fields and value is not None:
152
+ if field == "tools":
153
+ payload.update({
154
+ "tools": [{
155
+ "function_declarations": [tool["function"] for tool in value]
156
+ }],
157
+ "tool_config": {
158
+ "function_calling_config": {
159
+ "mode": "AUTO"
160
+ }
161
+ }
162
+ })
163
+ else:
164
+ payload[field] = value
165
 
166
  return url, headers, payload
167
 
response.py CHANGED
@@ -25,7 +25,7 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
25
  if function_call_content:
26
  sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
27
  if tools_id and function_call_name:
28
- sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
29
  # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
30
  if role:
31
  sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
@@ -48,6 +48,9 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
48
  error_json = error_str
49
  yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
50
  buffer = ""
 
 
 
51
  async for chunk in response.aiter_text():
52
  buffer += chunk
53
  while "\n" in buffer:
@@ -63,18 +66,23 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
63
  except json.JSONDecodeError:
64
  logger.error(f"无法解析JSON: {line}")
65
 
66
- # # 处理缓冲区中剩余的内容
67
- # if buffer:
68
- # # print(buffer)
69
- # if '\"text\": \"' in buffer:
70
- # try:
71
- # json_data = json.loads(buffer)
72
- # content = json_data.get('text', '')
73
- # content = "\n".join(content.split("\\n"))
74
- # sse_string = await generate_sse_response(timestamp, model, content)
75
- # yield sse_string
76
- # except json.JSONDecodeError:
77
- # print(f"无法解析JSON: {buffer}")
 
 
 
 
 
78
 
79
  async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
80
  redirect_count = 0
 
25
  if function_call_content:
26
  sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
27
  if tools_id and function_call_name:
28
+ sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id": tools_id,"type":"function","function":{"name": function_call_name, "arguments":""}}]}
29
  # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
30
  if role:
31
  sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
 
48
  error_json = error_str
49
  yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
50
  buffer = ""
51
+ revicing_function_call = False
52
+ function_full_response = "{"
53
+ need_function_call = False
54
  async for chunk in response.aiter_text():
55
  buffer += chunk
56
  while "\n" in buffer:
 
66
  except json.JSONDecodeError:
67
  logger.error(f"无法解析JSON: {line}")
68
 
69
+ if line and ('\"functionCall\": {' in line or revicing_function_call):
70
+ revicing_function_call = True
71
+ need_function_call = True
72
+ if ']' in line:
73
+ revicing_function_call = False
74
+ continue
75
+
76
+ function_full_response += line
77
+
78
+ if need_function_call:
79
+ function_call = json.loads(function_full_response)
80
+ function_call_name = function_call["functionCall"]["name"]
81
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name)
82
+ yield sse_string
83
+ function_full_response = json.dumps(function_call["functionCall"]["args"])
84
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
85
+ yield sse_string
86
 
87
  async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
88
  redirect_count = 0
test/test_vertex copy.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import base64
3
+ import time
4
+ import httpx
5
+ from cryptography.hazmat.primitives import hashes
6
+ from cryptography.hazmat.primitives.asymmetric import padding
7
+ from cryptography.hazmat.primitives.serialization import load_pem_private_key
8
+
9
+ # 您的服务账号密钥(请将其保存在安全的地方,不要公开分享)
10
+ def create_jwt(client_email, private_key):
11
+ # JWT Header
12
+ header = json.dumps({
13
+ "alg": "RS256",
14
+ "typ": "JWT"
15
+ }).encode()
16
+
17
+ # JWT Payload
18
+ now = int(time.time())
19
+ payload = json.dumps({
20
+ "iss": client_email,
21
+ "scope": "https://www.googleapis.com/auth/cloud-platform",
22
+ "aud": "https://oauth2.googleapis.com/token",
23
+ "exp": now + 3600,
24
+ "iat": now
25
+ }).encode()
26
+
27
+ # Encode header and payload
28
+ segments = [
29
+ base64.urlsafe_b64encode(header).rstrip(b'='),
30
+ base64.urlsafe_b64encode(payload).rstrip(b'=')
31
+ ]
32
+
33
+ # Create signature
34
+ signing_input = b'.'.join(segments)
35
+ private_key = load_pem_private_key(private_key.encode(), password=None)
36
+ signature = private_key.sign(
37
+ signing_input,
38
+ padding.PKCS1v15(),
39
+ hashes.SHA256()
40
+ )
41
+
42
+ segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
43
+ return b'.'.join(segments).decode()
44
+
45
+ def get_access_token(client_email, private_key):
46
+ jwt = create_jwt(client_email, private_key)
47
+
48
+ with httpx.Client() as client:
49
+ response = client.post(
50
+ "https://oauth2.googleapis.com/token",
51
+ data={
52
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
53
+ "assertion": jwt
54
+ },
55
+ headers={'Content-Type': "application/x-www-form-urlencoded"}
56
+ )
57
+ response.raise_for_status()
58
+ return response.json()["access_token"]
59
+
60
+ def ask_stream(prompt, client_email, private_key, project_id, engine):
61
+ payload = {
62
+ "contents": [
63
+ {
64
+ "role": "user",
65
+ "parts": [
66
+ {
67
+ "text": prompt
68
+ }
69
+ ]
70
+ }
71
+ ],
72
+ "system_instruction": {
73
+ "parts": [
74
+ {
75
+ "text": "You are Gemini, a large language model trained by Google. Respond conversationally"
76
+ }
77
+ ]
78
+ },
79
+ # "safety_settings": [
80
+ # {
81
+ # "category": "HARM_CATEGORY_HARASSMENT",
82
+ # "threshold": "BLOCK_NONE"
83
+ # },
84
+ # {
85
+ # "category": "HARM_CATEGORY_HATE_SPEECH",
86
+ # "threshold": "BLOCK_NONE"
87
+ # },
88
+ # {
89
+ # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
90
+ # "threshold": "BLOCK_NONE"
91
+ # },
92
+ # {
93
+ # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
94
+ # "threshold": "BLOCK_NONE"
95
+ # }
96
+ # ],
97
+ "generationConfig": {
98
+ "temperature": 0.5,
99
+ "max_output_tokens": 256,
100
+ "top_k": 40,
101
+ "top_p": 0.95
102
+ },
103
+ "tools": [
104
+ {
105
+ "function_declarations": [
106
+ {
107
+ "name": "get_search_results",
108
+ "description": "Search Google to enhance knowledge.",
109
+ "parameters": {
110
+ "type": "object",
111
+ "properties": {
112
+ "prompt": {
113
+ "type": "string",
114
+ "description": "The prompt to search."
115
+ }
116
+ },
117
+ "required": [
118
+ "prompt"
119
+ ]
120
+ }
121
+ },
122
+ {
123
+ "name": "get_url_content",
124
+ "description": "Get the webpage content of a URL",
125
+ "parameters": {
126
+ "type": "object",
127
+ "properties": {
128
+ "url": {
129
+ "type": "string",
130
+ "description": "the URL to request"
131
+ }
132
+ },
133
+ "required": [
134
+ "url"
135
+ ]
136
+ }
137
+ }
138
+ ]
139
+ }
140
+ ],
141
+ "tool_config": {
142
+ "function_calling_config": {
143
+ "mode": "AUTO"
144
+ }
145
+ }
146
+ }
147
+ # payload = {
148
+ # "contents": [
149
+ # {
150
+ # "role": "user",
151
+ # "parts": [
152
+ # {
153
+ # "text": prompt
154
+ # }
155
+ # ]
156
+ # },
157
+ # ],
158
+ # "generationConfig": {
159
+ # "temperature": 0.2,
160
+ # "maxOutputTokens": 256,
161
+ # "topK": 40,
162
+ # "topP": 0.95
163
+ # }
164
+ # }
165
+
166
+ access_token = get_access_token(client_email, private_key)
167
+ headers = {
168
+ 'Authorization': f"Bearer {access_token}",
169
+ 'Content-Type': "application/json"
170
+ }
171
+
172
+ MODEL_ID = engine
173
+ PROJECT_ID = project_id
174
+ stream = "generateContent"
175
+ with httpx.Client() as client:
176
+ response = client.post(
177
+ f"https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}",
178
+ json=payload,
179
+ headers=headers,
180
+ timeout=600,
181
+ )
182
+ response.raise_for_status()
183
+ return response.json()
184
+
185
+ # 使用示例
186
+ client_email, private_key, project_id = SERVICE_ACCOUNT_KEY["client_email"], SERVICE_ACCOUNT_KEY["private_key"], SERVICE_ACCOUNT_KEY["project_id"]
187
+ engine = "gemini-1.5-pro"
188
+ user_input = input("请输入您的问题: ")
189
+ result = ask_stream(user_input, client_email, private_key, project_id, engine)
190
+ print(json.dumps(result, ensure_ascii=False, indent=2))