Source code for langchain.retrievers.document_compressors.cross_encoder_rerank
from __future__ import annotations
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from pydantic import ConfigDict
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
[docs]
class CrossEncoderReranker(BaseDocumentCompressor):
"""Document compressor that uses CrossEncoder for reranking."""
model: BaseCrossEncoder
"""CrossEncoder model to use for scoring similarity
between the query and documents."""
top_n: int = 3
"""Number of documents to return."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
[docs]
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Rerank documents using CrossEncoder.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
scores = self.model.score([(query, doc.page_content) for doc in documents])
docs_with_scores = list(zip(documents, scores))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
return [doc for doc, _ in result[: self.top_n]]