importloggingfromtypingimportListfromlangchain_core.callbacksimport(AsyncCallbackManagerForRetrieverRun,CallbackManagerForRetrieverRun,)fromlangchain_core.documentsimportDocumentfromlangchain_core.language_modelsimportBaseLLMfromlangchain_core.output_parsersimportStrOutputParserfromlangchain_core.promptsimportBasePromptTemplatefromlangchain_core.prompts.promptimportPromptTemplatefromlangchain_core.retrieversimportBaseRetrieverfromlangchain_core.runnablesimportRunnablelogger=logging.getLogger(__name__)# Default templateDEFAULT_TEMPLATE="""You are an assistant tasked with taking a natural language \query from a user and converting it into a query for a vectorstore. \In this process, you strip out information that is not relevant for \the retrieval task. Here is the user query: {question}"""# Default promptDEFAULT_QUERY_PROMPT=PromptTemplate.from_template(DEFAULT_TEMPLATE)
[docs]classRePhraseQueryRetriever(BaseRetriever):"""Given a query, use an LLM to re-phrase it. Then, retrieve docs for the re-phrased query."""retriever:BaseRetrieverllm_chain:Runnable
[docs]@classmethoddeffrom_llm(cls,retriever:BaseRetriever,llm:BaseLLM,prompt:BasePromptTemplate=DEFAULT_QUERY_PROMPT,)->"RePhraseQueryRetriever":"""Initialize from llm using default template. The prompt used here expects a single input: `question` Args: retriever: retriever to query documents from llm: llm for query generation using DEFAULT_QUERY_PROMPT prompt: prompt template for query generation Returns: RePhraseQueryRetriever """llm_chain=prompt|llm|StrOutputParser()returncls(retriever=retriever,llm_chain=llm_chain,)
def_get_relevant_documents(self,query:str,*,run_manager:CallbackManagerForRetrieverRun,)->List[Document]:"""Get relevant documents given a user question. Args: query: user question Returns: Relevant documents for re-phrased question """re_phrased_question=self.llm_chain.invoke(query,{"callbacks":run_manager.get_child()})logger.info(f"Re-phrased question: {re_phrased_question}")docs=self.retriever.invoke(re_phrased_question,config={"callbacks":run_manager.get_child()})returndocsasyncdef_aget_relevant_documents(self,query:str,*,run_manager:AsyncCallbackManagerForRetrieverRun,)->List[Document]:raiseNotImplementedError