[docs]classCloudflareWorkersAIEmbeddings(BaseModel,Embeddings):"""Cloudflare Workers AI embedding model. To use, you need to provide an API token and account ID to access Cloudflare Workers AI. Example: .. code-block:: python from langchain_community.embeddings import CloudflareWorkersAIEmbeddings account_id = "my_account_id" api_token = "my_secret_api_token" model_name = "@cf/baai/bge-small-en-v1.5" cf = CloudflareWorkersAIEmbeddings( account_id=account_id, api_token=api_token, model_name=model_name ) """api_base_url:str="https://api.cloudflare.com/client/v4/accounts"account_id:strapi_token:strmodel_name:str=DEFAULT_MODEL_NAMEbatch_size:int=50strip_new_lines:bool=Trueheaders:Dict[str,str]={"Authorization":"Bearer "}def__init__(self,**kwargs:Any):"""Initialize the Cloudflare Workers AI client."""super().__init__(**kwargs)self.headers={"Authorization":f"Bearer {self.api_token}"}model_config=ConfigDict(extra="forbid",protected_namespaces=())
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Compute doc embeddings using Cloudflare Workers AI. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ifself.strip_new_lines:texts=[text.replace("\n"," ")fortextintexts]batches=[texts[i:i+self.batch_size]foriinrange(0,len(texts),self.batch_size)]embeddings=[]forbatchinbatches:response=requests.post(f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}",headers=self.headers,json={"text":batch},)embeddings.extend(response.json()["result"]["data"])returnembeddings
[docs]defembed_query(self,text:str)->List[float]:"""Compute query embeddings using Cloudflare Workers AI. Args: text: The text to embed. Returns: Embeddings for the text. """text=text.replace("\n"," ")ifself.strip_new_lineselsetextresponse=requests.post(f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}",headers=self.headers,json={"text":[text]},)returnresponse.json()["result"]["data"][0]