Source code for langchain_community.embeddings.sambanova

import json
from typing import Dict, Generator, List, Optional

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env, pre_init
from pydantic import BaseModel


[docs] class SambaStudioEmbeddings(BaseModel, Embeddings): """SambaNova embedding models. To use, you should have the environment variables ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI`` ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY`` set with your personal sambastudio variable or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.embeddings import SambaStudioEmbeddings embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, sambastudio_embeddings_base_uri=base_uri, sambastudio_embeddings_project_id=project_id, sambastudio_embeddings_endpoint_id=endpoint_id, sambastudio_embeddings_api_key=api_key, batch_size=32) (or) embeddings = SambaStudioEmbeddings(batch_size=32) (or) # CoE example embeddings = SambaStudioEmbeddings( batch_size=1, model_kwargs={ 'select_expert':'e5-mistral-7b-instruct' } ) """ sambastudio_embeddings_base_url: str = "" """Base url to use""" sambastudio_embeddings_base_uri: str = "" """endpoint base uri""" sambastudio_embeddings_project_id: str = "" """Project id on sambastudio for model""" sambastudio_embeddings_endpoint_id: str = "" """endpoint id on sambastudio for model""" sambastudio_embeddings_api_key: str = "" """sambastudio api key""" model_kwargs: dict = {} """Key word arguments to pass to the model.""" batch_size: int = 32 """Batch size for the embedding models"""
[docs] @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["sambastudio_embeddings_base_url"] = get_from_dict_or_env( values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL" ) values["sambastudio_embeddings_base_uri"] = get_from_dict_or_env( values, "sambastudio_embeddings_base_uri", "SAMBASTUDIO_EMBEDDINGS_BASE_URI", default="api/predict/generic", ) values["sambastudio_embeddings_project_id"] = get_from_dict_or_env( values, "sambastudio_embeddings_project_id", "SAMBASTUDIO_EMBEDDINGS_PROJECT_ID", ) values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env( values, "sambastudio_embeddings_endpoint_id", "SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID", ) values["sambastudio_embeddings_api_key"] = get_from_dict_or_env( values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY" ) return values
def _get_tuning_params(self) -> str: """ Get the tuning parameters to use when calling the model Returns: The tuning parameters as a JSON string. """ if "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri: tuning_params_dict = self.model_kwargs else: tuning_params_dict = { k: {"type": type(v).__name__, "value": str(v)} for k, v in (self.model_kwargs.items()) } tuning_params = json.dumps(tuning_params_dict) return tuning_params def _get_full_url(self, path: str) -> str: """ Return the full API URL for a given path. :param str path: the sub-path :returns: the full API URL for the sub-path :rtype: str """ return f"{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}" # noqa: E501 def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: """Generator for creating batches in the embed documents method Args: texts (List[str]): list of strings to embed batch_size (int, optional): batch size to be used for the embedding model. Will depend on the RDU endpoint used. Yields: List[str]: list (batch) of strings of size batch size """ for i in range(0, len(texts), batch_size): yield texts[i : i + batch_size]
[docs] def embed_documents( self, texts: List[str], batch_size: Optional[int] = None ) -> List[List[float]]: """Returns a list of embeddings for the given sentences. Args: texts (`List[str]`): List of texts to encode batch_size (`int`): Batch size for the encoding Returns: `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences """ if batch_size is None: batch_size = self.batch_size http_session = requests.Session() url = self._get_full_url( f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" ) params = json.loads(self._get_tuning_params()) embeddings = [] if "api/predict/nlp" in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): data = {"inputs": batch, "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response.status_code}.\n Details: {response.text}" ) try: embedding = response.json()["data"] embeddings.extend(embedding) except KeyError: raise KeyError( "'data' not found in endpoint response", response.json(), ) elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): items = [ {"id": f"item{i}", "value": item} for i, item in enumerate(batch) ] data = {"items": items, "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response.status_code}.\n Details: {response.text}" ) try: embedding = [item["value"] for item in response.json()["items"]] embeddings.extend(embedding) except KeyError: raise KeyError( "'items' not found in endpoint response", response.json(), ) elif "api/predict/generic" in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): data = {"instances": batch, "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response.status_code}.\n Details: {response.text}" ) try: if params.get("select_expert"): embedding = response.json()["predictions"][0] else: embedding = response.json()["predictions"] embeddings.extend(embedding) except KeyError: raise KeyError( "'predictions' not found in endpoint response", response.json(), ) else: raise ValueError( f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501 ) return embeddings
[docs] def embed_query(self, text: str) -> List[float]: """Returns a list of embeddings for the given sentences. Args: sentences (`List[str]`): List of sentences to encode Returns: `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences """ http_session = requests.Session() url = self._get_full_url( f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" ) params = json.loads(self._get_tuning_params()) if "api/predict/nlp" in self.sambastudio_embeddings_base_uri: data = {"inputs": [text], "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response.status_code}.\n Details: {response.text}" ) try: embedding = response.json()["data"][0] except KeyError: raise KeyError( "'data' not found in endpoint response", response.json(), ) elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri: data = {"items": [{"id": "item0", "value": text}], "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response.status_code}.\n Details: {response.text}" ) try: embedding = response.json()["items"][0]["value"] except KeyError: raise KeyError( "'items' not found in endpoint response", response.json(), ) elif "api/predict/generic" in self.sambastudio_embeddings_base_uri: data = {"instances": [text], "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response.status_code}.\n Details: {response.text}" ) try: if params.get("select_expert"): embedding = response.json()["predictions"][0][0] else: embedding = response.json()["predictions"][0] except KeyError: raise KeyError( "'predictions' not found in endpoint response", response.json(), ) else: raise ValueError( f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501 ) return embedding