Source code for langchain_cohere.embeddings

import typing
from typing import Any, Dict, List, Optional, Sequence, Union

import cohere
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env, secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from .utils import _create_retry_decorator


[docs] class CohereEmbeddings(BaseModel, Embeddings): """ Implements the Embeddings interface with Cohere's text representation language models. Find out more about us at https://cohere.com and https://huggingface.co/CohereForAI This implementation uses the Embed API - see https://docs.cohere.com/reference/embed To use this you'll need to a Cohere API key - either pass it to cohere_api_key parameter or set the COHERE_API_KEY environment variable. API keys are available on https://cohere.com - it's free to sign up and trial API keys work with this implementation. Basic Example: .. code-block:: python cohere_embeddings = CohereEmbeddings(model="embed-english-light-v3.0") text = "This is a test document." query_result = cohere_embeddings.embed_query(text) print(query_result) doc_result = cohere_embeddings.embed_documents([text]) print(doc_result) """ client: Any #: :meta private: """Cohere client.""" async_client: Any #: :meta private: """Cohere async client.""" model: Optional[str] = None """Model name to use. It is mandatory to specify the model name.""" truncate: Optional[str] = None """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")""" cohere_api_key: Optional[SecretStr] = Field( default_factory=secret_from_env("COHERE_API_KEY", default=None) ) embedding_types: Sequence[str] = ["float"] "Specifies the types of embeddings you want to get back" max_retries: int = 3 """Maximum number of retries to make when generating.""" request_timeout: Optional[float] = None """Timeout in seconds for the Cohere API request.""" user_agent: str = "langchain:partner" """Identifier for the application making the request.""" base_url: Optional[str] = None """Override the default Cohere API URL.""" model_config = ConfigDict( arbitrary_types_allowed=True, extra="forbid", protected_namespaces=(), ) @model_validator(mode="before") @classmethod def validate_environment(cls, values: Dict) -> Any: """Validate that api key and python package exists in environment.""" cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) if isinstance(cohere_api_key, SecretStr): cohere_api_key = cohere_api_key.get_secret_value() request_timeout = values.get("request_timeout") client_name = values.get("user_agent", "langchain:partner") values["client"] = cohere.Client( cohere_api_key, timeout=request_timeout, client_name=client_name, base_url=values.get("base_url"), ) values["async_client"] = cohere.AsyncClient( cohere_api_key, timeout=request_timeout, client_name=client_name, base_url=values.get("base_url"), ) return values @model_validator(mode="after") def validate_model_specified(self) -> Self: # type: ignore[valid-type] """Validate that model is specified.""" if not self.model: raise ValueError( "Did not find `model`! Please " " pass `model` as a named parameter." " Please check out" " https://docs.cohere.com/reference/embed" " for available models." ) return self
[docs] def embed_with_retry(self, **kwargs: Any) -> Any: """Use tenacity to retry the embed call.""" retry_decorator = _create_retry_decorator(self.max_retries) @retry_decorator def _embed_with_retry(**kwargs: Any) -> Any: return self.client.embed(**kwargs) return _embed_with_retry(**kwargs)
[docs] def aembed_with_retry(self, **kwargs: Any) -> Any: """Use tenacity to retry the embed call.""" retry_decorator = _create_retry_decorator(self.max_retries) @retry_decorator async def _embed_with_retry(**kwargs: Any) -> Any: return await self.async_client.embed(**kwargs) return _embed_with_retry(**kwargs)
[docs] def embed( self, texts: List[str], *, input_type: typing.Optional[cohere.EmbedInputType] = None, ) -> List[List[float]]: response = self.embed_with_retry( model=self.model, texts=texts, input_type=input_type, truncate=self.truncate, embedding_types=self.embedding_types, ) embeddings = response.dict().get("embeddings", []) embeddings_as_float: List[List[float]] = [] for embedding_type in self.embedding_types: e: List[List[Union[int, float]]] = embeddings.get(embedding_type) if not e: continue for i in range(len(e)): embeddings_as_float.append(list(map(float, e[i]))) return embeddings_as_float
[docs] async def aembed( self, texts: List[str], *, input_type: typing.Optional[cohere.EmbedInputType] = None, ) -> List[List[float]]: response = await self.aembed_with_retry( model=self.model, texts=texts, input_type=input_type, truncate=self.truncate, embedding_types=self.embedding_types, ) embeddings = response.dict().get("embeddings", []) embeddings_as_float: List[List[float]] = [] for embedding_type in self.embedding_types: e: List[List[Union[int, float]]] = embeddings.get(embedding_type) if not e: continue for i in range(len(e)): embeddings_as_float.append(list(map(float, e[i]))) return embeddings_as_float
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of document texts. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ return self.embed(texts, input_type="search_document")
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Async call out to Cohere's embedding endpoint. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ return await self.aembed(texts, input_type="search_document")
[docs] def embed_query(self, text: str) -> List[float]: """Call out to Cohere's embedding endpoint. Args: text: The text to embed. Returns: Embeddings for the text. """ return self.embed([text], input_type="search_query")[0]
[docs] async def aembed_query(self, text: str) -> List[float]: """Async call out to Cohere's embedding endpoint. Args: text: The text to embed. Returns: Embeddings for the text. """ return (await self.aembed([text], input_type="search_query"))[0]