Source code for langchain.retrievers.document_compressors.cross_encoder_rerank

from __future__ import annotations

import operator
from typing import Optional, Sequence

from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from pydantic import ConfigDict

from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder


[docs] class CrossEncoderReranker(BaseDocumentCompressor): """Document compressor that uses CrossEncoder for reranking.""" model: BaseCrossEncoder """CrossEncoder model to use for scoring similarity between the query and documents.""" top_n: int = 3 """Number of documents to return.""" model_config = ConfigDict( arbitrary_types_allowed=True, extra="forbid", )
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Rerank documents using CrossEncoder. Args: documents: A sequence of documents to compress. query: The query to use for compressing the documents. callbacks: Callbacks to run during the compression process. Returns: A sequence of compressed documents. """ scores = self.model.score([(query, doc.page_content) for doc in documents]) docs_with_scores = list(zip(documents, scores)) result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) return [doc for doc, _ in result[: self.top_n]]