importasyncioimportloggingimportwarningsfromtypingimportDict,Iterable,Listimporthttpxfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.pydantic_v1import(BaseModel,Field,SecretStr,root_validator,)fromlangchain_core.utilsimport(secret_from_env,)fromtokenizersimportTokenizer# type: ignorelogger=logging.getLogger(__name__)MAX_TOKENS=16_000"""A batching parameter for the Mistral API. This is NOT the maximum number of tokensaccepted by the embedding model for each document/chunk, but rather the maximum number of tokens that can be sent in a single request to the Mistral API (across multipledocuments/chunks)"""
[docs]classDummyTokenizer:"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
[docs]classMistralAIEmbeddings(BaseModel,Embeddings):"""MistralAI embedding model integration. Setup: Install ``langchain_mistralai`` and set environment variable ``MISTRAL_API_KEY``. .. code-block:: bash pip install -U langchain_mistralai export MISTRAL_API_KEY="your-api-key" Key init args — completion params: model: str Name of MistralAI model to use. Key init args — client params: api_key: Optional[SecretStr] The API key for the MistralAI API. If not provided, it will be read from the environment variable `MISTRAL_API_KEY`. max_retries: int The number of times to retry a request if it fails. timeout: int The number of seconds to wait for a response before timing out. max_concurrent_requests: int The maximum number of concurrent requests to make to the Mistral API. See full list of supported init args and their descriptions in the params section. Instantiate: .. code-block:: python from __module_name__ import MistralAIEmbeddings embed = MistralAIEmbeddings( model="mistral-embed", # api_key="...", # other params... ) Embed single text: .. code-block:: python input_text = "The meaning of life is 42" vector = embed.embed_query(input_text) print(vector[:3]) .. code-block:: python [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Embed multiple text: .. code-block:: python input_texts = ["Document 1...", "Document 2..."] vectors = embed.embed_documents(input_texts) print(len(vectors)) # The first 3 coordinates for the first vector print(vectors[0][:3]) .. code-block:: python 2 [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Async: .. code-block:: python vector = await embed.aembed_query(input_text) print(vector[:3]) # multiple: # await embed.aembed_documents(input_texts) .. code-block:: python [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188] """client:httpx.Client=Field(default=None)#: :meta private:async_client:httpx.AsyncClient=Field(default=None)#: :meta private:mistral_api_key:SecretStr=Field(alias="api_key",default_factory=secret_from_env("MISTRAL_API_KEY",default=""),)endpoint:str="https://api.mistral.ai/v1/"max_retries:int=5timeout:int=120max_concurrent_requests:int=64tokenizer:Tokenizer=Field(default=None)model:str="mistral-embed"classConfig:extra="forbid"arbitrary_types_allowed=Trueallow_population_by_field_name=True@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate configuration."""api_key_str=values["mistral_api_key"].get_secret_value()# todo: handle retriesifnotvalues.get("client"):values["client"]=httpx.Client(base_url=values["endpoint"],headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"Bearer {api_key_str}",},timeout=values["timeout"],)# todo: handle retries and max_concurrencyifnotvalues.get("async_client"):values["async_client"]=httpx.AsyncClient(base_url=values["endpoint"],headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"Bearer {api_key_str}",},timeout=values["timeout"],)ifvalues["tokenizer"]isNone:try:values["tokenizer"]=Tokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")exceptIOError:# huggingface_hub GatedRepoErrorwarnings.warn("Could not download mistral tokenizer from Huggingface for ""calculating batch sizes. Set a Huggingface token via the ""HF_TOKEN environment variable to download the real tokenizer. ""Falling back to a dummy tokenizer that uses `len()`.")values["tokenizer"]=DummyTokenizer()returnvaluesdef_get_batches(self,texts:List[str])->Iterable[List[str]]:"""Split a list of texts into batches of less than 16k tokens for Mistral API."""batch:List[str]=[]batch_tokens=0text_token_lengths=[len(encoded)forencodedinself.tokenizer.encode_batch(texts)]fortext,text_tokensinzip(texts,text_token_lengths):ifbatch_tokens+text_tokens>MAX_TOKENS:iflen(batch)>0:# edge case where first batch exceeds max tokens# should not yield an empty batch.yieldbatchbatch=[text]batch_tokens=text_tokenselse:batch.append(text)batch_tokens+=text_tokensifbatch:yieldbatch
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """try:batch_responses=(self.client.post(url="/embeddings",json=dict(model=self.model,input=batch,),)forbatchinself._get_batches(texts))return[list(map(float,embedding_obj["embedding"]))forresponseinbatch_responsesforembedding_objinresponse.json()["data"]]exceptExceptionase:logger.error(f"An error occurred with MistralAI: {e}")raise
[docs]asyncdefaembed_documents(self,texts:List[str])->List[List[float]]:"""Embed a list of document texts. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """try:batch_responses=awaitasyncio.gather(*[self.async_client.post(url="/embeddings",json=dict(model=self.model,input=batch,),)forbatchinself._get_batches(texts)])return[list(map(float,embedding_obj["embedding"]))forresponseinbatch_responsesforembedding_objinresponse.json()["data"]]exceptExceptionase:logger.error(f"An error occurred with MistralAI: {e}")raise
[docs]defembed_query(self,text:str)->List[float]:"""Embed a single query text. Args: text: The text to embed. Returns: Embedding for the text. """returnself.embed_documents([text])[0]
[docs]asyncdefaembed_query(self,text:str)->List[float]:"""Embed a single query text. Args: text: The text to embed. Returns: Embedding for the text. """return(awaitself.aembed_documents([text]))[0]