"""Wrapper around EdenAI's Generation API."""importloggingfromtypingimportAny,Dict,List,Literal,OptionalfromaiohttpimportClientSessionfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.llmsimportLLMfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfromlangchain_core.utils.pydanticimportget_fieldsfrompydanticimportConfigDict,Field,model_validatorfromlangchain_community.llms.utilsimportenforce_stop_tokensfromlangchain_community.utilities.requestsimportRequestslogger=logging.getLogger(__name__)
[docs]classEdenAI(LLM):"""EdenAI models. To use, you should have the environment variable ``EDENAI_API_KEY`` set with your API token. You can find your token here: https://app.edenai.run/admin/account/settings `feature` and `subfeature` are required, but any other model parameters can also be passed in with the format params={model_param: value, ...} for api reference check edenai documentation: http://docs.edenai.co. """base_url:str="https://api.edenai.run/v2"edenai_api_key:Optional[str]=Nonefeature:Literal["text","image"]="text""""Which generative feature to use, use text by default"""subfeature:Literal["generation"]="generation""""Subfeature of above feature, use generation by default"""provider:str"""Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""model:Optional[str]=None""" model name for above provider (eg: 'gpt-3.5-turbo-instruct' for openai) available models are shown on https://docs.edenai.co/ under 'available providers' """# Optional parameters to add depending of chosen feature# see api reference for more infostemperature:Optional[float]=Field(default=None,ge=0,le=1)# for textmax_tokens:Optional[int]=Field(default=None,ge=0)# for textresolution:Optional[Literal["256x256","512x512","1024x1024"]]=None# for imageparams:Dict[str,Any]=Field(default_factory=dict)""" DEPRECATED: use temperature, max_tokens, resolution directly optional parameters to pass to api """model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""extra parameters"""stop_sequences:Optional[List[str]]=None"""Stop sequences to use."""model_config=ConfigDict(extra="forbid",)
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key exists in environment."""values["edenai_api_key"]=get_from_dict_or_env(values,"edenai_api_key","EDENAI_API_KEY")returnvalues
@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names={field.aliasforfieldinget_fields(cls).values()}extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_namenotinall_required_field_names:iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")logger.warning(f"""{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)values["model_kwargs"]=extrareturnvalues@propertydef_llm_type(self)->str:"""Return type of model."""return"edenai"def_format_output(self,output:dict)->str:ifself.feature=="text":returnoutput[self.provider]["generated_text"]else:returnoutput[self.provider]["items"][0]["image"]
def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to EdenAI's text generation endpoint. Args: prompt: The prompt to pass into the model. Returns: json formatted str response. """stops=Noneifself.stop_sequencesisnotNoneandstopisnotNone:raiseValueError("stop sequences found in both the input and default params.")elifself.stop_sequencesisnotNone:stops=self.stop_sequenceselse:stops=stopurl=f"{self.base_url}/{self.feature}/{self.subfeature}"headers={"Authorization":f"Bearer {self.edenai_api_key}","User-Agent":self.get_user_agent(),}payload:Dict[str,Any]={"providers":self.provider,"text":prompt,"max_tokens":self.max_tokens,"temperature":self.temperature,"resolution":self.resolution,**self.params,**kwargs,"num_images":1,# always limit to 1 (ignored for text)}# filter None values to not pass them to the http payloadpayload={k:vfork,vinpayload.items()ifvisnotNone}ifself.modelisnotNone:payload["settings"]={self.provider:self.model}request=Requests(headers=headers)response=request.post(url=url,data=payload)ifresponse.status_code>=500:raiseException(f"EdenAI Server: Error {response.status_code}")elifresponse.status_code>=400:raiseValueError(f"EdenAI received an invalid payload: {response.text}")elifresponse.status_code!=200:raiseException(f"EdenAI returned an unexpected response with status "f"{response.status_code}: {response.text}")data=response.json()provider_response=data[self.provider]ifprovider_response.get("status")=="fail":err_msg=provider_response.get("error",{}).get("message")raiseException(err_msg)output=self._format_output(data)ifstopsisnotNone:output=enforce_stop_tokens(output,stops)returnoutputasyncdef_acall(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call EdenAi model to get predictions based on the prompt. Args: prompt: The prompt to pass into the model. stop: A list of stop words (optional). run_manager: A callback manager for async interaction with LLMs. Returns: The string generated by the model. """stops=Noneifself.stop_sequencesisnotNoneandstopisnotNone:raiseValueError("stop sequences found in both the input and default params.")elifself.stop_sequencesisnotNone:stops=self.stop_sequenceselse:stops=stopurl=f"{self.base_url}/{self.feature}/{self.subfeature}"headers={"Authorization":f"Bearer {self.edenai_api_key}","User-Agent":self.get_user_agent(),}payload:Dict[str,Any]={"providers":self.provider,"text":prompt,"max_tokens":self.max_tokens,"temperature":self.temperature,"resolution":self.resolution,**self.params,**kwargs,"num_images":1,# always limit to 1 (ignored for text)}# filter `None` values to not pass them to the http payload as nullpayload={k:vfork,vinpayload.items()ifvisnotNone}ifself.modelisnotNone:payload["settings"]={self.provider:self.model}asyncwithClientSession()assession:asyncwithsession.post(url,json=payload,headers=headers)asresponse:ifresponse.status>=500:raiseException(f"EdenAI Server: Error {response.status}")elifresponse.status>=400:raiseValueError(f"EdenAI received an invalid payload: {response.text}")elifresponse.status!=200:raiseException(f"EdenAI returned an unexpected response with status "f"{response.status}: {response.text}")response_json=awaitresponse.json()provider_response=response_json[self.provider]ifprovider_response.get("status")=="fail":err_msg=provider_response.get("error",{}).get("message")raiseException(err_msg)output=self._format_output(response_json)ifstopsisnotNone:output=enforce_stop_tokens(output,stops)returnoutput