Source code for langchain.chains.qa_with_sources.retrieval
"""Question-answering with sources over an index."""fromtypingimportAny,Dict,Listfromlangchain_core.callbacksimport(AsyncCallbackManagerForChainRun,CallbackManagerForChainRun,)fromlangchain_core.documentsimportDocumentfromlangchain_core.retrieversimportBaseRetrieverfrompydanticimportFieldfromlangchain.chains.combine_documents.stuffimportStuffDocumentsChainfromlangchain.chains.qa_with_sources.baseimportBaseQAWithSourcesChain
[docs]classRetrievalQAWithSourcesChain(BaseQAWithSourcesChain):"""Question-answering with sources over an index."""retriever:BaseRetriever=Field(exclude=True)"""Index to connect to."""reduce_k_below_max_tokens:bool=False"""Reduce the number of results to return from store based on tokens limit"""max_tokens_limit:int=3375"""Restrict the docs to return from store based on tokens, enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""def_reduce_tokens_below_limit(self,docs:List[Document])->List[Document]:num_docs=len(docs)ifself.reduce_k_below_max_tokensandisinstance(self.combine_documents_chain,StuffDocumentsChain):tokens=[self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)fordocindocs]token_count=sum(tokens[:num_docs])whiletoken_count>self.max_tokens_limit:num_docs-=1token_count-=tokens[num_docs]returndocs[:num_docs]def_get_docs(self,inputs:Dict[str,Any],*,run_manager:CallbackManagerForChainRun)->List[Document]:question=inputs[self.question_key]docs=self.retriever.invoke(question,config={"callbacks":run_manager.get_child()})returnself._reduce_tokens_below_limit(docs)asyncdef_aget_docs(self,inputs:Dict[str,Any],*,run_manager:AsyncCallbackManagerForChainRun)->List[Document]:question=inputs[self.question_key]docs=awaitself.retriever.ainvoke(question,config={"callbacks":run_manager.get_child()})returnself._reduce_tokens_below_limit(docs)@propertydef_chain_type(self)->str:"""Return the chain type."""return"retrieval_qa_with_sources_chain"