Source code for langchain_community.storage.cassandra

from __future__ import annotations

import asyncio
from asyncio import InvalidStateError, Task
from typing import (
    TYPE_CHECKING,
    AsyncIterator,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
)

from langchain_core.stores import ByteStore

from langchain_community.utilities.cassandra import SetupMode, aexecute_cql

if TYPE_CHECKING:
    from cassandra.cluster import Session
    from cassandra.query import PreparedStatement

CREATE_TABLE_CQL_TEMPLATE = """
    CREATE TABLE IF NOT EXISTS {keyspace}.{table} 
    (row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
"""
SELECT_TABLE_CQL_TEMPLATE = (
    """SELECT row_id, body_blob FROM  {keyspace}.{table} WHERE row_id IN ?;"""
)
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM  {keyspace}.{table};"""
INSERT_TABLE_CQL_TEMPLATE = (
    """INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
)
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""


[docs]class CassandraByteStore(ByteStore): """A ByteStore implementation using Cassandra as the backend. Parameters: table: The name of the table to use. session: A Cassandra session object. If not provided, it will be resolved from the cassio config. keyspace: The keyspace to use. If not provided, it will be resolved from the cassio config. setup_mode: The setup mode to use. Default is SYNC (SetupMode.SYNC). """
[docs] def __init__( self, table: str, *, session: Optional[Session] = None, keyspace: Optional[str] = None, setup_mode: SetupMode = SetupMode.SYNC, ) -> None: if not session or not keyspace: try: from cassio.config import check_resolve_keyspace, check_resolve_session self.keyspace = keyspace or check_resolve_keyspace(keyspace) self.session = session or check_resolve_session() except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import a recent cassio package." "Please install it with `pip install --upgrade cassio`." ) else: self.keyspace = keyspace self.session = session self.table = table self.select_statement = None self.insert_statement = None self.delete_statement = None create_cql = CREATE_TABLE_CQL_TEMPLATE.format( keyspace=self.keyspace, table=self.table, ) self.db_setup_task: Optional[Task[None]] = None if setup_mode == SetupMode.ASYNC: self.db_setup_task = asyncio.create_task( aexecute_cql(self.session, create_cql) ) else: self.session.execute(create_cql)
[docs] def ensure_db_setup(self) -> None: """Ensure that the DB setup is finished. If not, raise a ValueError.""" if self.db_setup_task: try: self.db_setup_task.result() except InvalidStateError: raise ValueError( "Asynchronous setup of the DB not finished. " "NB: AstraDB components sync methods shouldn't be called from the " "event loop. Consider using their async equivalents." )
[docs] async def aensure_db_setup(self) -> None: """Ensure that the DB setup is finished. If not, wait for it.""" if self.db_setup_task: await self.db_setup_task
[docs] def get_select_statement(self) -> PreparedStatement: """Get the prepared select statement for the table. If not available, prepare it. Returns: PreparedStatement: The prepared statement. """ if not self.select_statement: self.select_statement = self.session.prepare( SELECT_TABLE_CQL_TEMPLATE.format( keyspace=self.keyspace, table=self.table ) ) return self.select_statement
[docs] def get_insert_statement(self) -> PreparedStatement: """Get the prepared insert statement for the table. If not available, prepare it. Returns: PreparedStatement: The prepared statement. """ if not self.insert_statement: self.insert_statement = self.session.prepare( INSERT_TABLE_CQL_TEMPLATE.format( keyspace=self.keyspace, table=self.table ) ) return self.insert_statement
[docs] def get_delete_statement(self) -> PreparedStatement: """Get the prepared delete statement for the table. If not available, prepare it. Returns: PreparedStatement: The prepared statement. """ if not self.delete_statement: self.delete_statement = self.session.prepare( DELETE_TABLE_CQL_TEMPLATE.format( keyspace=self.keyspace, table=self.table ) ) return self.delete_statement
[docs] def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: from cassandra.query import ValueSequence self.ensure_db_setup() docs_dict = {} for row in self.session.execute( self.get_select_statement(), [ValueSequence(keys)] ): docs_dict[row.row_id] = row.body_blob return [docs_dict.get(key) for key in keys]
[docs] async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: from cassandra.query import ValueSequence await self.aensure_db_setup() docs_dict = {} for row in await aexecute_cql( self.session, self.get_select_statement(), parameters=[ValueSequence(keys)] ): docs_dict[row.row_id] = row.body_blob return [docs_dict.get(key) for key in keys]
[docs] def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: self.ensure_db_setup() insert_statement = self.get_insert_statement() for k, v in key_value_pairs: self.session.execute(insert_statement, (k, v))
[docs] async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: await self.aensure_db_setup() insert_statement = self.get_insert_statement() for k, v in key_value_pairs: await aexecute_cql(self.session, insert_statement, parameters=(k, v))
[docs] def mdelete(self, keys: Sequence[str]) -> None: from cassandra.query import ValueSequence self.ensure_db_setup() self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])
[docs] async def amdelete(self, keys: Sequence[str]) -> None: from cassandra.query import ValueSequence await self.aensure_db_setup() await aexecute_cql( self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)] )
[docs] def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: self.ensure_db_setup() for row in self.session.execute( SELECT_ALL_TABLE_CQL_TEMPLATE.format( keyspace=self.keyspace, table=self.table ) ): key = row.row_id if not prefix or key.startswith(prefix): yield key
[docs] async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: await self.aensure_db_setup() for row in await aexecute_cql( self.session, SELECT_ALL_TABLE_CQL_TEMPLATE.format( keyspace=self.keyspace, table=self.table ), ): key = row.row_id if not prefix or key.startswith(prefix): yield key