Source code for langchain_community.utilities.arcee
# This module contains utility classes and functions for interacting with Arcee API.# For more information and updates, refer to the Arcee utils page:# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]fromenumimportEnumfromtypingimportAny,Dict,List,Literal,Mapping,Optional,Unionimportrequestsfromlangchain_core.retrieversimportDocumentfrompydanticimportBaseModel,SecretStr,model_validator
[docs]classArceeRoute(str,Enum):"""Routes available for the Arcee API as enumerator."""generate="models/generate"retrieve="models/retrieve"model_training_status="models/status/{id_or_name}"
[docs]classDALMFilterType(str,Enum):"""Filter types available for a DALM retrieval as enumerator."""fuzzy_search="fuzzy_search"strict_search="strict_search"
[docs]classDALMFilter(BaseModel):"""Filters available for a DALM retrieval and generation. Arguments: field_name: The field to filter on. Can be 'document' or 'name' to filter on your document's raw text or title. Any other field will be presumed to be a metadata field you included when uploading your context data filter_type: Currently 'fuzzy_search' and 'strict_search' are supported. 'fuzzy_search' means a fuzzy search on the provided field is performed. The exact strict doesn't need to exist in the document for this to find a match. Very useful for scanning a document for some keyword terms. 'strict_search' means that the exact string must appear in the provided field. This is NOT an exact eq filter. ie a document with content "the happy dog crossed the street" will match on a strict_search of "dog" but won't match on "the dog". Python equivalent of `return search_string in full_string`. value: The actual value to search for in the context data/metadata """field_name:strfilter_type:DALMFilterTypevalue:str_is_metadata:bool=False@model_validator(mode="before")@classmethoddefset_meta(cls,values:Dict)->Any:"""document and name are reserved arcee keys. Anything else is metadata"""values["_is_meta"]=values.get("field_name")notin["document","name"]returnvalues
[docs]classArceeDocumentSource(BaseModel):"""Source of an Arcee document."""document:strname:strid:str
[docs]classArceeDocumentAdapter:"""Adapter for Arcee documents"""
[docs]@classmethoddefadapt(cls,arcee_document:ArceeDocument)->Document:"""Adapts an `ArceeDocument` to a langchain's `Document` object."""returnDocument(page_content=arcee_document.source.document,metadata={# arcee document; source metadata"name":arcee_document.source.name,"source_id":arcee_document.source.id,# arcee document metadata"index":arcee_document.index,"id":arcee_document.id,"score":arcee_document.score,},)
[docs]classArceeWrapper:"""Wrapper for Arcee API. For more details, see: https://www.arcee.ai/ """
[docs]def__init__(self,arcee_api_key:Union[str,SecretStr],arcee_api_url:str,arcee_api_version:str,model_kwargs:Optional[Dict[str,Any]],model_name:str,):"""Initialize ArceeWrapper. Arguments: arcee_api_key: API key for Arcee API. arcee_api_url: URL for Arcee API. arcee_api_version: Version of Arcee API. model_kwargs: Keyword arguments for Arcee API. model_name: Name of an Arcee model. """ifisinstance(arcee_api_key,str):arcee_api_key_=SecretStr(arcee_api_key)else:arcee_api_key_=arcee_api_keyself.arcee_api_key:SecretStr=arcee_api_key_self.model_kwargs=model_kwargsself.arcee_api_url=arcee_api_urlself.arcee_api_version=arcee_api_versiontry:route=ArceeRoute.model_training_status.value.format(id_or_name=model_name)response=self._make_request("get",route)self.model_id=response.get("model_id")self.model_training_status=response.get("status")exceptExceptionase:raiseValueError(f"Error while validating model training status for '{model_name}': {e}")frome
[docs]defvalidate_model_training_status(self)->None:ifself.model_training_status!="training_complete":raiseException(f"Model {self.model_id} is not ready. ""Please wait for training to complete.")
def_make_request(self,method:Literal["post","get"],route:Union[ArceeRoute,str],body:Optional[Mapping[str,Any]]=None,params:Optional[dict]=None,headers:Optional[dict]=None,)->dict:"""Make a request to the Arcee API Args: method: The HTTP method to use route: The route to call body: The body of the request params: The query params of the request headers: The headers of the request """headers=self._make_request_headers(headers=headers)url=self._make_request_url(route=route)req_type=getattr(requests,method)response=req_type(url,json=body,params=params,headers=headers)ifresponse.status_codenotin(200,201):raiseException(f"Failed to make request. Response: {response.text}")returnresponse.json()def_make_request_headers(self,headers:Optional[Dict]=None)->Dict:headers=headersor{}ifnotisinstance(self.arcee_api_key,SecretStr):raiseTypeError(f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}")api_key=self.arcee_api_key.get_secret_value()internal_headers={"X-Token":api_key,"Content-Type":"application/json",}headers.update(internal_headers)returnheadersdef_make_request_url(self,route:Union[ArceeRoute,str])->str:returnf"{self.arcee_api_url}/{self.arcee_api_version}/{route}"def_make_request_body_for_models(self,prompt:str,**kwargs:Mapping[str,Any])->Mapping[str,Any]:"""Make the request body for generate/retrieve models endpoint"""_model_kwargs=self.model_kwargsor{}_params={**_model_kwargs,**kwargs}filters=[DALMFilter(**f)forfin_params.get("filters",[])]returndict(model_id=self.model_id,query=prompt,size=_params.get("size",3),filters=filters,id=self.model_id,)
[docs]defgenerate(self,prompt:str,**kwargs:Any,)->str:"""Generate text from Arcee DALM. Args: prompt: Prompt to generate text from. size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). filters: Filters to apply to the context dataset. """response=self._make_request(method="post",route=ArceeRoute.generate.value,body=self._make_request_body_for_models(prompt=prompt,**kwargs,),)returnresponse["text"]
[docs]defretrieve(self,query:str,**kwargs:Any,)->List[Document]:"""Retrieve {size} contexts with your retriever for a given query Args: query: Query to submit to the model size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). filters: Filters to apply to the context dataset. """response=self._make_request(method="post",route=ArceeRoute.retrieve.value,body=self._make_request_body_for_models(prompt=query,**kwargs,),)return[ArceeDocumentAdapter.adapt(ArceeDocument(**doc))fordocinresponse["results"]]