File size: 8,764 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import time
from typing import Any, Dict, List, Optional, cast

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.utils import guard_import


def import_infino() -> Any:
    """Import the infino client."""
    return guard_import("infinopy").InfinoClient()


def import_tiktoken() -> Any:
    """Import tiktoken for counting tokens for OpenAI models."""
    return guard_import("tiktoken")


def get_num_tokens(string: str, openai_model_name: str) -> int:
    """Calculate num tokens for OpenAI with tiktoken package.

    Official documentation: https://github.com/openai/openai-cookbook/blob/main
                            /examples/How_to_count_tokens_with_tiktoken.ipynb
    """
    tiktoken = import_tiktoken()

    encoding = tiktoken.encoding_for_model(openai_model_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens


class InfinoCallbackHandler(BaseCallbackHandler):
    """Callback Handler that logs to Infino."""

    def __init__(
        self,
        model_id: Optional[str] = None,
        model_version: Optional[str] = None,
        verbose: bool = False,
    ) -> None:
        # Set Infino client
        self.client = import_infino()
        self.model_id = model_id
        self.model_version = model_version
        self.verbose = verbose
        self.is_chat_openai_model = False
        self.chat_openai_model_name = "gpt-3.5-turbo"

    def _send_to_infino(
        self,
        key: str,
        value: Any,
        is_ts: bool = True,
    ) -> None:
        """Send the key-value to Infino.

        Parameters:
        key (str): the key to send to Infino.
        value (Any): the value to send to Infino.
        is_ts (bool): if True, the value is part of a time series, else it
                      is sent as a log message.
        """
        payload = {
            "date": int(time.time()),
            key: value,
            "labels": {
                "model_id": self.model_id,
                "model_version": self.model_version,
            },
        }
        if self.verbose:
            print(f"Tracking {key} with Infino: {payload}")  # noqa: T201

        # Append to Infino time series only if is_ts is True, otherwise
        # append to Infino log.
        if is_ts:
            self.client.append_ts(payload)
        else:
            self.client.append_log(payload)

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        **kwargs: Any,
    ) -> None:
        """Log the prompts to Infino, and set start time and error flag."""
        for prompt in prompts:
            self._send_to_infino("prompt", prompt, is_ts=False)

        # Set the error flag to indicate no error (this will get overridden
        # in on_llm_error if an error occurs).
        self.error = 0

        # Set the start time (so that we can calculate the request
        # duration in on_llm_end).
        self.start_time = time.time()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Do nothing when a new token is generated."""
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Log the latency, error, token usage, and response to Infino."""
        # Calculate and track the request latency.
        self.end_time = time.time()
        duration = self.end_time - self.start_time
        self._send_to_infino("latency", duration)

        # Track success or error flag.
        self._send_to_infino("error", self.error)

        # Track prompt response.
        for generations in response.generations:
            for generation in generations:
                self._send_to_infino("prompt_response", generation.text, is_ts=False)

        # Track token usage (for non-chat models).
        if (response.llm_output is not None) and isinstance(response.llm_output, Dict):
            token_usage = response.llm_output["token_usage"]
            if token_usage is not None:
                prompt_tokens = token_usage["prompt_tokens"]
                total_tokens = token_usage["total_tokens"]
                completion_tokens = token_usage["completion_tokens"]
                self._send_to_infino("prompt_tokens", prompt_tokens)
                self._send_to_infino("total_tokens", total_tokens)
                self._send_to_infino("completion_tokens", completion_tokens)

        # Track completion token usage (for openai chat models).
        if self.is_chat_openai_model:
            messages = " ".join(
                cast(str, cast(ChatGeneration, generation).message.content)
                for generation in generations
            )
            completion_tokens = get_num_tokens(
                messages, openai_model_name=self.chat_openai_model_name
            )
            self._send_to_infino("completion_tokens", completion_tokens)

    def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
        """Set the error flag."""
        self.error = 1

    def on_chain_start(
        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> None:
        """Do nothing when LLM chain starts."""
        pass

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        """Do nothing when LLM chain ends."""
        pass

    def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
        """Need to log the error."""
        pass

    def on_tool_start(
        self,
        serialized: Dict[str, Any],
        input_str: str,
        **kwargs: Any,
    ) -> None:
        """Do nothing when tool starts."""
        pass

    def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
        """Do nothing when agent takes a specific action."""
        pass

    def on_tool_end(
        self,
        output: str,
        observation_prefix: Optional[str] = None,
        llm_prefix: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Do nothing when tool ends."""
        pass

    def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
        """Do nothing when tool outputs an error."""
        pass

    def on_text(self, text: str, **kwargs: Any) -> None:
        """Do nothing."""
        pass

    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
        """Do nothing."""
        pass

    def on_chat_model_start(
        self,
        serialized: Dict[str, Any],
        messages: List[List[BaseMessage]],
        **kwargs: Any,
    ) -> None:
        """Run when LLM starts running."""

        # Currently, for chat models, we only support input prompts for ChatOpenAI.
        # Check if this model is a ChatOpenAI model.
        values = serialized.get("id")
        if values:
            for value in values:
                if value == "ChatOpenAI":
                    self.is_chat_openai_model = True
                    break

        # Track prompt tokens for ChatOpenAI model.
        if self.is_chat_openai_model:
            invocation_params = kwargs.get("invocation_params")
            if invocation_params:
                model_name = invocation_params.get("model_name")
                if model_name:
                    self.chat_openai_model_name = model_name
                    prompt_tokens = 0
                    for message_list in messages:
                        message_string = " ".join(
                            cast(str, msg.content) for msg in message_list
                        )
                        num_tokens = get_num_tokens(
                            message_string,
                            openai_model_name=self.chat_openai_model_name,
                        )
                        prompt_tokens += num_tokens

                    self._send_to_infino("prompt_tokens", prompt_tokens)

        if self.verbose:
            print(  # noqa: T201
                f"on_chat_model_start: is_chat_openai_model= \
                  {self.is_chat_openai_model}, \
                  chat_openai_model_name={self.chat_openai_model_name}"
            )

        # Send the prompt to infino
        prompt = " ".join(
            cast(str, msg.content) for sublist in messages for msg in sublist
        )
        self._send_to_infino("prompt", prompt, is_ts=False)

        # Set the error flag to indicate no error (this will get overridden
        # in on_llm_error if an error occurs).
        self.error = 0

        # Set the start time (so that we can calculate the request
        # duration in on_llm_end).
        self.start_time = time.time()