Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/callbacks
/flyte_callback.py
"""FlyteKit callback handler.""" | |
from __future__ import annotations | |
import logging | |
from copy import deepcopy | |
from typing import TYPE_CHECKING, Any, Dict, List, Tuple | |
from langchain_core.agents import AgentAction, AgentFinish | |
from langchain_core.callbacks import BaseCallbackHandler | |
from langchain_core.outputs import LLMResult | |
from langchain_core.utils import guard_import | |
from langchain_community.callbacks.utils import ( | |
BaseMetadataCallbackHandler, | |
flatten_dict, | |
import_pandas, | |
import_spacy, | |
import_textstat, | |
) | |
if TYPE_CHECKING: | |
import flytekit | |
from flytekitplugins.deck import renderer | |
logger = logging.getLogger(__name__) | |
def import_flytekit() -> Tuple[flytekit, renderer]: | |
"""Import flytekit and flytekitplugins-deck-standard.""" | |
return ( | |
guard_import("flytekit"), | |
guard_import( | |
"flytekitplugins.deck", pip_name="flytekitplugins-deck-standard" | |
).renderer, | |
) | |
def analyze_text( | |
text: str, | |
nlp: Any = None, | |
textstat: Any = None, | |
) -> dict: | |
"""Analyze text using textstat and spacy. | |
Parameters: | |
text (str): The text to analyze. | |
nlp (spacy.lang): The spacy language model to use for visualization. | |
Returns: | |
(dict): A dictionary containing the complexity metrics and visualization | |
files serialized to HTML string. | |
""" | |
resp: Dict[str, Any] = {} | |
if textstat is not None: | |
text_complexity_metrics = { | |
"flesch_reading_ease": textstat.flesch_reading_ease(text), | |
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), | |
"smog_index": textstat.smog_index(text), | |
"coleman_liau_index": textstat.coleman_liau_index(text), | |
"automated_readability_index": textstat.automated_readability_index(text), | |
"dale_chall_readability_score": textstat.dale_chall_readability_score(text), | |
"difficult_words": textstat.difficult_words(text), | |
"linsear_write_formula": textstat.linsear_write_formula(text), | |
"gunning_fog": textstat.gunning_fog(text), | |
"fernandez_huerta": textstat.fernandez_huerta(text), | |
"szigriszt_pazos": textstat.szigriszt_pazos(text), | |
"gutierrez_polini": textstat.gutierrez_polini(text), | |
"crawford": textstat.crawford(text), | |
"gulpease_index": textstat.gulpease_index(text), | |
"osman": textstat.osman(text), | |
} | |
resp.update({"text_complexity_metrics": text_complexity_metrics}) | |
resp.update(text_complexity_metrics) | |
if nlp is not None: | |
spacy = import_spacy() | |
doc = nlp(text) | |
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) | |
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) | |
text_visualizations = { | |
"dependency_tree": dep_out, | |
"entities": ent_out, | |
} | |
resp.update(text_visualizations) | |
return resp | |
class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): | |
"""Callback handler that is used within a Flyte task.""" | |
def __init__(self) -> None: | |
"""Initialize callback handler.""" | |
flytekit, renderer = import_flytekit() | |
self.pandas = import_pandas() | |
self.textstat = None | |
try: | |
self.textstat = import_textstat() | |
except ImportError: | |
logger.warning( | |
"Textstat library is not installed. \ | |
It may result in the inability to log \ | |
certain metrics that can be captured with Textstat." | |
) | |
spacy = None | |
try: | |
spacy = import_spacy() | |
except ImportError: | |
logger.warning( | |
"Spacy library is not installed. \ | |
It may result in the inability to log \ | |
certain metrics that can be captured with Spacy." | |
) | |
super().__init__() | |
self.nlp = None | |
if spacy: | |
try: | |
self.nlp = spacy.load("en_core_web_sm") | |
except OSError: | |
logger.warning( | |
"FlyteCallbackHandler uses spacy's en_core_web_sm model" | |
" for certain metrics. To download," | |
" run the following command in your terminal:" | |
" `python -m spacy download en_core_web_sm`" | |
) | |
self.table_renderer = renderer.TableRenderer | |
self.markdown_renderer = renderer.MarkdownRenderer | |
self.deck = flytekit.Deck( | |
"LangChain Metrics", | |
self.markdown_renderer().to_html("## LangChain Metrics"), | |
) | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts.""" | |
self.step += 1 | |
self.llm_starts += 1 | |
self.starts += 1 | |
resp: Dict[str, Any] = {} | |
resp.update({"action": "on_llm_start"}) | |
resp.update(flatten_dict(serialized)) | |
resp.update(self.get_custom_callback_meta()) | |
prompt_responses = [] | |
for prompt in prompts: | |
prompt_responses.append(prompt) | |
resp.update({"prompts": prompt_responses}) | |
self.deck.append(self.markdown_renderer().to_html("### LLM Start")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" | |
) | |
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run when LLM generates a new token.""" | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
"""Run when LLM ends running.""" | |
self.step += 1 | |
self.llm_ends += 1 | |
self.ends += 1 | |
resp: Dict[str, Any] = {} | |
resp.update({"action": "on_llm_end"}) | |
resp.update(flatten_dict(response.llm_output or {})) | |
resp.update(self.get_custom_callback_meta()) | |
self.deck.append(self.markdown_renderer().to_html("### LLM End")) | |
self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp]))) | |
for generations in response.generations: | |
for generation in generations: | |
generation_resp = deepcopy(resp) | |
generation_resp.update(flatten_dict(generation.dict())) | |
if self.nlp or self.textstat: | |
generation_resp.update( | |
analyze_text( | |
generation.text, nlp=self.nlp, textstat=self.textstat | |
) | |
) | |
complexity_metrics: Dict[str, float] = generation_resp.pop( | |
"text_complexity_metrics" | |
) | |
self.deck.append( | |
self.markdown_renderer().to_html("#### Text Complexity Metrics") | |
) | |
self.deck.append( | |
self.table_renderer().to_html( | |
self.pandas.DataFrame([complexity_metrics]) | |
) | |
+ "\n" | |
) | |
dependency_tree = generation_resp["dependency_tree"] | |
self.deck.append( | |
self.markdown_renderer().to_html("#### Dependency Tree") | |
) | |
self.deck.append(dependency_tree) | |
entities = generation_resp["entities"] | |
self.deck.append(self.markdown_renderer().to_html("#### Entities")) | |
self.deck.append(entities) | |
else: | |
self.deck.append( | |
self.markdown_renderer().to_html("#### Generated Response") | |
) | |
self.deck.append(self.markdown_renderer().to_html(generation.text)) | |
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Run when LLM errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_chain_start( | |
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
) -> None: | |
"""Run when chain starts running.""" | |
self.step += 1 | |
self.chain_starts += 1 | |
self.starts += 1 | |
resp: Dict[str, Any] = {} | |
resp.update({"action": "on_chain_start"}) | |
resp.update(flatten_dict(serialized)) | |
resp.update(self.get_custom_callback_meta()) | |
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) | |
input_resp = deepcopy(resp) | |
input_resp["inputs"] = chain_input | |
self.deck.append(self.markdown_renderer().to_html("### Chain Start")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n" | |
) | |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
"""Run when chain ends running.""" | |
self.step += 1 | |
self.chain_ends += 1 | |
self.ends += 1 | |
resp: Dict[str, Any] = {} | |
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) | |
resp.update({"action": "on_chain_end", "outputs": chain_output}) | |
resp.update(self.get_custom_callback_meta()) | |
self.deck.append(self.markdown_renderer().to_html("### Chain End")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" | |
) | |
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Run when chain errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_tool_start( | |
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
) -> None: | |
"""Run when tool starts running.""" | |
self.step += 1 | |
self.tool_starts += 1 | |
self.starts += 1 | |
resp: Dict[str, Any] = {} | |
resp.update({"action": "on_tool_start", "input_str": input_str}) | |
resp.update(flatten_dict(serialized)) | |
resp.update(self.get_custom_callback_meta()) | |
self.deck.append(self.markdown_renderer().to_html("### Tool Start")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" | |
) | |
def on_tool_end(self, output: str, **kwargs: Any) -> None: | |
"""Run when tool ends running.""" | |
self.step += 1 | |
self.tool_ends += 1 | |
self.ends += 1 | |
resp: Dict[str, Any] = {} | |
resp.update({"action": "on_tool_end", "output": output}) | |
resp.update(self.get_custom_callback_meta()) | |
self.deck.append(self.markdown_renderer().to_html("### Tool End")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" | |
) | |
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Run when tool errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_text(self, text: str, **kwargs: Any) -> None: | |
""" | |
Run when agent is ending. | |
""" | |
self.step += 1 | |
self.text_ctr += 1 | |
resp: Dict[str, Any] = {} | |
resp.update({"action": "on_text", "text": text}) | |
resp.update(self.get_custom_callback_meta()) | |
self.deck.append(self.markdown_renderer().to_html("### On Text")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" | |
) | |
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
"""Run when agent ends running.""" | |
self.step += 1 | |
self.agent_ends += 1 | |
self.ends += 1 | |
resp: Dict[str, Any] = {} | |
resp.update( | |
{ | |
"action": "on_agent_finish", | |
"output": finish.return_values["output"], | |
"log": finish.log, | |
} | |
) | |
resp.update(self.get_custom_callback_meta()) | |
self.deck.append(self.markdown_renderer().to_html("### Agent Finish")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" | |
) | |
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: | |
"""Run on agent action.""" | |
self.step += 1 | |
self.tool_starts += 1 | |
self.starts += 1 | |
resp: Dict[str, Any] = {} | |
resp.update( | |
{ | |
"action": "on_agent_action", | |
"tool": action.tool, | |
"tool_input": action.tool_input, | |
"log": action.log, | |
} | |
) | |
resp.update(self.get_custom_callback_meta()) | |
self.deck.append(self.markdown_renderer().to_html("### Agent Action")) | |
self.deck.append( | |
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" | |
) | |