[docs]classCerebriumAI(LLM):"""CerebriumAI large language models. To use, you should have the ``cerebrium`` python package installed. You should also have the environment variable ``CEREBRIUMAI_API_KEY`` set with your API key or pass it as a named argument in the constructor. Any parameters that are valid to be passed to the call can be passed in, even if not explicitly saved on this class. Example: .. code-block:: python from langchain_community.llms import CerebriumAI cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key") """endpoint_url:str="""""model endpoint to use"""model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for `create` call not explicitly specified."""cerebriumai_api_key:Optional[SecretStr]=Nonemodel_config=ConfigDict(extra="forbid",)@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=set(list(cls.model_fields.keys()))extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_namenotinall_required_field_names:iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")logger.warning(f"""{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)values["model_kwargs"]=extrareturnvalues
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""cerebriumai_api_key=convert_to_secret_str(get_from_dict_or_env(values,"cerebriumai_api_key","CEREBRIUMAI_API_KEY"))values["cerebriumai_api_key"]=cerebriumai_api_keyreturnvalues
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""return{**{"endpoint_url":self.endpoint_url},**{"model_kwargs":self.model_kwargs},}@propertydef_llm_type(self)->str:"""Return type of llm."""return"cerebriumai"def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:headers:Dict={"Authorization":cast(SecretStr,self.cerebriumai_api_key).get_secret_value(),"Content-Type":"application/json",}params=self.model_kwargsor{}payload={"prompt":prompt,**params,**kwargs}response=requests.post(self.endpoint_url,json=payload,headers=headers)ifresponse.status_code==200:data=response.json()text=data["result"]ifstopisnotNone:# I believe this is required since the stop tokens# are not enforced by the model parameterstext=enforce_stop_tokens(text,stop)returntextelse:response.raise_for_status()return""