Source code for langchain_community.callbacks.flyte_callback

"""FlyteKit callback handler."""

from __future__ import annotations

import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import

from langchain_community.callbacks.utils import (
    BaseMetadataCallbackHandler,
    flatten_dict,
    import_pandas,
    import_spacy,
    import_textstat,
)

if TYPE_CHECKING:
    import flytekit
    from flytekitplugins.deck import renderer

logger = logging.getLogger(__name__)


[docs]def import_flytekit() -> Tuple[flytekit, renderer]: """Import flytekit and flytekitplugins-deck-standard.""" return ( guard_import("flytekit"), guard_import( "flytekitplugins.deck", pip_name="flytekitplugins-deck-standard" ).renderer, )
[docs]def analyze_text( text: str, nlp: Any = None, textstat: Any = None, ) -> dict: """Analyze text using textstat and spacy. Parameters: text (str): The text to analyze. nlp (spacy.lang): The spacy language model to use for visualization. Returns: (dict): A dictionary containing the complexity metrics and visualization files serialized to HTML string. """ resp: Dict[str, Any] = {} if textstat is not None: text_complexity_metrics = { "flesch_reading_ease": textstat.flesch_reading_ease(text), "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), "smog_index": textstat.smog_index(text), "coleman_liau_index": textstat.coleman_liau_index(text), "automated_readability_index": textstat.automated_readability_index(text), "dale_chall_readability_score": textstat.dale_chall_readability_score(text), "difficult_words": textstat.difficult_words(text), "linsear_write_formula": textstat.linsear_write_formula(text), "gunning_fog": textstat.gunning_fog(text), "fernandez_huerta": textstat.fernandez_huerta(text), "szigriszt_pazos": textstat.szigriszt_pazos(text), "gutierrez_polini": textstat.gutierrez_polini(text), "crawford": textstat.crawford(text), "gulpease_index": textstat.gulpease_index(text), "osman": textstat.osman(text), } resp.update({"text_complexity_metrics": text_complexity_metrics}) resp.update(text_complexity_metrics) if nlp is not None: spacy = import_spacy() doc = nlp(text) dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) text_visualizations = { "dependency_tree": dep_out, "entities": ent_out, } resp.update(text_visualizations) return resp
[docs]class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): """Callback handler that is used within a Flyte task."""
[docs] def __init__(self) -> None: """Initialize callback handler.""" flytekit, renderer = import_flytekit() self.pandas = import_pandas() self.textstat = None try: self.textstat = import_textstat() except ImportError: logger.warning( "Textstat library is not installed. \ It may result in the inability to log \ certain metrics that can be captured with Textstat." ) spacy = None try: spacy = import_spacy() except ImportError: logger.warning( "Spacy library is not installed. \ It may result in the inability to log \ certain metrics that can be captured with Spacy." ) super().__init__() self.nlp = None if spacy: try: self.nlp = spacy.load("en_core_web_sm") except OSError: logger.warning( "FlyteCallbackHandler uses spacy's en_core_web_sm model" " for certain metrics. To download," " run the following command in your terminal:" " `python -m spacy download en_core_web_sm`" ) self.table_renderer = renderer.TableRenderer self.markdown_renderer = renderer.MarkdownRenderer self.deck = flytekit.Deck( "LangChain Metrics", self.markdown_renderer().to_html("## LangChain Metrics"), )
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts.""" self.step += 1 self.llm_starts += 1 self.starts += 1 resp: Dict[str, Any] = {} resp.update({"action": "on_llm_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) prompt_responses = [] for prompt in prompts: prompt_responses.append(prompt) resp.update({"prompts": prompt_responses}) self.deck.append(self.markdown_renderer().to_html("### LLM Start")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" )
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run when LLM generates a new token."""
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" self.step += 1 self.llm_ends += 1 self.ends += 1 resp: Dict[str, Any] = {} resp.update({"action": "on_llm_end"}) resp.update(flatten_dict(response.llm_output or {})) resp.update(self.get_custom_callback_meta()) self.deck.append(self.markdown_renderer().to_html("### LLM End")) self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp]))) for generations in response.generations: for generation in generations: generation_resp = deepcopy(resp) generation_resp.update(flatten_dict(generation.dict())) if self.nlp or self.textstat: generation_resp.update( analyze_text( generation.text, nlp=self.nlp, textstat=self.textstat ) ) complexity_metrics: Dict[str, float] = generation_resp.pop( "text_complexity_metrics" ) self.deck.append( self.markdown_renderer().to_html("#### Text Complexity Metrics") ) self.deck.append( self.table_renderer().to_html( self.pandas.DataFrame([complexity_metrics]) ) + "\n" ) dependency_tree = generation_resp["dependency_tree"] self.deck.append( self.markdown_renderer().to_html("#### Dependency Tree") ) self.deck.append(dependency_tree) entities = generation_resp["entities"] self.deck.append(self.markdown_renderer().to_html("#### Entities")) self.deck.append(entities) else: self.deck.append( self.markdown_renderer().to_html("#### Generated Response") ) self.deck.append(self.markdown_renderer().to_html(generation.text))
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.step += 1 self.errors += 1
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" self.step += 1 self.chain_starts += 1 self.starts += 1 resp: Dict[str, Any] = {} resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) input_resp = deepcopy(resp) input_resp["inputs"] = chain_input self.deck.append(self.markdown_renderer().to_html("### Chain Start")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n" )
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" self.step += 1 self.chain_ends += 1 self.ends += 1 resp: Dict[str, Any] = {} chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) resp.update({"action": "on_chain_end", "outputs": chain_output}) resp.update(self.get_custom_callback_meta()) self.deck.append(self.markdown_renderer().to_html("### Chain End")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" )
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.step += 1 self.errors += 1
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """Run when tool starts running.""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp: Dict[str, Any] = {} resp.update({"action": "on_tool_start", "input_str": input_str}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) self.deck.append(self.markdown_renderer().to_html("### Tool Start")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" )
[docs] def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" self.step += 1 self.tool_ends += 1 self.ends += 1 resp: Dict[str, Any] = {} resp.update({"action": "on_tool_end", "output": output}) resp.update(self.get_custom_callback_meta()) self.deck.append(self.markdown_renderer().to_html("### Tool End")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" )
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.step += 1 self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """ Run when agent is ending. """ self.step += 1 self.text_ctr += 1 resp: Dict[str, Any] = {} resp.update({"action": "on_text", "text": text}) resp.update(self.get_custom_callback_meta()) self.deck.append(self.markdown_renderer().to_html("### On Text")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" )
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """Run when agent ends running.""" self.step += 1 self.agent_ends += 1 self.ends += 1 resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_finish", "output": finish.return_values["output"], "log": finish.log, } ) resp.update(self.get_custom_callback_meta()) self.deck.append(self.markdown_renderer().to_html("### Agent Finish")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" )
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_action", "tool": action.tool, "tool_input": action.tool_input, "log": action.log, } ) resp.update(self.get_custom_callback_meta()) self.deck.append(self.markdown_renderer().to_html("### Agent Action")) self.deck.append( self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" )