Source code for langchain_google_community.vertex_rank
import warnings
from typing import TYPE_CHECKING, Any, Optional, Sequence
from google.api_core import exceptions as core_exceptions # type: ignore
from google.auth.credentials import Credentials # type: ignore
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.documents.compressor import BaseDocumentCompressor
from pydantic import ConfigDict, Field
from langchain_google_community._utils import get_client_info
if TYPE_CHECKING:
from google.cloud import discoveryengine_v1alpha # type: ignore
if TYPE_CHECKING:
from google.cloud import discoveryengine_v1alpha # type: ignore
[docs]
class VertexAIRank(BaseDocumentCompressor):
"""
Initializes the Vertex AI Ranker with configurable parameters.
Inherits from BaseDocumentCompressor for document processing
and validation features, respectively.
Attributes:
project_id (str): Google Cloud project ID
location_id (str): Location ID for the ranking service.
ranking_config (str):
Required. The name of the rank service config, such as default_config.
It is set to default_config by default if unspecified.
model (str):
The identifier of the model to use. It is one of:
- ``semantic-ranker-512@latest``: Semantic ranking model
with maximum input token size 512.
It is set to ``semantic-ranker-512@latest`` by default if unspecified.
top_n (int):
The number of results to return. If this is
unset or no bigger than zero, returns all
results.
ignore_record_details_in_response (bool):
If true, the response will contain only
record ID and score. By default, it is false,
the response will contain record details.
id_field (Optional[str]): Specifies a unique document metadata field
to use as an id.
title_field (Optional[str]): Specifies the document metadata field
to use as title.
credentials (Optional[Credentials]): Google Cloud credentials object.
credentials_path (Optional[str]): Path to the Google Cloud service
account credentials file.
"""
project_id: str = Field(default=None)
location_id: str = Field(default="global")
ranking_config: str = Field(default="default_config")
model: str = Field(default="semantic-ranker-512@latest")
top_n: int = Field(default=10)
ignore_record_details_in_response: bool = Field(default=False)
id_field: Optional[str] = Field(default=None)
title_field: Optional[str] = Field(default=None)
credentials: Optional[Credentials] = Field(default=None)
credentials_path: Optional[str] = Field(default=None)
client: Any = None
def __init__(self, **kwargs: Any):
"""
Constructor for VertexAIRanker, allowing for specification of
ranking configuration and initialization of Google Cloud services.
The parameters accepted are the same as the attributes listed above.
"""
super().__init__(**kwargs)
self.client = kwargs.get("client") # type: ignore
if not self.client:
self.client = self._get_rank_service_client()
def _get_rank_service_client(self) -> "discoveryengine_v1alpha.RankServiceClient":
"""
Returns a RankServiceClient instance for making API calls to the
Vertex AI Ranking service.
Returns:
A RankServiceClient instance.
"""
try:
from google.cloud import discoveryengine_v1alpha # type: ignore
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
"Please, install vertexaisearch dependency group: "
"`pip install langchain-google-community[vertexaisearch]`"
) from exc
return discoveryengine_v1alpha.RankServiceClient(
credentials=(
self.credentials
or Credentials.from_service_account_file(self.credentials_path) # type: ignore[attr-defined]
if self.credentials_path
else None
),
client_info=get_client_info(module="vertex-ai-search"),
)
def _rerank_documents(
self, query: str, documents: Sequence[Document]
) -> Sequence[Document]:
"""
Reranks documents based on the provided query.
Args:
query: The query to use for reranking.
documents: The list of documents to rerank.
Returns:
A list of reranked documents.
"""
from google.cloud import discoveryengine_v1alpha # type: ignore
try:
records = [
discoveryengine_v1alpha.RankingRecord(
id=(doc.metadata.get(self.id_field) if self.id_field else str(idx)),
content=doc.page_content,
**(
{"title": doc.metadata.get(self.title_field)}
if self.title_field
else {}
),
)
for idx, doc in enumerate(documents)
if doc.page_content
or (self.title_field and doc.metadata.get(self.title_field))
]
except KeyError:
warnings.warn(f"id_field '{self.id_field}' not found in document metadata.")
ranking_config_path = (
f"projects/{self.project_id}/locations/{self.location_id}"
f"/rankingConfigs/{self.ranking_config}"
)
request = discoveryengine_v1alpha.RankRequest(
ranking_config=ranking_config_path,
model=self.model,
query=query,
records=records,
top_n=self.top_n,
ignore_record_details_in_response=self.ignore_record_details_in_response,
)
try:
response = self.client.rank(request=request)
except core_exceptions.GoogleAPICallError as e:
print(f"Error in Vertex AI Ranking API call: {str(e)}")
raise RuntimeError(f"Error in Vertex AI Ranking API call: {str(e)}") from e
return [
Document(
page_content=record.content
if not self.ignore_record_details_in_response
else "",
metadata={
"id": record.id,
"relevance_score": record.score,
**({self.title_field: record.title} if self.title_field else {}),
},
)
for record in response.records
]
[docs]
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compresses documents using Vertex AI's rerank API.
Args:
documents: List of Document instances to compress.
query: Query string to use for compressing the documents.
callbacks: Callbacks to execute during compression (not used here).
Returns:
A list of Document instances, compressed.
"""
return self._rerank_documents(query, documents)
model_config = ConfigDict(
extra="ignore",
arbitrary_types_allowed=True,
)