Source code for langchain_core.language_models.fake_chat_models

"""Fake ChatModel for testing purposes."""

import asyncio
import re
import time
from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional, Union, cast

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig


[docs] class FakeMessagesListChatModel(BaseChatModel): """Fake ChatModel for testing purposes.""" responses: list[BaseMessage] """List of responses to **cycle** through in order.""" sleep: Optional[float] = None """Sleep time in seconds between responses.""" i: int = 0 """Internally incremented after every model invocation.""" def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 generation = ChatGeneration(message=response) return ChatResult(generations=[generation]) @property def _llm_type(self) -> str: return "fake-messages-list-chat-model"
[docs] class FakeListChatModelError(Exception): pass
[docs] class FakeListChatModel(SimpleChatModel): """Fake ChatModel for testing purposes.""" responses: list[str] """List of responses to **cycle** through in order.""" sleep: Optional[float] = None i: int = 0 """List of responses to **cycle** through in order.""" error_on_chunk_number: Optional[int] = None """Internally incremented after every model invocation.""" @property def _llm_type(self) -> str: return "fake-list-chat-model" def _call( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 return response def _stream( self, messages: list[BaseMessage], stop: Union[list[str], None] = None, run_manager: Union[CallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 for i_c, c in enumerate(response): if self.sleep is not None: time.sleep(self.sleep) if ( self.error_on_chunk_number is not None and i_c == self.error_on_chunk_number ): raise FakeListChatModelError yield ChatGenerationChunk(message=AIMessageChunk(content=c)) async def _astream( self, messages: list[BaseMessage], stop: Union[list[str], None] = None, run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 for i_c, c in enumerate(response): if self.sleep is not None: await asyncio.sleep(self.sleep) if ( self.error_on_chunk_number is not None and i_c == self.error_on_chunk_number ): raise FakeListChatModelError yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @property def _identifying_params(self) -> dict[str, Any]: return {"responses": self.responses} # manually override batch to preserve batch ordering with no concurrency
[docs] def batch( self, inputs: list[Any], config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> list[BaseMessage]: if isinstance(config, list): return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)] return [self.invoke(m, config, **kwargs) for m in inputs]
[docs] async def abatch( self, inputs: list[Any], config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> list[BaseMessage]: if isinstance(config, list): # do Not use an async iterator here because need explicit ordering return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)] # do Not use an async iterator here because need explicit ordering return [await self.ainvoke(m, config, **kwargs) for m in inputs]
[docs] class FakeChatModel(SimpleChatModel): """Fake Chat Model wrapper for testing purposes.""" def _call( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: return "fake response" async def _agenerate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: output_str = "fake response" message = AIMessage(content=output_str) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @property def _llm_type(self) -> str: return "fake-chat-model" @property def _identifying_params(self) -> dict[str, Any]: return {"key": "fake"}
[docs] class GenericFakeChatModel(BaseChatModel): """Generic fake chat model that can be used to test the chat model interface. * Chat model should be usable in both sync and async tests * Invokes on_llm_new_token to allow for testing of callback related code for new tokens. * Includes logic to break messages into message chunk to facilitate testing of streaming. """ messages: Iterator[Union[AIMessage, str]] """Get an iterator over messages. This can be expanded to accept other types like Callables / dicts / strings to make the interface more generic if needed. Note: if you want to pass a list, you can use `iter` to convert it to an iterator. Please note that streaming is not implemented yet. We should try to implement it in the future by delegating to invoke and then breaking the resulting output into message chunks. """ def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """Top Level call""" message = next(self.messages) message_ = AIMessage(content=message) if isinstance(message, str) else message generation = ChatGeneration(message=message_) return ChatResult(generations=[generation]) def _stream( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model.""" chat_result = self._generate( messages, stop=stop, run_manager=run_manager, **kwargs ) if not isinstance(chat_result, ChatResult): msg = ( f"Expected generate to return a ChatResult, " f"but got {type(chat_result)} instead." ) raise ValueError(msg) message = chat_result.generations[0].message if not isinstance(message, AIMessage): msg = ( f"Expected invoke to return an AIMessage, " f"but got {type(message)} instead." ) raise ValueError(msg) content = message.content if content: # Use a regular expression to split on whitespace with a capture group # so that we can preserve the whitespace in the output. assert isinstance(content, str) content_chunks = cast(list[str], re.split(r"(\s)", content)) for token in content_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk(content=token, id=message.id) ) if run_manager: run_manager.on_llm_new_token(token, chunk=chunk) yield chunk if message.additional_kwargs: for key, value in message.additional_kwargs.items(): # We should further break down the additional kwargs into chunks # Special case for function call if key == "function_call": for fkey, fvalue in value.items(): if isinstance(fvalue, str): # Break function call by `,` fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue)) for fvalue_chunk in fvalue_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk( id=message.id, content="", additional_kwargs={ "function_call": {fkey: fvalue_chunk} }, ) ) if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) yield chunk else: chunk = ChatGenerationChunk( message=AIMessageChunk( id=message.id, content="", additional_kwargs={"function_call": {fkey: fvalue}}, ) ) if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) yield chunk else: chunk = ChatGenerationChunk( message=AIMessageChunk( id=message.id, content="", additional_kwargs={key: value} ) ) if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) yield chunk @property def _llm_type(self) -> str: return "generic-fake-chat-model"
[docs] class ParrotFakeChatModel(BaseChatModel): """Generic fake chat model that can be used to test the chat model interface. * Chat model should be usable in both sync and async tests """ def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """Top Level call""" return ChatResult(generations=[ChatGeneration(message=messages[-1])]) @property def _llm_type(self) -> str: return "parrot-fake-chat-model"