Source code for langchain_community.callbacks.context_callback

"""Callback handler for Context AI"""

import os
from typing import Any, Dict, List
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import


[docs]def import_context() -> Any: """Import the `getcontext` package.""" return ( guard_import("getcontext", pip_name="python-context"), guard_import("getcontext.token", pip_name="python-context").Credential, guard_import( "getcontext.generated.models", pip_name="python-context" ).Conversation, guard_import("getcontext.generated.models", pip_name="python-context").Message, guard_import( "getcontext.generated.models", pip_name="python-context" ).MessageRole, guard_import("getcontext.generated.models", pip_name="python-context").Rating, )
[docs]class ContextCallbackHandler(BaseCallbackHandler): """Callback Handler that records transcripts to the Context service. (https://context.ai). Keyword Args: token (optional): The token with which to authenticate requests to Context. Visit https://with.context.ai/settings to generate a token. If not provided, the value of the `CONTEXT_TOKEN` environment variable will be used. Raises: ImportError: if the `context-python` package is not installed. Chat Example: >>> from langchain_community.llms import ChatOpenAI >>> from langchain_community.callbacks import ContextCallbackHandler >>> context_callback = ContextCallbackHandler( ... token="<CONTEXT_TOKEN_HERE>", ... ) >>> chat = ChatOpenAI( ... temperature=0, ... headers={"user_id": "123"}, ... callbacks=[context_callback], ... openai_api_key="API_KEY_HERE", ... ) >>> messages = [ ... SystemMessage(content="You translate English to French."), ... HumanMessage(content="I love programming with LangChain."), ... ] >>> chat.invoke(messages) Chain Example: >>> from langchain.chains import LLMChain >>> from langchain_community.chat_models import ChatOpenAI >>> from langchain_community.callbacks import ContextCallbackHandler >>> context_callback = ContextCallbackHandler( ... token="<CONTEXT_TOKEN_HERE>", ... ) >>> human_message_prompt = HumanMessagePromptTemplate( ... prompt=PromptTemplate( ... template="What is a good name for a company that makes {product}?", ... input_variables=["product"], ... ), ... ) >>> chat_prompt_template = ChatPromptTemplate.from_messages( ... [human_message_prompt] ... ) >>> callback = ContextCallbackHandler(token) >>> # Note: the same callback object must be shared between the ... LLM and the chain. >>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback]) >>> chain = LLMChain( ... llm=chat, ... prompt=chat_prompt_template, ... callbacks=[callback] ... ) >>> chain.run("colorful socks") """
[docs] def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None: ( self.context, self.credential, self.conversation_model, self.message_model, self.message_role_model, self.rating_model, ) = import_context() token = token or os.environ.get("CONTEXT_TOKEN") or "" self.client = self.context.ContextAPI(credential=self.credential(token)) self.chain_run_id = None self.llm_model = None self.messages: List[Any] = [] self.metadata: Dict[str, str] = {}
[docs] def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, **kwargs: Any, ) -> Any: """Run when the chat model is started.""" llm_model = kwargs.get("invocation_params", {}).get("model", None) if llm_model is not None: self.metadata["model"] = llm_model if len(messages) == 0: return for message in messages[0]: role = self.message_role_model.SYSTEM if message.type == "human": role = self.message_role_model.USER elif message.type == "system": role = self.message_role_model.SYSTEM elif message.type == "ai": role = self.message_role_model.ASSISTANT self.messages.append( self.message_model( message=message.content, role=role, ) )
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends.""" if len(response.generations) == 0 or len(response.generations[0]) == 0: return if not self.chain_run_id: generation = response.generations[0][0] self.messages.append( self.message_model( message=generation.text, role=self.message_role_model.ASSISTANT, ) ) self._log_conversation()
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts.""" self.chain_run_id = kwargs.get("run_id", None)
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends.""" self.messages.append( self.message_model( message=outputs["text"], role=self.message_role_model.ASSISTANT, ) ) self._log_conversation() self.chain_run_id = None
def _log_conversation(self) -> None: """Log the conversation to the context API.""" if len(self.messages) == 0: return self.client.log.conversation_upsert( body={ "conversation": self.conversation_model( messages=self.messages, metadata=self.metadata, ) } ) self.messages = [] self.metadata = {}