Source code for langchain.chains.elasticsearch_database.base

"""Chain for interacting with Elasticsearch Database."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.output_parsers.json import SimpleJsonOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable

from langchain.chains.base import Chain
from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT

if TYPE_CHECKING:
    from elasticsearch import Elasticsearch

INTERMEDIATE_STEPS_KEY = "intermediate_steps"


[docs]class ElasticsearchDatabaseChain(Chain): """Chain for interacting with Elasticsearch Database. Example: .. code-block:: python from langchain.chains import ElasticsearchDatabaseChain from langchain_community.llms import OpenAI from elasticsearch import Elasticsearch database = Elasticsearch("http://localhost:9200") db_chain = ElasticsearchDatabaseChain.from_llm(OpenAI(), database) """ query_chain: Runnable """Chain for creating the ES query.""" answer_chain: Runnable """Chain for answering the user question.""" database: Any """Elasticsearch database to connect to of type elasticsearch.Elasticsearch.""" top_k: int = 10 """Number of results to return from the query""" ignore_indices: Optional[List[str]] = None include_indices: Optional[List[str]] = None input_key: str = "question" #: :meta private: output_key: str = "result" #: :meta private: sample_documents_in_index_info: int = 3 return_intermediate_steps: bool = False """Whether or not to return the intermediate steps along with the final answer.""" class Config: arbitrary_types_allowed = True extra = "forbid" @root_validator(pre=False, skip_on_failure=True) def validate_indices(cls, values: dict) -> dict: if values["include_indices"] and values["ignore_indices"]: raise ValueError( "Cannot specify both 'include_indices' and 'ignore_indices'." ) return values @property def input_keys(self) -> List[str]: """Return the singular input key. :meta private: """ return [self.input_key] @property def output_keys(self) -> List[str]: """Return the singular output key. :meta private: """ if not self.return_intermediate_steps: return [self.output_key] else: return [self.output_key, INTERMEDIATE_STEPS_KEY] def _list_indices(self) -> List[str]: all_indices = [ index["index"] for index in self.database.cat.indices(format="json") ] if self.include_indices: all_indices = [i for i in all_indices if i in self.include_indices] if self.ignore_indices: all_indices = [i for i in all_indices if i not in self.ignore_indices] return all_indices def _get_indices_infos(self, indices: List[str]) -> str: mappings = self.database.indices.get_mapping(index=",".join(indices)) if self.sample_documents_in_index_info > 0: for k, v in mappings.items(): hits = self.database.search( index=k, query={"match_all": {}}, size=self.sample_documents_in_index_info, )["hits"]["hits"] hits = [str(hit["_source"]) for hit in hits] mappings[k]["mappings"] = str(v) + "\n\n/*\n" + "\n".join(hits) + "\n*/" return "\n\n".join( [ "Mapping for index {}:\n{}".format(index, mappings[index]["mappings"]) for index in mappings ] ) def _search(self, indices: List[str], query: str) -> str: result = self.database.search(index=",".join(indices), body=query) return str(result) def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() input_text = f"{inputs[self.input_key]}\nESQuery:" _run_manager.on_text(input_text, verbose=self.verbose) indices = self._list_indices() indices_info = self._get_indices_infos(indices) query_inputs: dict = { "input": input_text, "top_k": str(self.top_k), "indices_info": indices_info, "stop": ["\nESResult:"], } intermediate_steps: List = [] try: intermediate_steps.append(query_inputs) # input: es generation es_cmd = self.query_chain.invoke( query_inputs, config={"callbacks": _run_manager.get_child()}, ) _run_manager.on_text(es_cmd, color="green", verbose=self.verbose) intermediate_steps.append( es_cmd ) # output: elasticsearch dsl generation (no checker) intermediate_steps.append({"es_cmd": es_cmd}) # input: ES search result = self._search(indices=indices, query=es_cmd) intermediate_steps.append(str(result)) # output: ES search _run_manager.on_text("\nESResult: ", verbose=self.verbose) _run_manager.on_text(result, color="yellow", verbose=self.verbose) _run_manager.on_text("\nAnswer:", verbose=self.verbose) answer_inputs: dict = {"data": result, "input": input_text} intermediate_steps.append(answer_inputs) # input: final answer final_result = self.answer_chain.invoke( answer_inputs, config={"callbacks": _run_manager.get_child()}, ) intermediate_steps.append(final_result) # output: final answer _run_manager.on_text(final_result, color="green", verbose=self.verbose) chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps return chain_result except Exception as exc: # Append intermediate steps to exception, to aid in logging and later # improvement of few shot prompt seeds exc.intermediate_steps = intermediate_steps # type: ignore raise exc @property def _chain_type(self) -> str: return "elasticsearch_database_chain"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, database: Elasticsearch, *, query_prompt: Optional[BasePromptTemplate] = None, answer_prompt: Optional[BasePromptTemplate] = None, query_output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> ElasticsearchDatabaseChain: """Convenience method to construct ElasticsearchDatabaseChain from an LLM. Args: llm: The language model to use. database: The Elasticsearch db. query_prompt: The prompt to use for query construction. answer_prompt: The prompt to use for answering user question given data. query_output_parser: The output parser to use for parsing model-generated ES query. Defaults to SimpleJsonOutputParser. kwargs: Additional arguments to pass to the constructor. """ query_prompt = query_prompt or DSL_PROMPT query_output_parser = query_output_parser or SimpleJsonOutputParser() query_chain = query_prompt | llm | query_output_parser answer_prompt = answer_prompt or ANSWER_PROMPT answer_chain = answer_prompt | llm | StrOutputParser() return cls( query_chain=query_chain, answer_chain=answer_chain, database=database, **kwargs, )