Source code for langchain_ibm.rerank

"""IBM watsonx.ai rerank wrapper."""

from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Any

from ibm_watsonx_ai import APIClient  # type: ignore[import-untyped]
from ibm_watsonx_ai.foundation_models import Rerank  # type: ignore[import-untyped]
from ibm_watsonx_ai.foundation_models.schema import (  # type: ignore[import-untyped]
    RerankParameters,  # noqa: TC002
)
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils.utils import secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import extract_params, resolve_watsonx_credentials

if TYPE_CHECKING:
    from collections.abc import Sequence

    from langchain_core.callbacks import Callbacks


[docs] class WatsonxRerank(BaseDocumentCompressor): """Document compressor that uses `watsonx Rerank API`. ???+ info "Setup" To use, you should have `langchain_ibm` python package installed, and the environment variable `WATSONX_APIKEY` set with your API key, or pass it as a named parameter `apikey` to the constructor. ```bash pip install -U langchain-ibm # or using uv uv add langchain-ibm ``` ```bash export WATSONX_APIKEY="your-api-key" ``` ??? info "Instantiate" ```python from langchain_ibm import WatsonxRerank from ibm_watsonx_ai.foundation_models.schema import RerankParameters parameters = RerankParameters(truncate_input_tokens=20) ranker = WatsonxRerank( model_id="cross-encoder/ms-marco-minilm-l-12-v2", url="https://us-south.ml.cloud.ibm.com", project_id="*****", params=parameters, # apikey="*****" ) ``` ??? info "Rerank" ```python query = "red cat chasing a laser pointer" documents = [ "A red cat darts across the living room, pouncing on a red laser dot.", "Two dogs play fetch in the park with a tennis ball.", "The tabby cat naps on a sunny windowsill all afternoon.", "A recipe for tuna casserole with crispy breadcrumbs.", ] ranker.rerank(documents=documents, query=query) ``` ```python [ {"index": 0, "relevance_score": 0.8719543218612671}, {"index": 2, "relevance_score": 0.6520894169807434}, {"index": 1, "relevance_score": 0.6270776391029358}, {"index": 3, "relevance_score": 0.4607713520526886}, ] ``` """ model_id: str """Type of model to use.""" project_id: str | None = None """ID of the Watson Studio project.""" space_id: str | None = None """ID of the Watson Studio space.""" url: SecretStr = Field( alias="url", default_factory=secret_from_env("WATSONX_URL", default=None), # type: ignore[assignment] ) """URL to the Watson Machine Learning or CPD instance.""" apikey: SecretStr | None = Field( alias="apikey", default_factory=secret_from_env("WATSONX_APIKEY", default=None), ) """API key to the Watson Machine Learning or CPD instance.""" token: SecretStr | None = Field( alias="token", default_factory=secret_from_env("WATSONX_TOKEN", default=None), ) """Token to the CPD instance.""" password: SecretStr | None = Field( alias="password", default_factory=secret_from_env("WATSONX_PASSWORD", default=None), ) """Password to the CPD instance.""" username: SecretStr | None = Field( alias="username", default_factory=secret_from_env("WATSONX_USERNAME", default=None), ) """Username to the CPD instance.""" instance_id: SecretStr | None = Field( alias="instance_id", default_factory=secret_from_env("WATSONX_INSTANCE_ID", default=None), ) """Instance_id of the CPD instance.""" version: SecretStr | None = None """Version of the CPD instance.""" params: dict | RerankParameters | None = None """Model parameters to use during request generation.""" verify: str | bool | None = None """You can pass one of following as verify: * the path to a CA_BUNDLE file * the path of directory with certificates of trusted CAs * True - default path to truststore will be taken * False - no verification will be made""" validate_model: bool = True """Model ID validation.""" streaming: bool = False """ Whether to stream the results or not. """ watsonx_rerank: Rerank = Field(default=None, exclude=True) #: :meta private: watsonx_client: APIClient | None = Field(default=None, exclude=True) model_config = ConfigDict( arbitrary_types_allowed=True, extra="forbid", protected_namespaces=(), ) @property def lc_secrets(self) -> dict[str, str]: """Mapping of secret environment variables.""" return { "url": "WATSONX_URL", "apikey": "WATSONX_APIKEY", "token": "WATSONX_TOKEN", "password": "WATSONX_PASSWORD", "username": "WATSONX_USERNAME", "instance_id": "WATSONX_INSTANCE_ID", } @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that credentials and python package exists in environment.""" if isinstance(self.watsonx_client, APIClient): watsonx_rerank = Rerank( model_id=self.model_id, params=self.params, api_client=self.watsonx_client, project_id=self.project_id, space_id=self.space_id, verify=self.verify, ) self.watsonx_rerank = watsonx_rerank else: credentials = resolve_watsonx_credentials( url=self.url, apikey=self.apikey, token=self.token, password=self.password, username=self.username, instance_id=self.instance_id, version=self.version, verify=self.verify, ) watsonx_rerank = Rerank( model_id=self.model_id, credentials=credentials, params=self.params, project_id=self.project_id, space_id=self.space_id, verify=self.verify, ) self.watsonx_rerank = watsonx_rerank return self
[docs] def rerank( self, documents: Sequence[str | Document | dict], query: str, **kwargs: Any, ) -> list[dict[str, Any]]: """Rerank documents.""" if len(documents) == 0: # to avoid empty api call return [] docs = [ doc.page_content if isinstance(doc, Document) else doc for doc in documents ] params = extract_params(kwargs, self.params) results = self.watsonx_rerank.generate( query=query, inputs=docs, **(kwargs | {"params": params}), ) return [ {"index": res.get("index"), "relevance_score": res.get("score")} for res in results["results"] ]
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Callbacks | None = None, **kwargs: Any, ) -> Sequence[Document]: """Compress documents using watsonx's rerank API. Args: documents: A sequence of documents to compress query: The query to use for compressing the documents callbacks: Callbacks to run during the compression process kwargs: Additional keyword args Returns: A sequence of compressed documents """ compressed = [] for res in self.rerank(documents, query, **kwargs): doc = documents[res["index"]] doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) doc_copy.metadata["relevance_score"] = res["relevance_score"] compressed.append(doc_copy) return compressed