Final_Assignment_Template / gemini_model.py
bitcloud2's picture
cleaning final answer
d35ad8f
import os
import time
import httpx
import warnings
from typing import List, Dict, Optional
from smolagents import ApiModel, ChatMessage
class GeminiApiModel(ApiModel):
"""
ApiModel implementation using the Google Gemini API via direct HTTP requests.
"""
def __init__(
self,
model_id: str = "gemini-pro",
api_key: Optional[str] = None,
**kwargs,
):
"""
Initializes the GeminiApiModel.
Args:
model_id (str): The Gemini model ID to use (e.g., "gemini-pro").
api_key (str, optional): Google AI Studio API key. Defaults to GEMINI_API_KEY environment variable.
**kwargs: Additional keyword arguments passed to the parent ApiModel.
"""
self.model_id = model_id
# Prefer explicitly passed key, fallback to environment variable
self.api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY")
if not self.api_key:
warnings.warn(
"GEMINI_API_KEY not provided via argument or environment variable. API calls will likely fail.",
UserWarning,
)
# Gemini API doesn't inherently support complex role structures or function calling like OpenAI.
# We'll flatten messages for simplicity.
super().__init__(
model_id=model_id,
flatten_messages_as_text=True, # Flatten messages to a single text prompt
**kwargs,
)
def create_client(self):
"""No dedicated client needed as we use httpx directly."""
return None # Or potentially return httpx client if reused
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[
List[str]
] = None, # Note: Gemini API might not support stop sequences directly here
grammar: Optional[
str
] = None, # Note: Gemini API doesn't support grammar directly
tools_to_call_from: Optional[
List["Tool"]
] = None, # Note: Basic Gemini API doesn't support tools
**kwargs,
) -> ChatMessage:
"""
Calls the Google Gemini API with the provided messages.
Args:
messages: A list of message dictionaries (e.g., [{'role': 'user', 'content': '...'}]).
stop_sequences: Optional stop sequences (may not be supported).
grammar: Optional grammar constraint (not supported).
tools_to_call_from: Optional list of tools (not supported).
**kwargs: Additional keyword arguments.
Returns:
A ChatMessage object containing the response.
"""
if not self.api_key:
raise ValueError("GEMINI_API_KEY is not set.")
# Prepare the prompt by concatenating message content
# The Gemini Pro basic API expects a simple text prompt.
prompt = self._messages_to_prompt(messages)
prompt += (
"\n\n"
+ "If you have a result from a web search that looks helpful, please use httpx to get the HTML from the URL listed."
+ "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
)
# print(f"--- Gemini API prompt: ---\n{prompt}\n--- End of prompt ---")
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_id}:generateContent?key={self.api_key}"
headers = {"Content-Type": "application/json"}
# Construct the payload according to Gemini API requirements
data = {"contents": [{"parts": [{"text": prompt}]}]}
# Add generation config if provided via kwargs (optional)
generation_config = {}
if "temperature" in kwargs:
generation_config["temperature"] = kwargs["temperature"]
if "max_output_tokens" in kwargs:
generation_config["maxOutputTokens"] = kwargs["max_output_tokens"]
# Add other relevant config parameters here if needed
if generation_config:
data["generationConfig"] = generation_config
# Handle stop sequences if provided (basic support)
# Note: This is a best-effort addition, check Gemini API docs for formal support
if stop_sequences:
if "generationConfig" not in data:
data["generationConfig"] = {}
# Assuming Gemini API might support 'stopSequences' in generationConfig
data["generationConfig"]["stopSequences"] = stop_sequences
raw_response = None
try:
response = httpx.post(
url, headers=headers, json=data, timeout=120.0
) # Increased timeout
time.sleep(6) # Add delay to respect rate limits
response.raise_for_status()
response_json = response.json()
raw_response = response_json # Store raw response
# Parse the response - adjust based on actual Gemini API structure
if "candidates" in response_json and response_json["candidates"]:
part = response_json["candidates"][0]["content"]["parts"][0]
if "text" in part:
content = part["text"]
# Check for "FINAL ANSWER: " and extract the rest of the string
final_answer_marker = "FINAL ANSWER: "
if final_answer_marker in content:
content = content.split(final_answer_marker)[-1].strip()
# Simulate token counts if available, otherwise default to 0
# The basic generateContent API might not return usage directly in the main response
# It might be in safetyRatings or other metadata if enabled/available.
# Setting to 0 for now as it's not reliably present in the simplest call.
self.last_input_token_count = 0
self.last_output_token_count = 0
# If usage data becomes available in response_json, parse it here:
# e.g., if response_json.get("usageMetadata"):
# self.last_input_token_count = response_json["usageMetadata"].get("promptTokenCount", 0)
# self.last_output_token_count = response_json["usageMetadata"].get("candidatesTokenCount", 0)
return ChatMessage(
role="assistant", content=content, raw=raw_response
)
# Handle cases where the expected response structure isn't found
error_content = f"Error or unexpected response format: {response_json}"
return ChatMessage(
role="assistant", content=error_content, raw=raw_response
)
except httpx.RequestError as exc:
error_content = (
f"An error occurred while requesting {exc.request.url!r}: {exc}"
)
return ChatMessage(
role="assistant", content=error_content, raw={"error": str(exc)}
)
except httpx.HTTPStatusError as exc:
error_content = f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}"
return ChatMessage(
role="assistant",
content=error_content,
raw={
"error": str(exc),
"status_code": exc.response.status_code,
"response_text": exc.response.text,
},
)
except Exception as e:
error_content = f"An unexpected error occurred: {e}"
return ChatMessage(
role="assistant", content=error_content, raw={"error": str(e)}
)
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
"""Converts a list of messages into a single string prompt."""
# Simple concatenation, could be more sophisticated based on roles if needed
# Ensure we handle cases where 'content' might not be a string (though it should be)
return "\n".join([str(msg.get("content", "")) for msg in messages])