Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/callbacks
/arize_callback.py
from datetime import datetime | |
from typing import Any, Dict, List, Optional | |
from langchain_core.agents import AgentAction, AgentFinish | |
from langchain_core.callbacks import BaseCallbackHandler | |
from langchain_core.outputs import LLMResult | |
from langchain_community.callbacks.utils import import_pandas | |
class ArizeCallbackHandler(BaseCallbackHandler): | |
"""Callback Handler that logs to Arize.""" | |
def __init__( | |
self, | |
model_id: Optional[str] = None, | |
model_version: Optional[str] = None, | |
SPACE_KEY: Optional[str] = None, | |
API_KEY: Optional[str] = None, | |
) -> None: | |
"""Initialize callback handler.""" | |
super().__init__() | |
self.model_id = model_id | |
self.model_version = model_version | |
self.space_key = SPACE_KEY | |
self.api_key = API_KEY | |
self.prompt_records: List[str] = [] | |
self.response_records: List[str] = [] | |
self.prediction_ids: List[str] = [] | |
self.pred_timestamps: List[int] = [] | |
self.response_embeddings: List[float] = [] | |
self.prompt_embeddings: List[float] = [] | |
self.prompt_tokens = 0 | |
self.completion_tokens = 0 | |
self.total_tokens = 0 | |
self.step = 0 | |
from arize.pandas.embeddings import EmbeddingGenerator, UseCases | |
from arize.pandas.logger import Client | |
self.generator = EmbeddingGenerator.from_use_case( | |
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION, | |
model_name="distilbert-base-uncased", | |
tokenizer_max_length=512, | |
batch_size=256, | |
) | |
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY) | |
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY": | |
raise ValueError("❌ CHANGE SPACE AND API KEYS") | |
else: | |
print("✅ Arize client setup done! Now you can start using Arize!") # noqa: T201 | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
for prompt in prompts: | |
self.prompt_records.append(prompt.replace("\n", "")) | |
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Do nothing.""" | |
pass | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
pd = import_pandas() | |
from arize.utils.types import ( | |
EmbeddingColumnNames, | |
Environments, | |
ModelTypes, | |
Schema, | |
) | |
# Safe check if 'llm_output' and 'token_usage' exist | |
if response.llm_output and "token_usage" in response.llm_output: | |
self.prompt_tokens = response.llm_output["token_usage"].get( | |
"prompt_tokens", 0 | |
) | |
self.total_tokens = response.llm_output["token_usage"].get( | |
"total_tokens", 0 | |
) | |
self.completion_tokens = response.llm_output["token_usage"].get( | |
"completion_tokens", 0 | |
) | |
else: | |
self.prompt_tokens = ( | |
self.total_tokens | |
) = self.completion_tokens = 0 # assign default value | |
for generations in response.generations: | |
for generation in generations: | |
prompt = self.prompt_records[self.step] | |
self.step = self.step + 1 | |
prompt_embedding = pd.Series( | |
self.generator.generate_embeddings( | |
text_col=pd.Series(prompt.replace("\n", " ")) | |
).reset_index(drop=True) | |
) | |
# Assigning text to response_text instead of response | |
response_text = generation.text.replace("\n", " ") | |
response_embedding = pd.Series( | |
self.generator.generate_embeddings( | |
text_col=pd.Series(generation.text.replace("\n", " ")) | |
).reset_index(drop=True) | |
) | |
pred_timestamp = datetime.now().timestamp() | |
# Define the columns and data | |
columns = [ | |
"prediction_ts", | |
"response", | |
"prompt", | |
"response_vector", | |
"prompt_vector", | |
"prompt_token", | |
"completion_token", | |
"total_token", | |
] | |
data = [ | |
[ | |
pred_timestamp, | |
response_text, | |
prompt, | |
response_embedding[0], | |
prompt_embedding[0], | |
self.prompt_tokens, | |
self.total_tokens, | |
self.completion_tokens, | |
] | |
] | |
# Create the DataFrame | |
df = pd.DataFrame(data, columns=columns) | |
# Declare prompt and response columns | |
prompt_columns = EmbeddingColumnNames( | |
vector_column_name="prompt_vector", data_column_name="prompt" | |
) | |
response_columns = EmbeddingColumnNames( | |
vector_column_name="response_vector", data_column_name="response" | |
) | |
schema = Schema( | |
timestamp_column_name="prediction_ts", | |
tag_column_names=[ | |
"prompt_token", | |
"completion_token", | |
"total_token", | |
], | |
prompt_column_names=prompt_columns, | |
response_column_names=response_columns, | |
) | |
response_from_arize = self.arize_client.log( | |
dataframe=df, | |
schema=schema, | |
model_id=self.model_id, | |
model_version=self.model_version, | |
model_type=ModelTypes.GENERATIVE_LLM, | |
environment=Environments.PRODUCTION, | |
) | |
if response_from_arize.status_code == 200: | |
print("✅ Successfully logged data to Arize!") # noqa: T201 | |
else: | |
print(f'❌ Logging failed "{response_from_arize.text}"') # noqa: T201 | |
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Do nothing.""" | |
pass | |
def on_chain_start( | |
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
) -> None: | |
pass | |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
"""Do nothing.""" | |
pass | |
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Do nothing.""" | |
pass | |
def on_tool_start( | |
self, | |
serialized: Dict[str, Any], | |
input_str: str, | |
**kwargs: Any, | |
) -> None: | |
pass | |
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: | |
"""Do nothing.""" | |
pass | |
def on_tool_end( | |
self, | |
output: Any, | |
observation_prefix: Optional[str] = None, | |
llm_prefix: Optional[str] = None, | |
**kwargs: Any, | |
) -> None: | |
pass | |
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: | |
pass | |
def on_text(self, text: str, **kwargs: Any) -> None: | |
pass | |
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
pass | |