Source code for langchain_upstage.embeddings

from __future__ import annotations

import logging
import warnings
from typing import (
    Any,
    Dict,
    List,
    Literal,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
)

import openai
from langchain_core.embeddings import Embeddings
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    SecretStr,
    model_validator,
)
from typing_extensions import Self

logger = logging.getLogger(__name__)

DEFAULT_EMBED_BATCH_SIZE = 10
MAX_EMBED_BATCH_SIZE = 100


[docs] class UpstageEmbeddings(BaseModel, Embeddings): """UpstageEmbeddings embedding model. To use, set the environment variable `UPSTAGE_API_KEY` with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_upstage import UpstageEmbeddings model = UpstageEmbeddings(model='solar-embedding-1-large') """ client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: model: str = Field(...) """Embeddings model name to use. Do not add suffixes like `-query` and `-passage`. Instead, use 'solar-embedding-1-large' for example. """ dimensions: Optional[int] = None """The number of dimensions the resulting output embeddings should have. Not yet supported. """ upstage_api_key: SecretStr = Field( default_factory=secret_from_env( "UPSTAGE_API_KEY", error_message=( "You must specify an api key. " "You can pass it an argument as `api_key=...` or " "set the environment variable `UPSTAGE_API_KEY`." ), ), alias="api_key", ) """Automatically inferred from env are `UPSTAGE_API_KEY` if not provided.""" upstage_api_base: Optional[str] = Field( default_factory=from_env( "UPSTAGE_API_BASE", default="https://api.upstage.ai/v1/solar" ), alias="base_url", ) """Endpoint URL to use.""" embedding_ctx_length: int = 4096 """The maximum number of tokens to embed at once. Not yet supported. """ embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE allowed_special: Union[Literal["all"], Set[str]] = set() """Not yet supported.""" disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" """Not yet supported.""" chunk_size: int = 1000 """Maximum number of texts to embed in each batch. Not yet supported. """ max_retries: int = 2 """Maximum number of retries to make when generating.""" request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field( default=None, alias="timeout" ) """Timeout for requests to Upstage embedding API. Can be float, httpx.Timeout or None.""" show_progress_bar: bool = False """Whether to show a progress bar when embedding. Not yet supported. """ model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" skip_empty: bool = False """Whether to skip empty strings when embedding or raise an error. Defaults to not skipping. Not yet supported.""" default_headers: Union[Mapping[str, str], None] = None default_query: Union[Mapping[str, object], None] = None # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: Union[Any, None] = None """Optional httpx.Client. Only used for sync invocations. Must specify http_async_client as well if you'd like a custom client for async invocations. """ http_async_client: Union[Any, None] = None """Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations.""" model_config = ConfigDict( extra="forbid", populate_by_name=True, protected_namespaces=(), ) @model_validator(mode="before") @classmethod def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: warnings.warn( f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""" ) extra[field_name] = values.pop(field_name) invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Instead they were passed in as part of `model_kwargs` parameter." ) values["model_kwargs"] = extra return values @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" client_params: dict = { "api_key": ( self.upstage_api_key.get_secret_value() if self.upstage_api_key else None ), "base_url": self.upstage_api_base, "timeout": self.request_timeout, "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } if not (self.client or None): sync_specific: dict = {"http_client": self.http_client} self.client = openai.OpenAI(**client_params, **sync_specific).embeddings if not (self.async_client or None): async_specific: dict = {"http_client": self.http_async_client} self.async_client = openai.AsyncOpenAI( **client_params, **async_specific ).embeddings return self @property def _invocation_params(self) -> Dict[str, Any]: self.model = self.model.replace("-query", "").replace("-passage", "") params: Dict = {"model": self.model, **self.model_kwargs} if self.dimensions is not None: params["dimensions"] = self.dimensions return params
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of document texts using passage model. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ assert ( self.embed_batch_size <= MAX_EMBED_BATCH_SIZE ), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}." if not texts: return [] params = self._invocation_params params["model"] = params["model"] + "-passage" embeddings = [] batch_size = min(self.embed_batch_size, len(texts)) for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] data = self.client.create(input=batch, **params).data embeddings.extend([r.embedding for r in data]) return embeddings
[docs] def embed_query(self, text: str) -> List[float]: """Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """ params = self._invocation_params params["model"] = params["model"] + "-query" response = self.client.create(input=text, **params) if not isinstance(response, dict): response = response.model_dump() return response["data"][0]["embedding"]
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of document texts using passage model asynchronously. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ assert ( self.embed_batch_size <= MAX_EMBED_BATCH_SIZE ), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}." if not texts: return [] params = self._invocation_params params["model"] = params["model"] + "-passage" embeddings = [] batch_size = min(self.embed_batch_size, len(texts)) for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] response = await self.async_client.create(input=batch, **params) embeddings.extend([r.embedding for r in response.data]) return embeddings
[docs] async def aembed_query(self, text: str) -> List[float]: """Asynchronous Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """ params = self._invocation_params params["model"] = params["model"] + "-query" response = await self.async_client.create(input=text, **params) if not isinstance(response, dict): response = response.model_dump() return response["data"][0]["embedding"]