|
import os |
|
import time |
|
|
|
import httpx |
|
import warnings |
|
from typing import List, Dict, Optional |
|
from smolagents import ApiModel, ChatMessage |
|
|
|
|
|
class GeminiApiModel(ApiModel): |
|
""" |
|
ApiModel implementation using the Google Gemini API via direct HTTP requests. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_id: str = "gemini-pro", |
|
api_key: Optional[str] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Initializes the GeminiApiModel. |
|
|
|
Args: |
|
model_id (str): The Gemini model ID to use (e.g., "gemini-pro"). |
|
api_key (str, optional): Google AI Studio API key. Defaults to GEMINI_API_KEY environment variable. |
|
**kwargs: Additional keyword arguments passed to the parent ApiModel. |
|
""" |
|
self.model_id = model_id |
|
|
|
self.api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY") |
|
if not self.api_key: |
|
warnings.warn( |
|
"GEMINI_API_KEY not provided via argument or environment variable. API calls will likely fail.", |
|
UserWarning, |
|
) |
|
|
|
|
|
super().__init__( |
|
model_id=model_id, |
|
flatten_messages_as_text=True, |
|
**kwargs, |
|
) |
|
|
|
def create_client(self): |
|
"""No dedicated client needed as we use httpx directly.""" |
|
return None |
|
|
|
def __call__( |
|
self, |
|
messages: List[Dict[str, str]], |
|
stop_sequences: Optional[ |
|
List[str] |
|
] = None, |
|
grammar: Optional[ |
|
str |
|
] = None, |
|
tools_to_call_from: Optional[ |
|
List["Tool"] |
|
] = None, |
|
**kwargs, |
|
) -> ChatMessage: |
|
""" |
|
Calls the Google Gemini API with the provided messages. |
|
|
|
Args: |
|
messages: A list of message dictionaries (e.g., [{'role': 'user', 'content': '...'}]). |
|
stop_sequences: Optional stop sequences (may not be supported). |
|
grammar: Optional grammar constraint (not supported). |
|
tools_to_call_from: Optional list of tools (not supported). |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
A ChatMessage object containing the response. |
|
""" |
|
if not self.api_key: |
|
raise ValueError("GEMINI_API_KEY is not set.") |
|
|
|
|
|
|
|
prompt = self._messages_to_prompt(messages) |
|
prompt += ( |
|
"\n\n" |
|
+ "If you have a result from a web search that looks helpful, please use httpx to get the HTML from the URL listed." |
|
+ "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." |
|
) |
|
|
|
|
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_id}:generateContent?key={self.api_key}" |
|
headers = {"Content-Type": "application/json"} |
|
|
|
data = {"contents": [{"parts": [{"text": prompt}]}]} |
|
|
|
|
|
generation_config = {} |
|
if "temperature" in kwargs: |
|
generation_config["temperature"] = kwargs["temperature"] |
|
if "max_output_tokens" in kwargs: |
|
generation_config["maxOutputTokens"] = kwargs["max_output_tokens"] |
|
|
|
|
|
if generation_config: |
|
data["generationConfig"] = generation_config |
|
|
|
|
|
|
|
if stop_sequences: |
|
if "generationConfig" not in data: |
|
data["generationConfig"] = {} |
|
|
|
data["generationConfig"]["stopSequences"] = stop_sequences |
|
|
|
raw_response = None |
|
try: |
|
response = httpx.post( |
|
url, headers=headers, json=data, timeout=120.0 |
|
) |
|
time.sleep(6) |
|
response.raise_for_status() |
|
response_json = response.json() |
|
raw_response = response_json |
|
|
|
|
|
if "candidates" in response_json and response_json["candidates"]: |
|
part = response_json["candidates"][0]["content"]["parts"][0] |
|
if "text" in part: |
|
content = part["text"] |
|
|
|
final_answer_marker = "FINAL ANSWER: " |
|
if final_answer_marker in content: |
|
content = content.split(final_answer_marker)[-1].strip() |
|
|
|
|
|
|
|
|
|
|
|
self.last_input_token_count = 0 |
|
self.last_output_token_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
return ChatMessage( |
|
role="assistant", content=content, raw=raw_response |
|
) |
|
|
|
|
|
error_content = f"Error or unexpected response format: {response_json}" |
|
return ChatMessage( |
|
role="assistant", content=error_content, raw=raw_response |
|
) |
|
|
|
except httpx.RequestError as exc: |
|
error_content = ( |
|
f"An error occurred while requesting {exc.request.url!r}: {exc}" |
|
) |
|
return ChatMessage( |
|
role="assistant", content=error_content, raw={"error": str(exc)} |
|
) |
|
except httpx.HTTPStatusError as exc: |
|
error_content = f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}" |
|
return ChatMessage( |
|
role="assistant", |
|
content=error_content, |
|
raw={ |
|
"error": str(exc), |
|
"status_code": exc.response.status_code, |
|
"response_text": exc.response.text, |
|
}, |
|
) |
|
except Exception as e: |
|
error_content = f"An unexpected error occurred: {e}" |
|
return ChatMessage( |
|
role="assistant", content=error_content, raw={"error": str(e)} |
|
) |
|
|
|
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: |
|
"""Converts a list of messages into a single string prompt.""" |
|
|
|
|
|
return "\n".join([str(msg.get("content", "")) for msg in messages]) |
|
|