Source code for langchain_google_vertexai.evaluators.evaluation

from abc import ABC
from typing import Any, Dict, List, Optional, Sequence

from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform_v1beta1 import (
    EvaluationServiceAsyncClient,
    EvaluationServiceClient,
)
from google.cloud.aiplatform_v1beta1.types import (
    EvaluateInstancesRequest,
    EvaluateInstancesResponse,
)
from google.protobuf.json_format import MessageToDict

from langchain_google_vertexai._utils import (
    get_client_info,
    get_user_agent,
)
from langchain_google_vertexai.evaluators._core import (
    PairwiseStringEvaluator,
    StringEvaluator,
)

_METRICS = [
    "bleu",
    "exact_match",
    "rouge",
    "coherence",
    "fluency",
    "safety",
    "groundedness",
    "fulfillment",
    "summarization_quality",
    "summarization_helpfulness",
    "summarization_verbosity",
    "question_answering_quality",
    "question_answering_relevance",
    "question_answering_correctness",
]
_PAIRWISE_METRICS = [
    "pairwise_question_answering_quality",
    "pairwise_summarization_quality",
]
_METRICS_INPUTS = {
    "rouge1": {"rouge_type": "rouge1"},
    "rouge2": {"rouge_type": "rouge2"},
    "rougeL": {"rouge_type": "rougeL"},
    "rougeLsum": {"rouge_type": "rougeLsum"},
}
_METRICS_ATTRS = {
    "safety": ["prediction"],
    "coherence": ["prediction"],
    "fluency": ["prediction"],
    "groundedness": ["context", "prediction"],
    "fulfillment": ["prediction", "instruction"],
    "summarization_quality": ["prediction", "instruction", "context"],
    "summarization_helpfulness": ["prediction", "context"],
    "summarization_verbosity": ["prediction", "context"],
    "question_answering_quality": ["prediction", "context", "instruction"],
    "question_answering_relevance": ["prediction", "instruction"],
    "question_answering_correctness": ["prediction", "instruction"],
    "pairwise_question_answering_quality": [
        "prediction",
        "baseline_prediction",
        "context",
        "instruction",
    ],
    "pairwise_summarization_quality": [
        "prediction",
        "baseline_prediction",
        "context",
        "instruction",
    ],
}
_METRICS_OPTIONAL_ATTRS = {
    "summarization_quality": ["reference"],
    "summarization_helpfulness": ["reference", "instruction"],
    "summarization_verbosity": ["reference", "instruction"],
    "question_answering_quality": ["reference"],
    "question_answering_relevance": ["reference", "context"],
    "question_answering_correctness": ["reference", "context"],
    "pairwise_question_answering_quality": ["reference"],
    "pairwise_summarization_quality": ["reference"],
}
# a client supports multiple instances per request for these metrics
_METRICS_MULTIPLE_INSTANCES = ["bleu", "exact_match", "rouge"]


def _format_metric(metric: str) -> str:
    if metric.startswith("rouge"):
        return "rouge"
    return metric


def _format_instance(instance: Dict[str, str], metric: str) -> Dict[str, str]:
    attrs = _METRICS_ATTRS.get(metric, ["prediction", "reference"])
    result = {a: instance[a] for a in attrs}
    for attr in _METRICS_OPTIONAL_ATTRS.get(metric, []):
        if attr in instance:
            result[attr] = instance[attr]
    return result


def _prepare_request(
    instances: Sequence[Dict[str, str]], metric: str, location: str
) -> EvaluateInstancesRequest:
    request = EvaluateInstancesRequest()
    metric_input: Dict[str, Any] = {"metric_spec": _METRICS_INPUTS.get(metric, {})}
    if _format_metric(metric) not in _METRICS_MULTIPLE_INSTANCES:
        if len(instances) > 1:
            raise ValueError(
                f"Metric {metric} supports only a single instance per request, "
                f"got {len(instances)}!"
            )
        metric_input["instance"] = _format_instance(instances[0], metric=metric)
    else:
        metric_input["instances"] = [
            _format_instance(i, metric=metric) for i in instances
        ]
    setattr(request, f"{_format_metric(metric)}_input", metric_input)
    request.location = location
    return request


def _parse_response(
    response: EvaluateInstancesResponse, metric: str
) -> List[Dict[str, Any]]:
    metric = _format_metric(metric)
    result = MessageToDict(response._pb, preserving_proto_field_name=True)
    if metric in _METRICS_MULTIPLE_INSTANCES:
        return result[f"{metric}_results"][f"{metric}_metric_values"]
    return [result[f"{metric}_result"]]


class _EvaluatorBase(ABC):
    @property
    def _user_agent(self) -> str:
        """Gets the User Agent."""
        _, user_agent = get_user_agent(f"{type(self).__name__}_{self._metric}")
        return user_agent

    def __init__(self, metric: str, project_id: str, location: str = "us-central1"):
        self._metric = metric
        client_options = ClientOptions(
            api_endpoint=f"{location}-{constants.PREDICTION_API_BASE_PATH}"
        )
        self._client = EvaluationServiceClient(
            client_options=client_options,
            client_info=get_client_info(module=self._user_agent),
        )
        self._async_client = EvaluationServiceAsyncClient(
            client_options=client_options,
            client_info=get_client_info(module=self._user_agent),
        )
        self._location = self._client.common_location_path(project_id, location)

    def _prepare_request(
        self,
        prediction: str,
        reference: Optional[str] = None,
        input: Optional[str] = None,
        **kwargs: Any,
    ) -> EvaluateInstancesRequest:
        instance = {"prediction": prediction}
        if reference:
            instance["reference"] = reference
        if input:
            instance["context"] = input
        instance = {**instance, **kwargs}
        return _prepare_request(
            [instance], metric=self._metric, location=self._location
        )


[docs]class VertexStringEvaluator(_EvaluatorBase, StringEvaluator): """Evaluate the perplexity of a predicted string."""
[docs] def __init__(self, metric: str, **kwargs): super().__init__(metric, **kwargs) if _format_metric(metric) not in _METRICS: raise ValueError(f"Metric {metric} is not supported yet!")
def _evaluate_strings( self, *, prediction: str, reference: Optional[str] = None, input: Optional[str] = None, **kwargs: Any, ) -> dict: request = self._prepare_request(prediction, reference, input, **kwargs) response = self._client.evaluate_instances(request) return _parse_response(response, metric=self._metric)[0]
[docs] def evaluate( self, examples: Sequence[Dict[str, str]], predictions: Sequence[Dict[str, str]], *, question_key: str = "context", answer_key: str = "reference", prediction_key: str = "prediction", instruction_key: str = "instruction", **kwargs: Any, ) -> List[dict]: instances: List[dict] = [] for example, prediction in zip(examples, predictions): row = {"prediction": prediction[prediction_key]} if answer_key in example: row["reference"] = example[answer_key] if question_key in example: row["context"] = example[question_key] if instruction_key in example: row["instruction"] = example[instruction_key] instances.append(row) if self._metric in _METRICS_MULTIPLE_INSTANCES: request = _prepare_request( instances, metric=self._metric, location=self._location ) response = self._client.evaluate_instances(request) return _parse_response(response, metric=self._metric) else: return [self._evaluate_strings(**i) for i in instances]
async def _aevaluate_strings( self, *, prediction: str, reference: Optional[str] = None, input: Optional[str] = None, **kwargs: Any, ) -> dict: request = self._prepare_request(prediction, reference, input, **kwargs) response = await self._async_client.evaluate_instances(request) return _parse_response(response, metric=self._metric)[0]
[docs]class VertexPairWiseStringEvaluator(_EvaluatorBase, PairwiseStringEvaluator): """Evaluate the perplexity of a predicted string."""
[docs] def __init__(self, metric: str, **kwargs): super().__init__(metric, **kwargs) if _format_metric(metric) not in _PAIRWISE_METRICS: raise ValueError(f"Metric {metric} is not supported yet!")
def _evaluate_string_pairs( self, *, prediction: str, prediction_b: str, reference: Optional[str] = None, input: Optional[str] = None, **kwargs: Any, ) -> dict: request = self._prepare_request( prediction_b, reference, input, baseline_prediction=prediction, **kwargs ) response = self._client.evaluate_instances(request) return _parse_response(response, metric=self._metric)[0] async def _aevaluate_string_pairs( self, *, prediction: str, prediction_b: str, reference: Optional[str] = None, input: Optional[str] = None, **kwargs: Any, ) -> dict: request = self._prepare_request( prediction_b, reference, input, baseline_prediction=prediction, **kwargs ) response = await self._async_client.evaluate_instances(request) return _parse_response(response, metric=self._metric)[0]