[docs]classHypotheticalDocumentEmbedder(Chain,Embeddings):"""Generate hypothetical document for query, and then embed that. Based on https://arxiv.org/abs/2212.10496 """base_embeddings:Embeddingsllm_chain:Runnablemodel_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@propertydefinput_keys(self)->List[str]:"""Input keys for Hyde's LLM chain."""returnself.llm_chain.input_schema.model_json_schema()["required"]@propertydefoutput_keys(self)->List[str]:"""Output keys for Hyde's LLM chain."""ifisinstance(self.llm_chain,LLMChain):returnself.llm_chain.output_keyselse:return["text"]
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Call the base embeddings."""returnself.base_embeddings.embed_documents(texts)
[docs]defcombine_embeddings(self,embeddings:List[List[float]])->List[float]:"""Combine embeddings into final embeddings."""try:importnumpyasnpreturnlist(np.array(embeddings).mean(axis=0))exceptImportError:logger.warning("NumPy not found in the current Python environment. ""HypotheticalDocumentEmbedder will use a pure Python implementation ""for internal calculations, which may significantly impact ""performance, especially for large datasets. For optimal speed and ""efficiency, consider installing NumPy: pip install numpy")ifnotembeddings:return[]num_vectors=len(embeddings)return[sum(dim_values)/num_vectorsfordim_valuesinzip(*embeddings)]
[docs]defembed_query(self,text:str)->List[float]:"""Generate a hypothetical document and embedded it."""var_name=self.input_keys[0]result=self.llm_chain.invoke({var_name:text})ifisinstance(self.llm_chain,LLMChain):documents=[result[self.output_keys[0]]]else:documents=[result]embeddings=self.embed_documents(documents)returnself.combine_embeddings(embeddings)
def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,str]:"""Call the internal llm chain."""_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()returnself.llm_chain.invoke(inputs,config={"callbacks":_run_manager.get_child()})
[docs]@classmethoddeffrom_llm(cls,llm:BaseLanguageModel,base_embeddings:Embeddings,prompt_key:Optional[str]=None,custom_prompt:Optional[BasePromptTemplate]=None,**kwargs:Any,)->HypotheticalDocumentEmbedder:"""Load and use LLMChain with either a specific prompt key or custom prompt."""ifcustom_promptisnotNone:prompt=custom_promptelifprompt_keyisnotNoneandprompt_keyinPROMPT_MAP:prompt=PROMPT_MAP[prompt_key]else:raiseValueError(f"Must specify prompt_key if custom_prompt not provided. Should be one "f"of {list(PROMPT_MAP.keys())}.")llm_chain=prompt|llm|StrOutputParser()returncls(base_embeddings=base_embeddings,llm_chain=llm_chain,**kwargs)