"""Wrapper around Together AI's Embeddings API."""importloggingimportwarningsfromtypingimport(Any,Dict,List,Literal,Mapping,Optional,Sequence,Set,Tuple,Union,)importopenaifromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.utilsimportfrom_env,get_pydantic_field_names,secret_from_envfrompydanticimport(BaseModel,ConfigDict,Field,SecretStr,model_validator,)fromtyping_extensionsimportSelflogger=logging.getLogger(__name__)
[docs]classTogetherEmbeddings(BaseModel,Embeddings):"""Together embedding model integration. Setup: Install ``langchain_together`` and set environment variable ``TOGETHER_API_KEY``. .. code-block:: bash pip install -U langchain_together export TOGETHER_API_KEY="your-api-key" Key init args — completion params: model: str Name of Together model to use. Key init args — client params: api_key: Optional[SecretStr] See full list of supported init args and their descriptions in the params section. Instantiate: .. code-block:: python from __module_name__ import TogetherEmbeddings embed = TogetherEmbeddings( model="togethercomputer/m2-bert-80M-8k-retrieval", # api_key="...", # other params... ) Embed single text: .. code-block:: python input_text = "The meaning of life is 42" vector = embed.embed_query(input_text) print(vector[:3]) .. code-block:: python [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Embed multiple texts: .. code-block:: python input_texts = ["Document 1...", "Document 2..."] vectors = embed.embed_documents(input_texts) print(len(vectors)) # The first 3 coordinates for the first vector print(vectors[0][:3]) .. code-block:: python 2 [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Async: .. code-block:: python vector = await embed.aembed_query(input_text) print(vector[:3]) # multiple: # await embed.aembed_documents(input_texts) .. code-block:: python [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188] """client:Any=Field(default=None,exclude=True)#: :meta private:async_client:Any=Field(default=None,exclude=True)#: :meta private:model:str="togethercomputer/m2-bert-80M-8k-retrieval""""Embeddings model name to use. Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example. """dimensions:Optional[int]=None"""The number of dimensions the resulting output embeddings should have. Not yet supported. """together_api_key:Optional[SecretStr]=Field(alias="api_key",default_factory=secret_from_env("TOGETHER_API_KEY",default=None),)"""Together AI API key. Automatically read from env variable `TOGETHER_API_KEY` if not provided. """together_api_base:str=Field(default_factory=from_env("TOGETHER_API_BASE",default="https://api.together.xyz/v1/"),alias="base_url",)"""Endpoint URL to use."""embedding_ctx_length:int=4096"""The maximum number of tokens to embed at once. Not yet supported. """allowed_special:Union[Literal["all"],Set[str]]=set()"""Not yet supported."""disallowed_special:Union[Literal["all"],Set[str],Sequence[str]]="all""""Not yet supported."""chunk_size:int=1000"""Maximum number of texts to embed in each batch. Not yet supported. """max_retries:int=2"""Maximum number of retries to make when generating."""request_timeout:Optional[Union[float,Tuple[float,float],Any]]=Field(default=None,alias="timeout")"""Timeout for requests to Together embedding API. Can be float, httpx.Timeout or None."""show_progress_bar:bool=False"""Whether to show a progress bar when embedding. Not yet supported. """model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for `create` call not explicitly specified."""skip_empty:bool=False"""Whether to skip empty strings when embedding or raise an error. Defaults to not skipping. Not yet supported."""default_headers:Union[Mapping[str,str],None]=Nonedefault_query:Union[Mapping[str,object],None]=None# Configure a custom httpx client. See the# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.http_client:Union[Any,None]=None"""Optional httpx.Client. Only used for sync invocations. Must specify http_async_client as well if you'd like a custom client for async invocations. """http_async_client:Union[Any,None]=None"""Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations."""model_config=ConfigDict(extra="forbid",populate_by_name=True,protected_namespaces=(),)@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")iffield_namenotinall_required_field_names:warnings.warn(f"""WARNING! {field_name} is not default parameter.{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)invalid_model_kwargs=all_required_field_names.intersection(extra.keys())ifinvalid_model_kwargs:raiseValueError(f"Parameters {invalid_model_kwargs} should be specified explicitly. "f"Instead they were passed in as part of `model_kwargs` parameter.")values["model_kwargs"]=extrareturnvalues@model_validator(mode="after")defpost_init(self)->Self:"""Logic that will post Pydantic initialization."""client_params:dict={"api_key":(self.together_api_key.get_secret_value()ifself.together_api_keyelseNone),"base_url":self.together_api_base,"timeout":self.request_timeout,"max_retries":self.max_retries,"default_headers":self.default_headers,"default_query":self.default_query,}ifnot(self.clientorNone):sync_specific:dict=({"http_client":self.http_client}ifself.http_clientelse{})self.client=openai.OpenAI(**client_params,**sync_specific).embeddingsifnot(self.async_clientorNone):async_specific:dict=({"http_client":self.http_async_client}ifself.http_async_clientelse{})self.async_client=openai.AsyncOpenAI(**client_params,**async_specific).embeddingsreturnself@propertydef_invocation_params(self)->Dict[str,Any]:params:Dict={"model":self.model,**self.model_kwargs}ifself.dimensionsisnotNone:params["dimensions"]=self.dimensionsreturnparams
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts using passage model. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """embeddings=[]params=self._invocation_paramsparams["model"]=params["model"]fortextintexts:response=self.client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()embeddings.extend([i["embedding"]foriinresponse["data"]])returnembeddings
[docs]defembed_query(self,text:str)->List[float]:"""Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """params=self._invocation_paramsparams["model"]=params["model"]response=self.client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()returnresponse["data"][0]["embedding"]
[docs]asyncdefaembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts using passage model asynchronously. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """embeddings=[]params=self._invocation_paramsparams["model"]=params["model"]fortextintexts:response=awaitself.async_client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()embeddings.extend([i["embedding"]foriinresponse["data"]])returnembeddings
[docs]asyncdefaembed_query(self,text:str)->List[float]:"""Asynchronous Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """params=self._invocation_paramsparams["model"]=params["model"]response=awaitself.async_client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()returnresponse["data"][0]["embedding"]