[docs]classUpstageEmbeddings(BaseModel,Embeddings):"""UpstageEmbeddings embedding model. To use, set the environment variable `UPSTAGE_API_KEY` with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_upstage import UpstageEmbeddings model = UpstageEmbeddings(model='solar-embedding-1-large') """client:Any=Field(default=None,exclude=True)#: :meta private:async_client:Any=Field(default=None,exclude=True)#: :meta private:model:str=Field(...)"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`. Instead, use 'solar-embedding-1-large' for example. """dimensions:Optional[int]=None"""The number of dimensions the resulting output embeddings should have. Not yet supported. """upstage_api_key:SecretStr=Field(default_factory=secret_from_env("UPSTAGE_API_KEY",error_message=("You must specify an api key. ""You can pass it an argument as `api_key=...` or ""set the environment variable `UPSTAGE_API_KEY`."),),alias="api_key",)"""Automatically inferred from env are `UPSTAGE_API_KEY` if not provided."""upstage_api_base:Optional[str]=Field(default_factory=from_env("UPSTAGE_API_BASE",default="https://api.upstage.ai/v1/solar"),alias="base_url",)"""Endpoint URL to use."""embedding_ctx_length:int=4096"""The maximum number of tokens to embed at once. Not yet supported. """embed_batch_size:int=DEFAULT_EMBED_BATCH_SIZEallowed_special:Union[Literal["all"],Set[str]]=set()"""Not yet supported."""disallowed_special:Union[Literal["all"],Set[str],Sequence[str]]="all""""Not yet supported."""chunk_size:int=1000"""Maximum number of texts to embed in each batch. Not yet supported. """max_retries:int=2"""Maximum number of retries to make when generating."""request_timeout:Optional[Union[float,Tuple[float,float],Any]]=Field(default=None,alias="timeout")"""Timeout for requests to Upstage embedding API. Can be float, httpx.Timeout or None."""show_progress_bar:bool=False"""Whether to show a progress bar when embedding. Not yet supported. """model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for `create` call not explicitly specified."""skip_empty:bool=False"""Whether to skip empty strings when embedding or raise an error. Defaults to not skipping. Not yet supported."""default_headers:Union[Mapping[str,str],None]=Nonedefault_query:Union[Mapping[str,object],None]=None# Configure a custom httpx client. See the# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.http_client:Union[Any,None]=None"""Optional httpx.Client. Only used for sync invocations. Must specify http_async_client as well if you'd like a custom client for async invocations. """http_async_client:Union[Any,None]=None"""Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations."""model_config=ConfigDict(extra="forbid",populate_by_name=True,protected_namespaces=(),)@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")iffield_namenotinall_required_field_names:warnings.warn(f"""WARNING! {field_name} is not default parameter.{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)invalid_model_kwargs=all_required_field_names.intersection(extra.keys())ifinvalid_model_kwargs:raiseValueError(f"Parameters {invalid_model_kwargs} should be specified explicitly. "f"Instead they were passed in as part of `model_kwargs` parameter.")values["model_kwargs"]=extrareturnvalues@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that api key and python package exists in environment."""client_params:dict={"api_key":(self.upstage_api_key.get_secret_value()ifself.upstage_api_keyelseNone),"base_url":self.upstage_api_base,"timeout":self.request_timeout,"max_retries":self.max_retries,"default_headers":self.default_headers,"default_query":self.default_query,}ifnot(self.clientorNone):sync_specific:dict={"http_client":self.http_client}self.client=openai.OpenAI(**client_params,**sync_specific).embeddingsifnot(self.async_clientorNone):async_specific:dict={"http_client":self.http_async_client}self.async_client=openai.AsyncOpenAI(**client_params,**async_specific).embeddingsreturnself@propertydef_invocation_params(self)->Dict[str,Any]:self.model=self.model.replace("-query","").replace("-passage","")params:Dict={"model":self.model,**self.model_kwargs}ifself.dimensionsisnotNone:params["dimensions"]=self.dimensionsreturnparams
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts using passage model. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """assert(self.embed_batch_size<=MAX_EMBED_BATCH_SIZE),f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."ifnottexts:return[]params=self._invocation_paramsparams["model"]=params["model"]+"-passage"embeddings=[]batch_size=min(self.embed_batch_size,len(texts))foriinrange(0,len(texts),batch_size):batch=texts[i:i+batch_size]data=self.client.create(input=batch,**params).dataembeddings.extend([r.embeddingforrindata])returnembeddings
[docs]defembed_query(self,text:str)->List[float]:"""Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """params=self._invocation_paramsparams["model"]=params["model"]+"-query"response=self.client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()returnresponse["data"][0]["embedding"]
[docs]asyncdefaembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts using passage model asynchronously. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """assert(self.embed_batch_size<=MAX_EMBED_BATCH_SIZE),f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."ifnottexts:return[]params=self._invocation_paramsparams["model"]=params["model"]+"-passage"embeddings=[]batch_size=min(self.embed_batch_size,len(texts))foriinrange(0,len(texts),batch_size):batch=texts[i:i+batch_size]response=awaitself.async_client.create(input=batch,**params)embeddings.extend([r.embeddingforrinresponse.data])returnembeddings
[docs]asyncdefaembed_query(self,text:str)->List[float]:"""Asynchronous Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """params=self._invocation_paramsparams["model"]=params["model"]+"-query"response=awaitself.async_client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()returnresponse["data"][0]["embedding"]