Source code for langchain_elasticsearch.retrievers

import logging
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Mapping,
    Optional,
    Sequence,
    Union,
    cast,
)

from elasticsearch import Elasticsearch
from langchain_core.callbacks import (
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client

logger = logging.getLogger(__name__)


[docs] class ElasticsearchRetriever(BaseRetriever): """ Elasticsearch retriever Args: es_client: Elasticsearch client connection. Alternatively you can use the `from_es_params` method with parameters to initialize the client. index_name: The name of the index to query. Can also be a list of names. body_func: Function to create an Elasticsearch DSL query body from a search string. The returned query body must fit what you would normally send in a POST request the the _search endpoint. If applicable, it also includes parameters the `size` parameter etc. content_field: The document field name that contains the page content. If multiple indices are queried, specify a dict {index_name: field_name} here. document_mapper: Function to map Elasticsearch hits to LangChain Documents. """ _expects_other_args = True es_client: Elasticsearch index_name: Union[str, Sequence[str]] body_func: Callable[[str], Dict] content_field: Optional[Union[str, Mapping[str, str]]] = None document_mapper: Optional[Callable[[Mapping], Document]] = None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) if self.content_field is None and self.document_mapper is None: raise ValueError("One of content_field or document_mapper must be defined.") if self.content_field is not None and self.document_mapper is not None: raise ValueError( "Both content_field and document_mapper are defined. " "Please provide only one." ) if not self.document_mapper: if isinstance(self.content_field, str): self.document_mapper = self._single_field_mapper elif isinstance(self.content_field, Mapping): self.document_mapper = self._multi_field_mapper else: raise ValueError( "unknown type for content_field, expected string or dict." ) self.es_client = with_user_agent_header(self.es_client, "langchain-py-r")
[docs] @staticmethod def from_es_params( index_name: Union[str, Sequence[str]], body_func: Callable[[str], Dict], content_field: Optional[Union[str, Mapping[str, str]]] = None, document_mapper: Optional[Callable[[Mapping], Document]] = None, url: Optional[str] = None, cloud_id: Optional[str] = None, api_key: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> "ElasticsearchRetriever": client = None try: client = create_elasticsearch_client( url=url, cloud_id=cloud_id, api_key=api_key, username=username, password=password, params=params, ) except Exception as err: logger.error(f"Error connecting to Elasticsearch: {err}") raise err return ElasticsearchRetriever( es_client=client, index_name=index_name, body_func=body_func, content_field=content_field, document_mapper=document_mapper, )
def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> List[Document]: if not self.es_client or not self.document_mapper: raise ValueError("faulty configuration") # should not happen body = self.body_func(query, **kwargs) results = self.es_client.search(index=self.index_name, body=body) return [self.document_mapper(hit) for hit in results["hits"]["hits"]] def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: content = hit["_source"].pop(self.content_field) return Document(page_content=content, metadata=hit) def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: self.content_field = cast(Mapping, self.content_field) field = self.content_field[hit["_index"]] content = hit["_source"].pop(field) return Document(page_content=content, metadata=hit)