"""Kafka-based chat message history by using confluent-kafka-python.
confluent-kafka-python is under Apache 2.0 license.
https://github.com/confluentinc/confluent-kafka-python
"""
from __future__ import annotations
import json
import logging
import time
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Sequence
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
if TYPE_CHECKING:
from confluent_kafka import TopicPartition
from confluent_kafka.admin import AdminClient
logger = logging.getLogger(__name__)
BOOTSTRAP_SERVERS_CONFIG = "bootstrap.servers"
DEFAULT_TTL_MS = 604800000 # 7 days
DEFAULT_REPLICATION_FACTOR = 1
DEFAULT_PARTITION = 3
[docs]
class ConsumeStartPosition(Enum):
"""Consume start position for Kafka consumer to get chat history messages.
LAST_CONSUMED: Continue from the last consumed offset.
EARLIEST: Start consuming from the beginning.
LATEST: Start consuming from the latest offset.
"""
LAST_CONSUMED = 1
EARLIEST = 2
LATEST = 3
[docs]
def ensure_topic_exists(
admin_client: AdminClient,
topic_name: str,
replication_factor: int,
partition: int,
ttl_ms: int,
) -> int:
"""Create topic if it doesn't exist, and return the number of partitions.
If the topic already exists, we don't change the topic configuration.
"""
from confluent_kafka.admin import NewTopic
try:
topic_metadata = admin_client.list_topics().topics
if topic_name in topic_metadata:
num_partitions = len(topic_metadata[topic_name].partitions)
logger.info(
f"Topic {topic_name} already exists with {num_partitions} partitions"
)
return num_partitions
except Exception as e:
logger.error(f"Failed to list topics: {e}")
raise e
topics = [
NewTopic(
topic_name,
num_partitions=partition,
replication_factor=replication_factor,
config={"retention.ms": str(ttl_ms)},
)
]
try:
futures = admin_client.create_topics(topics)
for _, f in futures.items():
f.result() # result is None
logger.info(f"Topic {topic_name} created")
except Exception as e:
logger.error(f"Failed to create topic {topic_name}: {e}")
raise e
return partition
[docs]
class KafkaChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in Kafka.
Setup:
Install ``confluent-kafka-python``.
.. code-block:: bash
pip install confluent_kafka
Instantiate:
.. code-block:: python
from langchain_community.chat_message_histories import KafkaChatMessageHistory
history = KafkaChatMessageHistory(
session_id="your_session_id",
bootstrap_servers="host:port",
)
Add and retrieve messages:
.. code-block:: python
# Add messages
history.add_messages([message1, message2, message3, ...])
# Retrieve messages
message_batch_0 = history.messages
# retrieve messages after message_batch_0
message_batch_1 = history.messages
# Reset to beginning and retrieve messages
messages_from_beginning = history.messages_from_beginning()
Retrieving messages is stateful. Internally, it uses Kafka consumer to read.
The consumed offset is maintained persistently.
To retrieve messages, you can use the following methods:
- `messages`:
continue consuming chat messages from last one.
- `messages_from_beginning`:
reset the consumer to the beginning of the chat history and return messages.
Optional parameters:
1. `max_message_count`: maximum number of messages to return.
2. `max_time_sec`: maximum time in seconds to wait for messages.
- `messages_from_latest`:
reset to end of the chat history and try consuming messages.
Optional parameters same as above.
- `messages_from_last_consumed`:
continuing from the last consumed message, similar to `messages`.
Optional parameters same as above.
`max_message_count` and `max_time_sec` are used to avoid blocking indefinitely
when retrieving messages. As a result, the method to retrieve messages may not
return all messages. Change `max_message_count` and `max_time_sec` to retrieve
all history messages.
""" # noqa: E501
[docs]
def __init__(
self,
session_id: str,
bootstrap_servers: str,
ttl_ms: int = DEFAULT_TTL_MS,
replication_factor: int = DEFAULT_REPLICATION_FACTOR,
partition: int = DEFAULT_PARTITION,
):
"""
Args:
session_id: The ID for single chat session. It is used as Kafka topic name.
bootstrap_servers:
Comma-separated host/port pairs to establish connection to Kafka cluster
https://kafka.apache.org/documentation.html#adminclientconfigs_bootstrap.servers
ttl_ms:
Time-to-live (milliseconds) for automatic expiration of entries.
Default 7 days. -1 for no expiration.
It translates to https://kafka.apache.org/documentation.html#topicconfigs_retention.ms
replication_factor: The replication factor for the topic. Default 1.
partition: The number of partitions for the topic. Default 3.
"""
try:
from confluent_kafka import Producer
from confluent_kafka.admin import AdminClient
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import confluent_kafka package. "
"Please install it with `pip install confluent_kafka`."
)
self.session_id = session_id
self.bootstrap_servers = bootstrap_servers
self.admin_client = AdminClient({BOOTSTRAP_SERVERS_CONFIG: bootstrap_servers})
self.num_partitions = ensure_topic_exists(
self.admin_client, session_id, replication_factor, partition, ttl_ms
)
self.producer = Producer({BOOTSTRAP_SERVERS_CONFIG: bootstrap_servers})
[docs]
def add_messages(
self,
messages: Sequence[BaseMessage],
flush_timeout_seconds: float = 5.0,
) -> None:
"""Add messages to the chat history by producing to the Kafka topic."""
try:
for message in messages:
self.producer.produce(
topic=self.session_id,
value=json.dumps(message_to_dict(message)),
)
message_remaining = self.producer.flush(flush_timeout_seconds)
if message_remaining > 0:
logger.warning(f"{message_remaining} messages are still in-flight.")
except Exception as e:
logger.error(f"Failed to add messages to Kafka: {e}")
raise e
def __read_messages(
self,
consume_start_pos: ConsumeStartPosition,
max_message_count: Optional[int],
max_time_sec: Optional[float],
) -> List[BaseMessage]:
"""Retrieve messages from Kafka topic for the session.
Please note this method is stateful. Internally, it uses Kafka consumer
to consume messages, and maintains the consumed offset.
Args:
consume_start_pos: Start position for Kafka consumer.
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
from confluent_kafka import OFFSET_BEGINNING, OFFSET_END, Consumer
consumer_config = {
BOOTSTRAP_SERVERS_CONFIG: self.bootstrap_servers,
"group.id": self.session_id,
"auto.offset.reset": "latest"
if consume_start_pos == ConsumeStartPosition.LATEST
else "earliest",
}
def assign_beginning(
assigned_consumer: Consumer, assigned_partitions: list[TopicPartition]
) -> None:
for p in assigned_partitions:
p.offset = OFFSET_BEGINNING
assigned_consumer.assign(assigned_partitions)
def assign_latest(
assigned_consumer: Consumer, assigned_partitions: list[TopicPartition]
) -> None:
for p in assigned_partitions:
p.offset = OFFSET_END
assigned_consumer.assign(assigned_partitions)
messages: List[dict] = []
consumer = Consumer(consumer_config)
try:
if consume_start_pos == ConsumeStartPosition.EARLIEST:
consumer.subscribe([self.session_id], on_assign=assign_beginning)
elif consume_start_pos == ConsumeStartPosition.LATEST:
consumer.subscribe([self.session_id], on_assign=assign_latest)
else:
consumer.subscribe([self.session_id])
start_time_sec = time.time()
while True:
if (
max_time_sec is not None
and time.time() - start_time_sec > max_time_sec
):
break
if max_message_count is not None and len(messages) >= max_message_count:
break
message = consumer.poll(timeout=1.0)
if message is None: # poll timeout
continue
if message.error() is not None: # error
logger.error(f"Consumer error: {message.error()}")
continue
if message.value() is None: # empty value
logger.warning("Empty message value")
continue
messages.append(json.loads(message.value()))
except Exception as e:
logger.error(f"Failed to consume messages from Kafka: {e}")
raise e
finally:
consumer.close()
return messages_from_dict(messages)
[docs]
def messages_from_beginning(
self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0
) -> List[BaseMessage]:
"""Retrieve messages from Kafka topic from the beginning.
This method resets the consumer to the beginning and consumes messages.
Args:
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
return self.__read_messages(
consume_start_pos=ConsumeStartPosition.EARLIEST,
max_message_count=max_message_count,
max_time_sec=max_time_sec,
)
[docs]
def messages_from_latest(
self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0
) -> List[BaseMessage]:
"""Reset to the end offset. Try to consume messages if available.
Args:
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
return self.__read_messages(
consume_start_pos=ConsumeStartPosition.LATEST,
max_message_count=max_message_count,
max_time_sec=max_time_sec,
)
[docs]
def messages_from_last_consumed(
self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0
) -> List[BaseMessage]:
"""Retrieve messages from Kafka topic from the last consumed message.
Please note this method is stateful. Internally, it uses Kafka consumer
to consume messages, and maintains the commit offset.
Args:
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
return self.__read_messages(
consume_start_pos=ConsumeStartPosition.LAST_CONSUMED,
max_message_count=max_message_count,
max_time_sec=max_time_sec,
)
@property
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""
Retrieve the messages for the session, from Kafka topic continuously
from last consumed message. This method is stateful and maintains
consumed(committed) offset based on consumer group.
Alternatively, use messages_from_last_consumed() with specified parameters.
Use messages_from_beginning() to read from the earliest message.
Use messages_from_latest() to read from the latest message.
"""
return self.messages_from_last_consumed()
[docs]
def clear(self) -> None:
"""Clear the chat history by deleting the Kafka topic."""
try:
futures = self.admin_client.delete_topics([self.session_id])
for _, f in futures.items():
f.result() # result is None
logger.info(f"Topic {self.session_id} deleted")
except Exception as e:
logger.error(f"Failed to delete topic {self.session_id}: {e}")
raise e
[docs]
def close(self) -> None:
"""Release the resources.
Nothing to be released at this moment.
"""
pass