"""Wrapper around Together AI's Embeddings API."""importloggingimportwarningsfromtypingimport(Any,Dict,List,Literal,Mapping,Optional,Sequence,Set,Tuple,Union,)importopenaifromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.pydantic_v1import(BaseModel,Field,SecretStr,root_validator,)fromlangchain_core.utilsimport(from_env,get_pydantic_field_names,secret_from_env,)logger=logging.getLogger(__name__)
[docs]classTogetherEmbeddings(BaseModel,Embeddings):"""Together embedding model integration. Setup: Install ``langchain_together`` and set environment variable ``TOGETHER_API_KEY``. .. code-block:: bash pip install -U langchain_together export TOGETHER_API_KEY="your-api-key" Key init args — completion params: model: str Name of Together model to use. Key init args — client params: api_key: Optional[SecretStr] See full list of supported init args and their descriptions in the params section. Instantiate: .. code-block:: python from __module_name__ import TogetherEmbeddings embed = TogetherEmbeddings( model="togethercomputer/m2-bert-80M-8k-retrieval", # api_key="...", # other params... ) Embed single text: .. code-block:: python input_text = "The meaning of life is 42" vector = embed.embed_query(input_text) print(vector[:3]) .. code-block:: python [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Embed multiple texts: .. code-block:: python input_texts = ["Document 1...", "Document 2..."] vectors = embed.embed_documents(input_texts) print(len(vectors)) # The first 3 coordinates for the first vector print(vectors[0][:3]) .. code-block:: python 2 [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Async: .. code-block:: python vector = await embed.aembed_query(input_text) print(vector[:3]) # multiple: # await embed.aembed_documents(input_texts) .. code-block:: python [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188] """client:Any=Field(default=None,exclude=True)#: :meta private:async_client:Any=Field(default=None,exclude=True)#: :meta private:model:str="togethercomputer/m2-bert-80M-8k-retrieval""""Embeddings model name to use. Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example. """dimensions:Optional[int]=None"""The number of dimensions the resulting output embeddings should have. Not yet supported. """together_api_key:Optional[SecretStr]=Field(alias="api_key",default_factory=secret_from_env("TOGETHER_API_KEY",default=None),)"""Together AI API key. Automatically read from env variable `TOGETHER_API_KEY` if not provided. """together_api_base:str=Field(default_factory=from_env("TOGETHER_API_BASE",default="https://api.together.xyz/v1/"),alias="base_url",)"""Endpoint URL to use."""embedding_ctx_length:int=4096"""The maximum number of tokens to embed at once. Not yet supported. """allowed_special:Union[Literal["all"],Set[str]]=set()"""Not yet supported."""disallowed_special:Union[Literal["all"],Set[str],Sequence[str]]="all""""Not yet supported."""chunk_size:int=1000"""Maximum number of texts to embed in each batch. Not yet supported. """max_retries:int=2"""Maximum number of retries to make when generating."""request_timeout:Optional[Union[float,Tuple[float,float],Any]]=Field(default=None,alias="timeout")"""Timeout for requests to Together embedding API. Can be float, httpx.Timeout or None."""show_progress_bar:bool=False"""Whether to show a progress bar when embedding. Not yet supported. """model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for `create` call not explicitly specified."""skip_empty:bool=False"""Whether to skip empty strings when embedding or raise an error. Defaults to not skipping. Not yet supported."""default_headers:Union[Mapping[str,str],None]=Nonedefault_query:Union[Mapping[str,object],None]=None# Configure a custom httpx client. See the# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.http_client:Union[Any,None]=None"""Optional httpx.Client. Only used for sync invocations. Must specify http_async_client as well if you'd like a custom client for async invocations. """http_async_client:Union[Any,None]=None"""Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations."""classConfig:"""Configuration for this pydantic object."""extra="forbid"allow_population_by_field_name=True@root_validator(pre=True)defbuild_extra(cls,values:Dict[str,Any])->Dict[str,Any]:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")iffield_namenotinall_required_field_names:warnings.warn(f"""WARNING! {field_name} is not default parameter.{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)invalid_model_kwargs=all_required_field_names.intersection(extra.keys())ifinvalid_model_kwargs:raiseValueError(f"Parameters {invalid_model_kwargs} should be specified explicitly. "f"Instead they were passed in as part of `model_kwargs` parameter.")values["model_kwargs"]=extrareturnvalues@root_validator(pre=False,skip_on_failure=True)defpost_init(cls,values:Dict)->Dict:"""Logic that will post Pydantic initialization."""client_params={"api_key":(values["together_api_key"].get_secret_value()ifvalues["together_api_key"]elseNone),"base_url":values["together_api_base"],"timeout":values["request_timeout"],"max_retries":values["max_retries"],"default_headers":values["default_headers"],"default_query":values["default_query"],}ifnotvalues.get("client"):sync_specific=({"http_client":values["http_client"]}ifvalues["http_client"]else{})values["client"]=openai.OpenAI(**client_params,**sync_specific).embeddingsifnotvalues.get("async_client"):async_specific=({"http_client":values["http_async_client"]}ifvalues["http_async_client"]else{})values["async_client"]=openai.AsyncOpenAI(**client_params,**async_specific).embeddingsreturnvalues@propertydef_invocation_params(self)->Dict[str,Any]:params:Dict={"model":self.model,**self.model_kwargs}ifself.dimensionsisnotNone:params["dimensions"]=self.dimensionsreturnparams
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts using passage model. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """embeddings=[]params=self._invocation_paramsparams["model"]=params["model"]fortextintexts:response=self.client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()embeddings.extend([i["embedding"]foriinresponse["data"]])returnembeddings
[docs]defembed_query(self,text:str)->List[float]:"""Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """params=self._invocation_paramsparams["model"]=params["model"]response=self.client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()returnresponse["data"][0]["embedding"]
[docs]asyncdefaembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts using passage model asynchronously. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """embeddings=[]params=self._invocation_paramsparams["model"]=params["model"]fortextintexts:response=awaitself.async_client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()embeddings.extend([i["embedding"]foriinresponse["data"]])returnembeddings
[docs]asyncdefaembed_query(self,text:str)->List[float]:"""Asynchronous Embed query text using query model. Args: text: The text to embed. Returns: Embedding for the text. """params=self._invocation_paramsparams["model"]=params["model"]response=awaitself.async_client.create(input=text,**params)ifnotisinstance(response,dict):response=response.model_dump()returnresponse["data"][0]["embedding"]