import json
import os
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.utils import (
flatten_dict,
)
[docs]def save_json(data: dict, file_path: str) -> None:
"""Save dict to local file path.
Parameters:
data (dict): The dictionary to be saved.
file_path (str): Local file path.
"""
with open(file_path, "w") as outfile:
json.dump(data, outfile)
[docs]class SageMakerCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
Parameters:
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
"""
[docs] def __init__(self, run: Any) -> None:
"""Initialize callback handler."""
super().__init__()
self.run = run
self.metrics = {
"step": 0,
"starts": 0,
"ends": 0,
"errors": 0,
"text_ctr": 0,
"chain_starts": 0,
"chain_ends": 0,
"llm_starts": 0,
"llm_ends": 0,
"llm_streams": 0,
"tool_starts": 0,
"tool_ends": 0,
"agent_ends": 0,
}
# Create a temporary directory
self.temp_dir = tempfile.mkdtemp()
def _reset(self) -> None:
for k, v in self.metrics.items():
self.metrics[k] = 0
[docs] def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.metrics["step"] += 1
self.metrics["llm_starts"] += 1
self.metrics["starts"] += 1
llm_starts = self.metrics["llm_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
for idx, prompt in enumerate(prompts):
prompt_resp = deepcopy(resp)
prompt_resp["prompt"] = prompt
self.jsonf(
prompt_resp,
self.temp_dir,
f"llm_start_{llm_starts}_prompt_{idx}",
)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.metrics["step"] += 1
self.metrics["llm_streams"] += 1
llm_streams = self.metrics["llm_streams"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.metrics["step"] += 1
self.metrics["llm_ends"] += 1
self.metrics["ends"] += 1
llm_ends = self.metrics["llm_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.metrics)
for generations in response.generations:
for idx, generation in enumerate(generations):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
self.jsonf(
resp,
self.temp_dir,
f"llm_end_{llm_ends}_generation_{idx}",
)
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
[docs] def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.metrics["step"] += 1
self.metrics["chain_starts"] += 1
self.metrics["starts"] += 1
chain_starts = self.metrics["chain_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.metrics["step"] += 1
self.metrics["chain_ends"] += 1
self.metrics["ends"] += 1
chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.metrics["step"] += 1
self.metrics["text_ctr"] += 1
text_ctr = self.metrics["text_ctr"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.metrics["step"] += 1
self.metrics["agent_ends"] += 1
self.metrics["ends"] += 1
agent_ends = self.metrics["agent_ends"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
[docs] def jsonf(
self,
data: Dict[str, Any],
data_dir: str,
filename: str,
is_output: Optional[bool] = True,
) -> None:
"""To log the input data as json file artifact."""
file_path = os.path.join(data_dir, f"{filename}.json")
save_json(data, file_path)
self.run.log_file(file_path, name=filename, is_output=is_output)
[docs] def flush_tracker(self) -> None:
"""Reset the steps and delete the temporary local directory."""
self._reset()
shutil.rmtree(self.temp_dir)