Source code for langchain.retrievers.contextual_compression
from typing import Any
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.retrievers import BaseRetriever, RetrieverLike
from pydantic import ConfigDict
from typing_extensions import override
[docs]
class ContextualCompressionRetriever(BaseRetriever):
"""Retriever that wraps a base retriever and compresses the results."""
base_compressor: BaseDocumentCompressor
"""Compressor for compressing retrieved documents."""
base_retriever: RetrieverLike
"""Base Retriever to use for getting relevant documents."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@override
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]:
docs = self.base_retriever.invoke(
query,
config={"callbacks": run_manager.get_child()},
**kwargs,
)
if docs:
compressed_docs = self.base_compressor.compress_documents(
docs,
query,
callbacks=run_manager.get_child(),
)
return list(compressed_docs)
return []
@override
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]:
docs = await self.base_retriever.ainvoke(
query,
config={"callbacks": run_manager.get_child()},
**kwargs,
)
if docs:
compressed_docs = await self.base_compressor.acompress_documents(
docs,
query,
callbacks=run_manager.get_child(),
)
return list(compressed_docs)
return []