Source code for langchain_experimental.retrievers.vector_sql_database

"""Vector SQL Database Chain Retriever"""

from typing import Any, Dict, List

from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain


[docs]class VectorSQLDatabaseChainRetriever(BaseRetriever): """Retriever that uses Vector SQL Database.""" sql_db_chain: VectorSQLDatabaseChain """SQL Database Chain""" page_content_key: str = "content" """column name for page content of documents""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any, ) -> List[Document]: ret: List[Dict[str, Any]] = self.sql_db_chain( query, callbacks=run_manager.get_child(), **kwargs )["result"] return [ Document(page_content=r[self.page_content_key], metadata=r) for r in ret ] async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError