File size: 8,834 Bytes
b0ed2e6 d35ad8f b0ed2e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
import time
import httpx
import warnings
from typing import List, Dict, Optional
from smolagents import ApiModel, ChatMessage
class GeminiApiModel(ApiModel):
"""
ApiModel implementation using the Google Gemini API via direct HTTP requests.
"""
def __init__(
self,
model_id: str = "gemini-pro",
api_key: Optional[str] = None,
**kwargs,
):
"""
Initializes the GeminiApiModel.
Args:
model_id (str): The Gemini model ID to use (e.g., "gemini-pro").
api_key (str, optional): Google AI Studio API key. Defaults to GEMINI_API_KEY environment variable.
**kwargs: Additional keyword arguments passed to the parent ApiModel.
"""
self.model_id = model_id
# Prefer explicitly passed key, fallback to environment variable
self.api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY")
if not self.api_key:
warnings.warn(
"GEMINI_API_KEY not provided via argument or environment variable. API calls will likely fail.",
UserWarning,
)
# Gemini API doesn't inherently support complex role structures or function calling like OpenAI.
# We'll flatten messages for simplicity.
super().__init__(
model_id=model_id,
flatten_messages_as_text=True, # Flatten messages to a single text prompt
**kwargs,
)
def create_client(self):
"""No dedicated client needed as we use httpx directly."""
return None # Or potentially return httpx client if reused
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[
List[str]
] = None, # Note: Gemini API might not support stop sequences directly here
grammar: Optional[
str
] = None, # Note: Gemini API doesn't support grammar directly
tools_to_call_from: Optional[
List["Tool"]
] = None, # Note: Basic Gemini API doesn't support tools
**kwargs,
) -> ChatMessage:
"""
Calls the Google Gemini API with the provided messages.
Args:
messages: A list of message dictionaries (e.g., [{'role': 'user', 'content': '...'}]).
stop_sequences: Optional stop sequences (may not be supported).
grammar: Optional grammar constraint (not supported).
tools_to_call_from: Optional list of tools (not supported).
**kwargs: Additional keyword arguments.
Returns:
A ChatMessage object containing the response.
"""
if not self.api_key:
raise ValueError("GEMINI_API_KEY is not set.")
# Prepare the prompt by concatenating message content
# The Gemini Pro basic API expects a simple text prompt.
prompt = self._messages_to_prompt(messages)
prompt += (
"\n\n"
+ "If you have a result from a web search that looks helpful, please use httpx to get the HTML from the URL listed."
+ "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
)
# print(f"--- Gemini API prompt: ---\n{prompt}\n--- End of prompt ---")
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_id}:generateContent?key={self.api_key}"
headers = {"Content-Type": "application/json"}
# Construct the payload according to Gemini API requirements
data = {"contents": [{"parts": [{"text": prompt}]}]}
# Add generation config if provided via kwargs (optional)
generation_config = {}
if "temperature" in kwargs:
generation_config["temperature"] = kwargs["temperature"]
if "max_output_tokens" in kwargs:
generation_config["maxOutputTokens"] = kwargs["max_output_tokens"]
# Add other relevant config parameters here if needed
if generation_config:
data["generationConfig"] = generation_config
# Handle stop sequences if provided (basic support)
# Note: This is a best-effort addition, check Gemini API docs for formal support
if stop_sequences:
if "generationConfig" not in data:
data["generationConfig"] = {}
# Assuming Gemini API might support 'stopSequences' in generationConfig
data["generationConfig"]["stopSequences"] = stop_sequences
raw_response = None
try:
response = httpx.post(
url, headers=headers, json=data, timeout=120.0
) # Increased timeout
time.sleep(6) # Add delay to respect rate limits
response.raise_for_status()
response_json = response.json()
raw_response = response_json # Store raw response
# Parse the response - adjust based on actual Gemini API structure
if "candidates" in response_json and response_json["candidates"]:
part = response_json["candidates"][0]["content"]["parts"][0]
if "text" in part:
content = part["text"]
# Check for "FINAL ANSWER: " and extract the rest of the string
final_answer_marker = "FINAL ANSWER: "
if final_answer_marker in content:
content = content.split(final_answer_marker)[-1].strip()
# Simulate token counts if available, otherwise default to 0
# The basic generateContent API might not return usage directly in the main response
# It might be in safetyRatings or other metadata if enabled/available.
# Setting to 0 for now as it's not reliably present in the simplest call.
self.last_input_token_count = 0
self.last_output_token_count = 0
# If usage data becomes available in response_json, parse it here:
# e.g., if response_json.get("usageMetadata"):
# self.last_input_token_count = response_json["usageMetadata"].get("promptTokenCount", 0)
# self.last_output_token_count = response_json["usageMetadata"].get("candidatesTokenCount", 0)
return ChatMessage(
role="assistant", content=content, raw=raw_response
)
# Handle cases where the expected response structure isn't found
error_content = f"Error or unexpected response format: {response_json}"
return ChatMessage(
role="assistant", content=error_content, raw=raw_response
)
except httpx.RequestError as exc:
error_content = (
f"An error occurred while requesting {exc.request.url!r}: {exc}"
)
return ChatMessage(
role="assistant", content=error_content, raw={"error": str(exc)}
)
except httpx.HTTPStatusError as exc:
error_content = f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}"
return ChatMessage(
role="assistant",
content=error_content,
raw={
"error": str(exc),
"status_code": exc.response.status_code,
"response_text": exc.response.text,
},
)
except Exception as e:
error_content = f"An unexpected error occurred: {e}"
return ChatMessage(
role="assistant", content=error_content, raw={"error": str(e)}
)
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
"""Converts a list of messages into a single string prompt."""
# Simple concatenation, could be more sophisticated based on roles if needed
# Ensure we handle cases where 'content' might not be a string (though it should be)
return "\n".join([str(msg.get("content", "")) for msg in messages])
|