"""Fake ChatModel for testing purposes."""
import asyncio
import re
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
[docs]class FakeMessagesListChatModel(BaseChatModel):
"""Fake ChatModel for testing purposes."""
responses: List[BaseMessage]
"""List of responses to **cycle** through in order."""
sleep: Optional[float] = None
"""Sleep time in seconds between responses."""
i: int = 0
"""Internally incremented after every model invocation."""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-messages-list-chat-model"
[docs]class FakeListChatModelError(Exception):
pass
[docs]class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
responses: List[str]
"""List of responses to **cycle** through in order."""
sleep: Optional[float] = None
i: int = 0
"""List of responses to **cycle** through in order."""
error_on_chunk_number: Optional[int] = None
"""Internally incremented after every model invocation."""
@property
def _llm_type(self) -> str:
return "fake-list-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
return response
def _stream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for i_c, c in enumerate(response):
if self.sleep is not None:
time.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for i_c, c in enumerate(response):
if self.sleep is not None:
await asyncio.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"responses": self.responses}
[docs]class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return "fake response"
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = "fake response"
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"key": "fake"}
[docs]class GenericFakeChatModel(BaseChatModel):
"""Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
* Invokes on_llm_new_token to allow for testing of callback related code for new
tokens.
* Includes logic to break messages into message chunk to facilitate testing of
streaming.
"""
messages: Iterator[Union[AIMessage, str]]
"""Get an iterator over messages.
This can be expanded to accept other types like Callables / dicts / strings
to make the interface more generic if needed.
Note: if you want to pass a list, you can use `iter` to convert it to an iterator.
Please note that streaming is not implemented yet. We should try to implement it
in the future by delegating to invoke and then breaking the resulting output
into message chunks.
"""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
message = next(self.messages)
if isinstance(message, str):
message_ = AIMessage(content=message)
else:
message_ = message
generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model."""
chat_result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if not isinstance(chat_result, ChatResult):
raise ValueError(
f"Expected generate to return a ChatResult, "
f"but got {type(chat_result)} instead."
)
message = chat_result.generations[0].message
if not isinstance(message, AIMessage):
raise ValueError(
f"Expected invoke to return an AIMessage, "
f"but got {type(message)} instead."
)
content = message.content
if content:
# Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output.
assert isinstance(content, str)
content_chunks = cast(List[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, id=message.id)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
if message.additional_kwargs:
for key, value in message.additional_kwargs.items():
# We should further break down the additional kwargs into chunks
# Special case for function call
if key == "function_call":
for fkey, fvalue in value.items():
if isinstance(fvalue, str):
# Break function call by `,`
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue))
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={
"function_call": {fkey: fvalue_chunk}
},
)
)
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
yield chunk
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={"function_call": {fkey: fvalue}},
)
)
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
yield chunk
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id, content="", additional_kwargs={key: value}
)
)
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
yield chunk
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"
[docs]class ParrotFakeChatModel(BaseChatModel):
"""Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"