Source code for langchain.chains.retrieval_qa.base
"""Chain for question-answering against a vector database."""from__future__importannotationsimportinspectimportwarningsfromabcimportabstractmethodfromtypingimportAny,Dict,List,Optionalfromlangchain_core._apiimportdeprecatedfromlangchain_core.callbacksimport(AsyncCallbackManagerForChainRun,CallbackManagerForChainRun,Callbacks,)fromlangchain_core.documentsimportDocumentfromlangchain_core.language_modelsimportBaseLanguageModelfromlangchain_core.promptsimportPromptTemplatefromlangchain_core.retrieversimportBaseRetrieverfromlangchain_core.vectorstoresimportVectorStorefrompydanticimportConfigDict,Field,model_validatorfromlangchain.chains.baseimportChainfromlangchain.chains.combine_documents.baseimportBaseCombineDocumentsChainfromlangchain.chains.combine_documents.stuffimportStuffDocumentsChainfromlangchain.chains.llmimportLLMChainfromlangchain.chains.question_answeringimportload_qa_chainfromlangchain.chains.question_answering.stuff_promptimportPROMPT_SELECTOR
[docs]@deprecated(since="0.2.13",removal="1.0",message=("This class is deprecated. Use the `create_retrieval_chain` constructor ""instead. See migration guide here: ""https://python.langchain.com/docs/versions/migrating_chains/retrieval_qa/"),)classBaseRetrievalQA(Chain):"""Base class for question-answering chains."""combine_documents_chain:BaseCombineDocumentsChain"""Chain to use to combine the documents."""input_key:str="query"#: :meta private:output_key:str="result"#: :meta private:return_source_documents:bool=False"""Return the source documents or not."""model_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,extra="forbid",)@propertydefinput_keys(self)->List[str]:"""Input keys. :meta private: """return[self.input_key]@propertydefoutput_keys(self)->List[str]:"""Output keys. :meta private: """_output_keys=[self.output_key]ifself.return_source_documents:_output_keys=_output_keys+["source_documents"]return_output_keys
[docs]@classmethoddeffrom_llm(cls,llm:BaseLanguageModel,prompt:Optional[PromptTemplate]=None,callbacks:Callbacks=None,llm_chain_kwargs:Optional[dict]=None,**kwargs:Any,)->BaseRetrievalQA:"""Initialize from LLM."""_prompt=promptorPROMPT_SELECTOR.get_prompt(llm)llm_chain=LLMChain(llm=llm,prompt=_prompt,callbacks=callbacks,**(llm_chain_kwargsor{}))document_prompt=PromptTemplate(input_variables=["page_content"],template="Context:\n{page_content}")combine_documents_chain=StuffDocumentsChain(llm_chain=llm_chain,document_variable_name="context",document_prompt=document_prompt,callbacks=callbacks,)returncls(combine_documents_chain=combine_documents_chain,callbacks=callbacks,**kwargs,)
[docs]@classmethoddeffrom_chain_type(cls,llm:BaseLanguageModel,chain_type:str="stuff",chain_type_kwargs:Optional[dict]=None,**kwargs:Any,)->BaseRetrievalQA:"""Load chain from chain type."""_chain_type_kwargs=chain_type_kwargsor{}combine_documents_chain=load_qa_chain(llm,chain_type=chain_type,**_chain_type_kwargs)returncls(combine_documents_chain=combine_documents_chain,**kwargs)
@abstractmethoddef_get_docs(self,question:str,*,run_manager:CallbackManagerForChainRun,)->List[Document]:"""Get documents to do question answering over."""def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:"""Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. Example: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()question=inputs[self.input_key]accepts_run_manager=("run_manager"ininspect.signature(self._get_docs).parameters)ifaccepts_run_manager:docs=self._get_docs(question,run_manager=_run_manager)else:docs=self._get_docs(question)# type: ignore[call-arg]answer=self.combine_documents_chain.run(input_documents=docs,question=question,callbacks=_run_manager.get_child())ifself.return_source_documents:return{self.output_key:answer,"source_documents":docs}else:return{self.output_key:answer}@abstractmethodasyncdef_aget_docs(self,question:str,*,run_manager:AsyncCallbackManagerForChainRun,)->List[Document]:"""Get documents to do question answering over."""asyncdef_acall(self,inputs:Dict[str,Any],run_manager:Optional[AsyncCallbackManagerForChainRun]=None,)->Dict[str,Any]:"""Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. Example: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """_run_manager=run_managerorAsyncCallbackManagerForChainRun.get_noop_manager()question=inputs[self.input_key]accepts_run_manager=("run_manager"ininspect.signature(self._aget_docs).parameters)ifaccepts_run_manager:docs=awaitself._aget_docs(question,run_manager=_run_manager)else:docs=awaitself._aget_docs(question)# type: ignore[call-arg]answer=awaitself.combine_documents_chain.arun(input_documents=docs,question=question,callbacks=_run_manager.get_child())ifself.return_source_documents:return{self.output_key:answer,"source_documents":docs}else:return{self.output_key:answer}
[docs]@deprecated(since="0.1.17",removal="1.0",message=("This class is deprecated. Use the `create_retrieval_chain` constructor ""instead. See migration guide here: ""https://python.langchain.com/docs/versions/migrating_chains/retrieval_qa/"),)classRetrievalQA(BaseRetrievalQA):"""Chain for question-answering against an index. This class is deprecated. See below for an example implementation using `create_retrieval_chain`: .. code-block:: python from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI retriever = ... # Your retriever llm = ChatOpenAI() system_prompt = ( "Use the given context to answer the question. " "If you don't know the answer, say you don't know. " "Use three sentence maximum and keep the answer concise. " "Context: {context}" ) prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, prompt) chain = create_retrieval_chain(retriever, question_answer_chain) chain.invoke({"input": query}) Example: .. code-block:: python from langchain_community.llms import OpenAI from langchain.chains import RetrievalQA from langchain_community.vectorstores import FAISS from langchain_core.vectorstores import VectorStoreRetriever retriever = VectorStoreRetriever(vectorstore=FAISS(...)) retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever) """retriever:BaseRetriever=Field(exclude=True)def_get_docs(self,question:str,*,run_manager:CallbackManagerForChainRun,)->List[Document]:"""Get docs."""returnself.retriever.invoke(question,config={"callbacks":run_manager.get_child()})asyncdef_aget_docs(self,question:str,*,run_manager:AsyncCallbackManagerForChainRun,)->List[Document]:"""Get docs."""returnawaitself.retriever.ainvoke(question,config={"callbacks":run_manager.get_child()})@propertydef_chain_type(self)->str:"""Return the chain type."""return"retrieval_qa"
[docs]@deprecated(since="0.2.13",removal="1.0",message=("This class is deprecated. Use the `create_retrieval_chain` constructor ""instead. See migration guide here: ""https://python.langchain.com/docs/versions/migrating_chains/retrieval_qa/"),)classVectorDBQA(BaseRetrievalQA):"""Chain for question-answering against a vector database."""vectorstore:VectorStore=Field(exclude=True,alias="vectorstore")"""Vector Database to connect to."""k:int=4"""Number of documents to query for."""search_type:str="similarity""""Search type to use over vectorstore. `similarity` or `mmr`."""search_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Extra search args."""@model_validator(mode="before")@classmethoddefraise_deprecation(cls,values:Dict)->Any:warnings.warn("`VectorDBQA` is deprecated - ""please use `from langchain.chains import RetrievalQA`")returnvalues@model_validator(mode="before")@classmethoddefvalidate_search_type(cls,values:Dict)->Any:"""Validate search type."""if"search_type"invalues:search_type=values["search_type"]ifsearch_typenotin("similarity","mmr"):raiseValueError(f"search_type of {search_type} not allowed.")returnvaluesdef_get_docs(self,question:str,*,run_manager:CallbackManagerForChainRun,)->List[Document]:"""Get docs."""ifself.search_type=="similarity":docs=self.vectorstore.similarity_search(question,k=self.k,**self.search_kwargs)elifself.search_type=="mmr":docs=self.vectorstore.max_marginal_relevance_search(question,k=self.k,**self.search_kwargs)else:raiseValueError(f"search_type of {self.search_type} not allowed.")returndocsasyncdef_aget_docs(self,question:str,*,run_manager:AsyncCallbackManagerForChainRun,)->List[Document]:"""Get docs."""raiseNotImplementedError("VectorDBQA does not support async")@propertydef_chain_type(self)->str:"""Return the chain type."""return"vector_db_qa"