Source code for langchain_community.chat_message_histories.sql
importcontextlibimportjsonimportloggingfromabcimportABC,abstractmethodfromtypingimport(Any,AsyncGenerator,Dict,Generator,List,Optional,Sequence,Union,cast,)fromlangchain_core._apiimportdeprecated,warn_deprecatedfromsqlalchemyimportColumn,Integer,Text,delete,selecttry:fromsqlalchemy.ormimportdeclarative_baseexceptImportError:fromsqlalchemy.ext.declarativeimportdeclarative_basefromlangchain_core.chat_historyimportBaseChatMessageHistoryfromlangchain_core.messagesimport(BaseMessage,message_to_dict,messages_from_dict,)fromsqlalchemyimportcreate_enginefromsqlalchemy.engineimportEnginefromsqlalchemy.ext.asyncioimport(AsyncEngine,AsyncSession,create_async_engine,)fromsqlalchemy.ormimport(SessionasSQLSession,)fromsqlalchemy.ormimport(declarative_base,scoped_session,sessionmaker,)try:fromsqlalchemy.ext.asyncioimportasync_sessionmakerexceptImportError:# dummy for sqlalchemy < 2async_sessionmaker=type("async_sessionmaker",(type,),{})# type: ignorelogger=logging.getLogger(__name__)
[docs]classBaseMessageConverter(ABC):"""Convert BaseMessage to the SQLAlchemy model."""
[docs]@abstractmethoddeffrom_sql_model(self,sql_message:Any)->BaseMessage:"""Convert a SQLAlchemy model to a BaseMessage instance."""raiseNotImplementedError
[docs]@abstractmethoddefto_sql_model(self,message:BaseMessage,session_id:str)->Any:"""Convert a BaseMessage instance to a SQLAlchemy model."""raiseNotImplementedError
[docs]@abstractmethoddefget_sql_model_class(self)->Any:"""Get the SQLAlchemy model class."""raiseNotImplementedError
[docs]defcreate_message_model(table_name:str,DynamicBase:Any)->Any:""" Create a message model for a given table name. Args: table_name: The name of the table to use. DynamicBase: The base class to use for the model. Returns: The model class. """# Model declared inside a function to have a dynamic table name.classMessage(DynamicBase):# type: ignore[valid-type, misc]__tablename__=table_nameid=Column(Integer,primary_key=True)session_id=Column(Text)message=Column(Text)returnMessage
[docs]classDefaultMessageConverter(BaseMessageConverter):"""The default message converter for SQLChatMessageHistory."""
[docs]classSQLChatMessageHistory(BaseChatMessageHistory):"""Chat message history stored in an SQL database. Example: .. code-block:: python from langchain_core.messages import HumanMessage from langchain_community.chat_message_histories import SQLChatMessageHistory # create sync sql message history by connection_string message_history = SQLChatMessageHistory( session_id='foo', connection_string='sqlite///:memory.db' ) message_history.add_message(HumanMessage("hello")) message_history.message # create async sql message history using aiosqlite # from sqlalchemy.ext.asyncio import create_async_engine # # async_engine = create_async_engine("sqlite+aiosqlite:///memory.db") # async_message_history = SQLChatMessageHistory( # session_id='foo', connection=async_engine, # ) # await async_message_history.aadd_message(HumanMessage("hello")) # await async_message_history.aget_messages() """@property@deprecated("0.2.2",removal="1.0",alternative="session_maker")defSession(self)->Union[scoped_session,async_sessionmaker]:returnself.session_maker
[docs]def__init__(self,session_id:str,connection_string:Optional[str]=None,table_name:str="message_store",session_id_field_name:str="session_id",custom_message_converter:Optional[BaseMessageConverter]=None,connection:Union[None,DBConnection]=None,engine_args:Optional[Dict[str,Any]]=None,async_mode:Optional[bool]=None,# Use only if connection is a string):"""Initialize with a SQLChatMessageHistory instance. Args: session_id: Indicates the id of the same session. connection_string: String parameter configuration for connecting to the database. table_name: Table name used to save data. session_id_field_name: The name of field of `session_id`. custom_message_converter: Custom message converter for converting database data and `BaseMessage` connection: Database connection object, which can be a string containing connection configuration, Engine object or AsyncEngine object. engine_args: Additional configuration for creating database engines. async_mode: Whether it is an asynchronous connection. """assertnot(connection_stringandconnection),("connection_string and connection are mutually exclusive")ifconnection_string:global_warned_once_alreadyifnot_warned_once_already:warn_deprecated(since="0.2.2",removal="1.0",name="connection_string",alternative="connection",)_warned_once_already=Trueconnection=connection_stringself.connection_string=connection_stringifisinstance(connection,str):self.async_mode=async_modeifasync_mode:self.async_engine=create_async_engine(connection,**(engine_argsor{}))else:self.engine=create_engine(url=connection,**(engine_argsor{}))elifisinstance(connection,Engine):self.async_mode=Falseself.engine=connectionelifisinstance(connection,AsyncEngine):self.async_mode=Trueself.async_engine=connectionelse:raiseValueError("connection should be a connection string or an instance of ""sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine")# To be consistent with others SQL implementations, rename to session_makerself.session_maker:Union[scoped_session,async_sessionmaker]ifself.async_mode:self.session_maker=async_sessionmaker(bind=self.async_engine)else:self.session_maker=scoped_session(sessionmaker(bind=self.engine))self.session_id_field_name=session_id_field_nameself.converter=custom_message_converterorDefaultMessageConverter(table_name)self.sql_model_class=self.converter.get_sql_model_class()ifnothasattr(self.sql_model_class,session_id_field_name):raiseValueError("SQL model class must have session_id column")self._table_created=Falseifnotself.async_mode:self._create_table_if_not_exists()self.session_id=session_id
def_create_table_if_not_exists(self)->None:self.sql_model_class.metadata.create_all(self.engine)self._table_created=Trueasyncdef_acreate_table_if_not_exists(self)->None:ifnotself._table_created:assertself.async_mode,"This method must be called with async_mode"asyncwithself.async_engine.begin()asconn:awaitconn.run_sync(self.sql_model_class.metadata.create_all)self._table_created=True@propertydefmessages(self)->List[BaseMessage]:# type: ignore"""Retrieve all messages from db"""withself._make_sync_session()assession:result=(session.query(self.sql_model_class).where(getattr(self.sql_model_class,self.session_id_field_name)==self.session_id).order_by(self.sql_model_class.id.asc()))messages=[]forrecordinresult:messages.append(self.converter.from_sql_model(record))returnmessages
[docs]asyncdefaget_messages(self)->List[BaseMessage]:"""Retrieve all messages from db"""awaitself._acreate_table_if_not_exists()asyncwithself._make_async_session()assession:stmt=(select(self.sql_model_class).where(getattr(self.sql_model_class,self.session_id_field_name)==self.session_id).order_by(self.sql_model_class.id.asc()))result=awaitsession.execute(stmt)messages=[]forrecordinresult.scalars():messages.append(self.converter.from_sql_model(record))returnmessages
[docs]defadd_message(self,message:BaseMessage)->None:"""Append the message to the record in db"""withself._make_sync_session()assession:session.add(self.converter.to_sql_model(message,self.session_id))session.commit()
[docs]asyncdefaadd_message(self,message:BaseMessage)->None:"""Add a Message object to the store. Args: message: A BaseMessage object to store. """awaitself._acreate_table_if_not_exists()asyncwithself._make_async_session()assession:session.add(self.converter.to_sql_model(message,self.session_id))awaitsession.commit()
[docs]defadd_messages(self,messages:Sequence[BaseMessage])->None:# Add all messages in one transactionwithself._make_sync_session()assession:formessageinmessages:session.add(self.converter.to_sql_model(message,self.session_id))session.commit()
[docs]asyncdefaadd_messages(self,messages:Sequence[BaseMessage])->None:# Add all messages in one transactionawaitself._acreate_table_if_not_exists()asyncwithself.session_maker()assession:formessageinmessages:session.add(self.converter.to_sql_model(message,self.session_id))awaitsession.commit()
[docs]defclear(self)->None:"""Clear session memory from db"""withself._make_sync_session()assession:session.query(self.sql_model_class).filter(getattr(self.sql_model_class,self.session_id_field_name)==self.session_id).delete()session.commit()
[docs]asyncdefaclear(self)->None:"""Clear session memory from db"""awaitself._acreate_table_if_not_exists()asyncwithself._make_async_session()assession:stmt=delete(self.sql_model_class).filter(getattr(self.sql_model_class,self.session_id_field_name)==self.session_id)awaitsession.execute(stmt)awaitsession.commit()
@contextlib.contextmanagerdef_make_sync_session(self)->Generator[SQLSession,None,None]:"""Make an async session."""ifself.async_mode:raiseValueError("Attempting to use a sync method in when async mode is turned on. ""Please use the corresponding async method instead.")withself.session_maker()assession:yieldcast(SQLSession,session)@contextlib.asynccontextmanagerasyncdef_make_async_session(self)->AsyncGenerator[AsyncSession,None]:"""Make an async session."""ifnotself.async_mode:raiseValueError("Attempting to use an async method in when sync mode is turned on. ""Please use the corresponding async method instead.")asyncwithself.session_maker()assession:yieldcast(AsyncSession,session)