File size: 8,834 Bytes
b0ed2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d35ad8f
 
 
 
 
b0ed2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import time

import httpx
import warnings
from typing import List, Dict, Optional
from smolagents import ApiModel, ChatMessage


class GeminiApiModel(ApiModel):
    """
    ApiModel implementation using the Google Gemini API via direct HTTP requests.
    """

    def __init__(
        self,
        model_id: str = "gemini-pro",
        api_key: Optional[str] = None,
        **kwargs,
    ):
        """
        Initializes the GeminiApiModel.

        Args:
            model_id (str): The Gemini model ID to use (e.g., "gemini-pro").
            api_key (str, optional): Google AI Studio API key. Defaults to GEMINI_API_KEY environment variable.
            **kwargs: Additional keyword arguments passed to the parent ApiModel.
        """
        self.model_id = model_id
        # Prefer explicitly passed key, fallback to environment variable
        self.api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY")
        if not self.api_key:
            warnings.warn(
                "GEMINI_API_KEY not provided via argument or environment variable. API calls will likely fail.",
                UserWarning,
            )
        # Gemini API doesn't inherently support complex role structures or function calling like OpenAI.
        # We'll flatten messages for simplicity.
        super().__init__(
            model_id=model_id,
            flatten_messages_as_text=True,  # Flatten messages to a single text prompt
            **kwargs,
        )

    def create_client(self):
        """No dedicated client needed as we use httpx directly."""
        return None  # Or potentially return httpx client if reused

    def __call__(
        self,
        messages: List[Dict[str, str]],
        stop_sequences: Optional[
            List[str]
        ] = None,  # Note: Gemini API might not support stop sequences directly here
        grammar: Optional[
            str
        ] = None,  # Note: Gemini API doesn't support grammar directly
        tools_to_call_from: Optional[
            List["Tool"]
        ] = None,  # Note: Basic Gemini API doesn't support tools
        **kwargs,
    ) -> ChatMessage:
        """
        Calls the Google Gemini API with the provided messages.

        Args:
            messages: A list of message dictionaries (e.g., [{'role': 'user', 'content': '...'}]).
            stop_sequences: Optional stop sequences (may not be supported).
            grammar: Optional grammar constraint (not supported).
            tools_to_call_from: Optional list of tools (not supported).
            **kwargs: Additional keyword arguments.

        Returns:
            A ChatMessage object containing the response.
        """
        if not self.api_key:
            raise ValueError("GEMINI_API_KEY is not set.")

        # Prepare the prompt by concatenating message content
        # The Gemini Pro basic API expects a simple text prompt.
        prompt = self._messages_to_prompt(messages)
        prompt += (
            "\n\n"
            + "If you have a result from a web search that looks helpful, please use httpx to get the HTML from the URL listed."
            + "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
        )
        # print(f"--- Gemini API prompt: ---\n{prompt}\n--- End of prompt ---")

        url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_id}:generateContent?key={self.api_key}"
        headers = {"Content-Type": "application/json"}
        # Construct the payload according to Gemini API requirements
        data = {"contents": [{"parts": [{"text": prompt}]}]}

        # Add generation config if provided via kwargs (optional)
        generation_config = {}
        if "temperature" in kwargs:
            generation_config["temperature"] = kwargs["temperature"]
        if "max_output_tokens" in kwargs:
            generation_config["maxOutputTokens"] = kwargs["max_output_tokens"]
        # Add other relevant config parameters here if needed

        if generation_config:
            data["generationConfig"] = generation_config

        # Handle stop sequences if provided (basic support)
        # Note: This is a best-effort addition, check Gemini API docs for formal support
        if stop_sequences:
            if "generationConfig" not in data:
                data["generationConfig"] = {}
            # Assuming Gemini API might support 'stopSequences' in generationConfig
            data["generationConfig"]["stopSequences"] = stop_sequences

        raw_response = None
        try:
            response = httpx.post(
                url, headers=headers, json=data, timeout=120.0
            )  # Increased timeout
            time.sleep(6)  # Add delay to respect rate limits
            response.raise_for_status()
            response_json = response.json()
            raw_response = response_json  # Store raw response

            # Parse the response - adjust based on actual Gemini API structure
            if "candidates" in response_json and response_json["candidates"]:
                part = response_json["candidates"][0]["content"]["parts"][0]
                if "text" in part:
                    content = part["text"]
                    # Check for "FINAL ANSWER: " and extract the rest of the string
                    final_answer_marker = "FINAL ANSWER: "
                    if final_answer_marker in content:
                        content = content.split(final_answer_marker)[-1].strip()

                    # Simulate token counts if available, otherwise default to 0
                    # The basic generateContent API might not return usage directly in the main response
                    # It might be in safetyRatings or other metadata if enabled/available.
                    # Setting to 0 for now as it's not reliably present in the simplest call.
                    self.last_input_token_count = 0
                    self.last_output_token_count = 0
                    # If usage data becomes available in response_json, parse it here:
                    # e.g., if response_json.get("usageMetadata"):
                    #    self.last_input_token_count = response_json["usageMetadata"].get("promptTokenCount", 0)
                    #    self.last_output_token_count = response_json["usageMetadata"].get("candidatesTokenCount", 0)

                    return ChatMessage(
                        role="assistant", content=content, raw=raw_response
                    )

            # Handle cases where the expected response structure isn't found
            error_content = f"Error or unexpected response format: {response_json}"
            return ChatMessage(
                role="assistant", content=error_content, raw=raw_response
            )

        except httpx.RequestError as exc:
            error_content = (
                f"An error occurred while requesting {exc.request.url!r}: {exc}"
            )
            return ChatMessage(
                role="assistant", content=error_content, raw={"error": str(exc)}
            )
        except httpx.HTTPStatusError as exc:
            error_content = f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}"
            return ChatMessage(
                role="assistant",
                content=error_content,
                raw={
                    "error": str(exc),
                    "status_code": exc.response.status_code,
                    "response_text": exc.response.text,
                },
            )
        except Exception as e:
            error_content = f"An unexpected error occurred: {e}"
            return ChatMessage(
                role="assistant", content=error_content, raw={"error": str(e)}
            )

    def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
        """Converts a list of messages into a single string prompt."""
        # Simple concatenation, could be more sophisticated based on roles if needed
        # Ensure we handle cases where 'content' might not be a string (though it should be)
        return "\n".join([str(msg.get("content", "")) for msg in messages])