importasyncioimportloggingimportwarningsfromtypingimportIterable,ListimporthttpxfromhttpximportResponsefromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.utilsimport(secret_from_env,)frompydanticimport(BaseModel,ConfigDict,Field,SecretStr,model_validator,)fromtenacityimportretry,retry_if_exception_type,stop_after_attempt,wait_fixedfromtokenizersimportTokenizer# type: ignorefromtyping_extensionsimportSelflogger=logging.getLogger(__name__)MAX_TOKENS=16_000"""A batching parameter for the Mistral API. This is NOT the maximum number of tokensaccepted by the embedding model for each document/chunk, but rather the maximum number of tokens that can be sent in a single request to the Mistral API (across multipledocuments/chunks)"""
[docs]classDummyTokenizer:"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
[docs]classMistralAIEmbeddings(BaseModel,Embeddings):"""MistralAI embedding model integration. Setup: Install ``langchain_mistralai`` and set environment variable ``MISTRAL_API_KEY``. .. code-block:: bash pip install -U langchain_mistralai export MISTRAL_API_KEY="your-api-key" Key init args — completion params: model: str Name of MistralAI model to use. Key init args — client params: api_key: Optional[SecretStr] The API key for the MistralAI API. If not provided, it will be read from the environment variable `MISTRAL_API_KEY`. max_retries: int The number of times to retry a request if it fails. timeout: int The number of seconds to wait for a response before timing out. wait_time: int The number of seconds to wait before retrying a request in case of 429 error. max_concurrent_requests: int The maximum number of concurrent requests to make to the Mistral API. See full list of supported init args and their descriptions in the params section. Instantiate: .. code-block:: python from __module_name__ import MistralAIEmbeddings embed = MistralAIEmbeddings( model="mistral-embed", # api_key="...", # other params... ) Embed single text: .. code-block:: python input_text = "The meaning of life is 42" vector = embed.embed_query(input_text) print(vector[:3]) .. code-block:: python [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Embed multiple text: .. code-block:: python input_texts = ["Document 1...", "Document 2..."] vectors = embed.embed_documents(input_texts) print(len(vectors)) # The first 3 coordinates for the first vector print(vectors[0][:3]) .. code-block:: python 2 [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Async: .. code-block:: python vector = await embed.aembed_query(input_text) print(vector[:3]) # multiple: # await embed.aembed_documents(input_texts) .. code-block:: python [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188] """# The type for client and async_client is ignored because the type is not# an Optional after the model is initialized and the model_validator# is run.client:httpx.Client=Field(default=None)# type: ignore # : :meta private:async_client:httpx.AsyncClient=Field(# type: ignore # : meta private:default=None)mistral_api_key:SecretStr=Field(alias="api_key",default_factory=secret_from_env("MISTRAL_API_KEY",default=""),)endpoint:str="https://api.mistral.ai/v1/"max_retries:int=5timeout:int=120wait_time:int=30max_concurrent_requests:int=64tokenizer:Tokenizer=Field(default=None)model:str="mistral-embed"model_config=ConfigDict(extra="forbid",arbitrary_types_allowed=True,populate_by_name=True,)@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate configuration."""api_key_str=self.mistral_api_key.get_secret_value()# todo: handle retriesifnotself.client:self.client=httpx.Client(base_url=self.endpoint,headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"Bearer {api_key_str}",},timeout=self.timeout,)# todo: handle retries and max_concurrencyifnotself.async_client:self.async_client=httpx.AsyncClient(base_url=self.endpoint,headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"Bearer {api_key_str}",},timeout=self.timeout,)ifself.tokenizerisNone:try:self.tokenizer=Tokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")exceptIOError:# huggingface_hub GatedRepoErrorwarnings.warn("Could not download mistral tokenizer from Huggingface for ""calculating batch sizes. Set a Huggingface token via the ""HF_TOKEN environment variable to download the real tokenizer. ""Falling back to a dummy tokenizer that uses `len()`.")self.tokenizer=DummyTokenizer()returnselfdef_get_batches(self,texts:List[str])->Iterable[List[str]]:"""Split a list of texts into batches of less than 16k tokens for Mistral API."""batch:List[str]=[]batch_tokens=0text_token_lengths=[len(encoded)forencodedinself.tokenizer.encode_batch(texts)]fortext,text_tokensinzip(texts,text_token_lengths):ifbatch_tokens+text_tokens>MAX_TOKENS:iflen(batch)>0:# edge case where first batch exceeds max tokens# should not yield an empty batch.yieldbatchbatch=[text]batch_tokens=text_tokenselse:batch.append(text)batch_tokens+=text_tokensifbatch:yieldbatch
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """try:batch_responses=[]@retry(retry=retry_if_exception_type((httpx.TimeoutException,httpx.HTTPStatusError)),wait=wait_fixed(self.wait_time),stop=stop_after_attempt(self.max_retries),)def_embed_batch(batch:List[str])->Response:response=self.client.post(url="/embeddings",json=dict(model=self.model,input=batch,),)response.raise_for_status()returnresponseforbatchinself._get_batches(texts):batch_responses.append(_embed_batch(batch))return[list(map(float,embedding_obj["embedding"]))forresponseinbatch_responsesforembedding_objinresponse.json()["data"]]exceptExceptionase:logger.error(f"An error occurred with MistralAI: {e}")raise
[docs]asyncdefaembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """try:batch_responses=awaitasyncio.gather(*[self.async_client.post(url="/embeddings",json=dict(model=self.model,input=batch,),)forbatchinself._get_batches(texts)])return[list(map(float,embedding_obj["embedding"]))forresponseinbatch_responsesforembedding_objinresponse.json()["data"]]exceptExceptionase:logger.error(f"An error occurred with MistralAI: {e}")raise
[docs]defembed_query(self,text:str)->List[float]:"""Embed a single query text. Args: text: The text to embed. Returns: Embedding for the text. """returnself.embed_documents([text])[0]
[docs]asyncdefaembed_query(self,text:str)->List[float]:"""Embed a single query text. Args: text: The text to embed. Returns: Embedding for the text. """return(awaitself.aembed_documents([text]))[0]