Source code for langchain_community.chat_message_histories.kafka

"""Kafka-based chat message history by using confluent-kafka-python.
confluent-kafka-python is under Apache 2.0 license.
https://github.com/confluentinc/confluent-kafka-python
"""

from __future__ import annotations

import json
import logging
import time
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict

if TYPE_CHECKING:
    from confluent_kafka import TopicPartition
    from confluent_kafka.admin import AdminClient

logger = logging.getLogger(__name__)

BOOTSTRAP_SERVERS_CONFIG = "bootstrap.servers"

DEFAULT_TTL_MS = 604800000  # 7 days
DEFAULT_REPLICATION_FACTOR = 1
DEFAULT_PARTITION = 3


[docs]class ConsumeStartPosition(Enum): """Consume start position for Kafka consumer to get chat history messages. LAST_CONSUMED: Continue from the last consumed offset. EARLIEST: Start consuming from the beginning. LATEST: Start consuming from the latest offset. """ LAST_CONSUMED = 1 EARLIEST = 2 LATEST = 3
[docs]def ensure_topic_exists( admin_client: AdminClient, topic_name: str, replication_factor: int, partition: int, ttl_ms: int, ) -> int: """Create topic if it doesn't exist, and return the number of partitions. If the topic already exists, we don't change the topic configuration. """ from confluent_kafka.admin import NewTopic try: topic_metadata = admin_client.list_topics().topics if topic_name in topic_metadata: num_partitions = len(topic_metadata[topic_name].partitions) logger.info( f"Topic {topic_name} already exists with {num_partitions} partitions" ) return num_partitions except Exception as e: logger.error(f"Failed to list topics: {e}") raise e topics = [ NewTopic( topic_name, num_partitions=partition, replication_factor=replication_factor, config={"retention.ms": str(ttl_ms)}, ) ] try: futures = admin_client.create_topics(topics) for _, f in futures.items(): f.result() # result is None logger.info(f"Topic {topic_name} created") except Exception as e: logger.error(f"Failed to create topic {topic_name}: {e}") raise e return partition
[docs]class KafkaChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in Kafka. Setup: Install ``confluent-kafka-python``. .. code-block:: bash pip install confluent_kafka Instantiate: .. code-block:: python from langchain_community.chat_message_histories import KafkaChatMessageHistory history = KafkaChatMessageHistory( session_id="your_session_id", bootstrap_servers="host:port", ) Add and retrieve messages: .. code-block:: python # Add messages history.add_messages([message1, message2, message3, ...]) # Retrieve messages message_batch_0 = history.messages # retrieve messages after message_batch_0 message_batch_1 = history.messages # Reset to beginning and retrieve messages messages_from_beginning = history.messages_from_beginning() Retrieving messages is stateful. Internally, it uses Kafka consumer to read. The consumed offset is maintained persistently. To retrieve messages, you can use the following methods: - `messages`: continue consuming chat messages from last one. - `messages_from_beginning`: reset the consumer to the beginning of the chat history and return messages. Optional parameters: 1. `max_message_count`: maximum number of messages to return. 2. `max_time_sec`: maximum time in seconds to wait for messages. - `messages_from_latest`: reset to end of the chat history and try consuming messages. Optional parameters same as above. - `messages_from_last_consumed`: continuing from the last consumed message, similar to `messages`. Optional parameters same as above. `max_message_count` and `max_time_sec` are used to avoid blocking indefinitely when retrieving messages. As a result, the method to retrieve messages may not return all messages. Change `max_message_count` and `max_time_sec` to retrieve all history messages. """ # noqa: E501
[docs] def __init__( self, session_id: str, bootstrap_servers: str, ttl_ms: int = DEFAULT_TTL_MS, replication_factor: int = DEFAULT_REPLICATION_FACTOR, partition: int = DEFAULT_PARTITION, ): """ Args: session_id: The ID for single chat session. It is used as Kafka topic name. bootstrap_servers: Comma-separated host/port pairs to establish connection to Kafka cluster https://kafka.apache.org/documentation.html#adminclientconfigs_bootstrap.servers ttl_ms: Time-to-live (milliseconds) for automatic expiration of entries. Default 7 days. -1 for no expiration. It translates to https://kafka.apache.org/documentation.html#topicconfigs_retention.ms replication_factor: The replication factor for the topic. Default 1. partition: The number of partitions for the topic. Default 3. """ try: from confluent_kafka import Producer from confluent_kafka.admin import AdminClient except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import confluent_kafka package. " "Please install it with `pip install confluent_kafka`." ) self.session_id = session_id self.bootstrap_servers = bootstrap_servers self.admin_client = AdminClient({BOOTSTRAP_SERVERS_CONFIG: bootstrap_servers}) self.num_partitions = ensure_topic_exists( self.admin_client, session_id, replication_factor, partition, ttl_ms ) self.producer = Producer({BOOTSTRAP_SERVERS_CONFIG: bootstrap_servers})
[docs] def add_messages( self, messages: Sequence[BaseMessage], flush_timeout_seconds: float = 5.0, ) -> None: """Add messages to the chat history by producing to the Kafka topic.""" try: for message in messages: self.producer.produce( topic=self.session_id, value=json.dumps(message_to_dict(message)), ) message_remaining = self.producer.flush(flush_timeout_seconds) if message_remaining > 0: logger.warning(f"{message_remaining} messages are still in-flight.") except Exception as e: logger.error(f"Failed to add messages to Kafka: {e}") raise e
def __read_messages( self, consume_start_pos: ConsumeStartPosition, max_message_count: Optional[int], max_time_sec: Optional[float], ) -> List[BaseMessage]: """Retrieve messages from Kafka topic for the session. Please note this method is stateful. Internally, it uses Kafka consumer to consume messages, and maintains the consumed offset. Args: consume_start_pos: Start position for Kafka consumer. max_message_count: Maximum number of messages to consume. max_time_sec: Time limit in seconds to consume messages. Returns: List of messages. """ from confluent_kafka import OFFSET_BEGINNING, OFFSET_END, Consumer consumer_config = { BOOTSTRAP_SERVERS_CONFIG: self.bootstrap_servers, "group.id": self.session_id, "auto.offset.reset": "latest" if consume_start_pos == ConsumeStartPosition.LATEST else "earliest", } def assign_beginning( assigned_consumer: Consumer, assigned_partitions: list[TopicPartition] ) -> None: for p in assigned_partitions: p.offset = OFFSET_BEGINNING assigned_consumer.assign(assigned_partitions) def assign_latest( assigned_consumer: Consumer, assigned_partitions: list[TopicPartition] ) -> None: for p in assigned_partitions: p.offset = OFFSET_END assigned_consumer.assign(assigned_partitions) messages: List[dict] = [] consumer = Consumer(consumer_config) try: if consume_start_pos == ConsumeStartPosition.EARLIEST: consumer.subscribe([self.session_id], on_assign=assign_beginning) elif consume_start_pos == ConsumeStartPosition.LATEST: consumer.subscribe([self.session_id], on_assign=assign_latest) else: consumer.subscribe([self.session_id]) start_time_sec = time.time() while True: if ( max_time_sec is not None and time.time() - start_time_sec > max_time_sec ): break if max_message_count is not None and len(messages) >= max_message_count: break message = consumer.poll(timeout=1.0) if message is None: # poll timeout continue if message.error() is not None: # error logger.error(f"Consumer error: {message.error()}") continue if message.value() is None: # empty value logger.warning("Empty message value") continue messages.append(json.loads(message.value())) except Exception as e: logger.error(f"Failed to consume messages from Kafka: {e}") raise e finally: consumer.close() return messages_from_dict(messages)
[docs] def messages_from_beginning( self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0 ) -> List[BaseMessage]: """Retrieve messages from Kafka topic from the beginning. This method resets the consumer to the beginning and consumes messages. Args: max_message_count: Maximum number of messages to consume. max_time_sec: Time limit in seconds to consume messages. Returns: List of messages. """ return self.__read_messages( consume_start_pos=ConsumeStartPosition.EARLIEST, max_message_count=max_message_count, max_time_sec=max_time_sec, )
[docs] def messages_from_latest( self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0 ) -> List[BaseMessage]: """Reset to the end offset. Try to consume messages if available. Args: max_message_count: Maximum number of messages to consume. max_time_sec: Time limit in seconds to consume messages. Returns: List of messages. """ return self.__read_messages( consume_start_pos=ConsumeStartPosition.LATEST, max_message_count=max_message_count, max_time_sec=max_time_sec, )
[docs] def messages_from_last_consumed( self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0 ) -> List[BaseMessage]: """Retrieve messages from Kafka topic from the last consumed message. Please note this method is stateful. Internally, it uses Kafka consumer to consume messages, and maintains the commit offset. Args: max_message_count: Maximum number of messages to consume. max_time_sec: Time limit in seconds to consume messages. Returns: List of messages. """ return self.__read_messages( consume_start_pos=ConsumeStartPosition.LAST_CONSUMED, max_message_count=max_message_count, max_time_sec=max_time_sec, )
@property def messages(self) -> List[BaseMessage]: # type: ignore """ Retrieve the messages for the session, from Kafka topic continuously from last consumed message. This method is stateful and maintains consumed(committed) offset based on consumer group. Alternatively, use messages_from_last_consumed() with specified parameters. Use messages_from_beginning() to read from the earliest message. Use messages_from_latest() to read from the latest message. """ return self.messages_from_last_consumed()
[docs] def clear(self) -> None: """Clear the chat history by deleting the Kafka topic.""" try: futures = self.admin_client.delete_topics([self.session_id]) for _, f in futures.items(): f.result() # result is None logger.info(f"Topic {self.session_id} deleted") except Exception as e: logger.error(f"Failed to delete topic {self.session_id}: {e}") raise e
[docs] def close(self) -> None: """Release the resources. Nothing to be released at this moment. """ pass