Source code for langchain.chains.qa_with_sources.vector_db
"""Question-answering with sources over a vector database."""importwarningsfromtypingimportAny,Dict,Listfromlangchain_core.callbacksimport(AsyncCallbackManagerForChainRun,CallbackManagerForChainRun,)fromlangchain_core.documentsimportDocumentfromlangchain_core.vectorstoresimportVectorStorefrompydanticimportField,model_validatorfromlangchain.chains.combine_documents.stuffimportStuffDocumentsChainfromlangchain.chains.qa_with_sources.baseimportBaseQAWithSourcesChain
[docs]classVectorDBQAWithSourcesChain(BaseQAWithSourcesChain):"""Question-answering with sources over a vector database."""vectorstore:VectorStore=Field(exclude=True)"""Vector Database to connect to."""k:int=4"""Number of results to return from store"""reduce_k_below_max_tokens:bool=False"""Reduce the number of results to return from store based on tokens limit"""max_tokens_limit:int=3375"""Restrict the docs to return from store based on tokens, enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""search_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Extra search args."""def_reduce_tokens_below_limit(self,docs:List[Document])->List[Document]:num_docs=len(docs)ifself.reduce_k_below_max_tokensandisinstance(self.combine_documents_chain,StuffDocumentsChain):tokens=[self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)fordocindocs]token_count=sum(tokens[:num_docs])whiletoken_count>self.max_tokens_limit:num_docs-=1token_count-=tokens[num_docs]returndocs[:num_docs]def_get_docs(self,inputs:Dict[str,Any],*,run_manager:CallbackManagerForChainRun)->List[Document]:question=inputs[self.question_key]docs=self.vectorstore.similarity_search(question,k=self.k,**self.search_kwargs)returnself._reduce_tokens_below_limit(docs)asyncdef_aget_docs(self,inputs:Dict[str,Any],*,run_manager:AsyncCallbackManagerForChainRun)->List[Document]:raiseNotImplementedError("VectorDBQAWithSourcesChain does not support async")@model_validator(mode="before")@classmethoddefraise_deprecation(cls,values:Dict)->Any:warnings.warn("`VectorDBQAWithSourcesChain` is deprecated - ""please use `from langchain.chains import RetrievalQAWithSourcesChain`")returnvalues@propertydef_chain_type(self)->str:return"vector_db_qa_with_sources_chain"