Source code for langchain_community.embeddings.baichuan
fromtypingimportAny,List,Optionalimportrequestsfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.utilsimport(secret_from_env,)frompydanticimport(BaseModel,ConfigDict,Field,SecretStr,model_validator,)fromrequestsimportRequestExceptionfromtyping_extensionsimportSelfBAICHUAN_API_URL:str="https://api.baichuan-ai.com/v1/embeddings"# BaichuanTextEmbeddings is an embedding model provided by Baichuan Inc. (https://www.baichuan-ai.com/home).# As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB# (Chinese Multi-Task Embedding Benchmark) leaderboard.# Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard# Official Website: https://platform.baichuan-ai.com/docs/text-Embedding# An API-key is required to use this embedding model. You can get one by registering# at https://platform.baichuan-ai.com/docs/text-Embedding.# BaichuanTextEmbeddings support 512 token window and produces vectors with# 1024 dimensions.# NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding.# Multi-language support is coming soon.
[docs]classBaichuanTextEmbeddings(BaseModel,Embeddings):"""Baichuan Text Embedding models. Setup: To use, you should set the environment variable ``BAICHUAN_API_KEY`` to your API key or pass it as a named parameter to the constructor. .. code-block:: bash export BAICHUAN_API_KEY="your-api-key" Instantiate: .. code-block:: python from langchain_community.embeddings import BaichuanTextEmbeddings embeddings = BaichuanTextEmbeddings() Embed: .. code-block:: python # embed the documents vectors = embeddings.embed_documents([text1, text2, ...]) # embed the query vectors = embeddings.embed_query(text) """# noqa: E501session:Any=None#: :meta private:model_name:str=Field(default="Baichuan-Text-Embedding",alias="model")"""The model used to embed the documents."""baichuan_api_key:SecretStr=Field(alias="api_key",default_factory=secret_from_env(["BAICHUAN_API_KEY","BAICHUAN_AUTH_TOKEN"]),)"""Automatically inferred from env var `BAICHUAN_API_KEY` if not provided."""chunk_size:int=16"""Chunk size when multiple texts are input"""model_config=ConfigDict(populate_by_name=True,protected_namespaces=())@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that auth token exists in environment."""session=requests.Session()session.headers.update({"Authorization":f"Bearer {self.baichuan_api_key.get_secret_value()}","Accept-Encoding":"identity","Content-type":"application/json",})self.session=sessionreturnselfdef_embed(self,texts:List[str])->Optional[List[List[float]]]:"""Internal method to call Baichuan Embedding API and return embeddings. Args: texts: A list of texts to embed. Returns: A list of list of floats representing the embeddings, or None if an error occurs. """chunk_texts=[texts[i:i+self.chunk_size]foriinrange(0,len(texts),self.chunk_size)]embed_results=[]forchunkinchunk_texts:response=self.session.post(BAICHUAN_API_URL,json={"input":chunk,"model":self.model_name})# Raise exception if response status code from 400 to 600response.raise_for_status()# Check if the response status code indicates successifresponse.status_code==200:resp=response.json()embeddings=resp.get("data",[])# Sort resulting embeddings by indexsorted_embeddings=sorted(embeddings,key=lambdae:e.get("index",0))# Return just the embeddingsembed_results.extend([result.get("embedding",[])forresultinsorted_embeddings])else:# Log error or handle unsuccessful response appropriately# Handle 100 <= status_code < 400, not include 200raiseRequestException(f"Error: Received status code {response.status_code} from ""`BaichuanEmbedding` API")returnembed_results
[docs]defembed_documents(self,texts:List[str])->Optional[List[List[float]]]:# type: ignore[override]"""Public method to get embeddings for a list of documents. Args: texts: The list of texts to embed. Returns: A list of embeddings, one for each text, or None if an error occurs. """returnself._embed(texts)
[docs]defembed_query(self,text:str)->Optional[List[float]]:# type: ignore[override]"""Public method to get embedding for a single query text. Args: text: The text to embed. Returns: Embeddings for the text, or None if an error occurs. """result=self._embed([text])returnresult[0]ifresultisnotNoneelseNone