Source code for langchain.chains.hyde.base

"""Hypothetical Document Embeddings.

https://arxiv.org/abs/2212.10496
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

import numpy as np
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable
from pydantic import ConfigDict

from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain


[docs] class HypotheticalDocumentEmbedder(Chain, Embeddings): """Generate hypothetical document for query, and then embed that. Based on https://arxiv.org/abs/2212.10496 """ base_embeddings: Embeddings llm_chain: Runnable model_config = ConfigDict( arbitrary_types_allowed=True, extra="forbid", ) @property def input_keys(self) -> List[str]: """Input keys for Hyde's LLM chain.""" return self.llm_chain.input_schema.model_json_schema()["required"] @property def output_keys(self) -> List[str]: """Output keys for Hyde's LLM chain.""" if isinstance(self.llm_chain, LLMChain): return self.llm_chain.output_keys else: return ["text"]
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call the base embeddings.""" return self.base_embeddings.embed_documents(texts)
[docs] def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: """Combine embeddings into final embeddings.""" return list(np.array(embeddings).mean(axis=0))
[docs] def embed_query(self, text: str) -> List[float]: """Generate a hypothetical document and embedded it.""" var_name = self.input_keys[0] result = self.llm_chain.invoke({var_name: text}) if isinstance(self.llm_chain, LLMChain): documents = [result[self.output_keys[0]]] else: documents = [result] embeddings = self.embed_documents(documents) return self.combine_embeddings(embeddings)
def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: """Call the internal llm chain.""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() return self.llm_chain.invoke( inputs, config={"callbacks": _run_manager.get_child()} )
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, base_embeddings: Embeddings, prompt_key: Optional[str] = None, custom_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> HypotheticalDocumentEmbedder: """Load and use LLMChain with either a specific prompt key or custom prompt.""" if custom_prompt is not None: prompt = custom_prompt elif prompt_key is not None and prompt_key in PROMPT_MAP: prompt = PROMPT_MAP[prompt_key] else: raise ValueError( f"Must specify prompt_key if custom_prompt not provided. Should be one " f"of {list(PROMPT_MAP.keys())}." ) llm_chain = prompt | llm | StrOutputParser() return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
@property def _chain_type(self) -> str: return "hyde_chain"