[docs]classReplicate(LLM):"""Replicate models. To use, you should have the ``replicate`` python package installed, and the environment variable ``REPLICATE_API_TOKEN`` set with your API token. You can find your token here: https://replicate.com/account The model param is required, but any other model parameters can also be passed in with the format model_kwargs={model_param: value, ...} Example: .. code-block:: python from langchain_community.llms import Replicate replicate = Replicate( model=( "stability-ai/stable-diffusion: " "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", ), model_kwargs={"image_dimensions": "512x512"} ) """model:strmodel_kwargs:Dict[str,Any]=Field(default_factory=dict,alias="input")replicate_api_token:Optional[str]=Noneprompt_key:Optional[str]=Noneversion_obj:Any=Field(default=None,exclude=True)"""Optionally pass in the model version object during initialization to avoid having to make an extra API call to retrieve it during streaming. NOTE: not serializable, is excluded from serialization. """streaming:bool=False"""Whether to stream the results."""stop:List[str]=Field(default_factory=list)"""Stop sequences to early-terminate generation."""model_config=ConfigDict(populate_by_name=True,extra="forbid",)@propertydeflc_secrets(self)->Dict[str,str]:return{"replicate_api_token":"REPLICATE_API_TOKEN"}@classmethoddefis_lc_serializable(cls)->bool:returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","llms","replicate"]@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names={fieldforfieldinget_fields(cls).keys()}input=values.pop("input",{})ifinput:logger.warning("Init param `input` is deprecated, please use `model_kwargs` instead.")extra={**values.pop("model_kwargs",{}),**input}forfield_nameinlist(values):iffield_namenotinall_required_field_names:iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")logger.warning(f"""{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)values["model_kwargs"]=extrareturnvalues
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""replicate_api_token=get_from_dict_or_env(values,"replicate_api_token","REPLICATE_API_TOKEN")values["replicate_api_token"]=replicate_api_tokenreturnvalues
@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{"model":self.model,"model_kwargs":self.model_kwargs,}@propertydef_llm_type(self)->str:"""Return type of model."""return"replicate"def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call to replicate endpoint."""ifself.streaming:completion:Optional[str]=Noneforchunkinself._stream(prompt,stop=stop,run_manager=run_manager,**kwargs):ifcompletionisNone:completion=chunk.textelse:completion+=chunk.textelse:prediction=self._create_prediction(prompt,**kwargs)prediction.wait()ifprediction.status=="failed":raiseRuntimeError(prediction.error)ifisinstance(prediction.output,str):completion=prediction.outputelse:completion="".join(prediction.output)assertcompletionisnotNonestop_conditions=stoporself.stopforsinstop_conditions:ifsincompletion:completion=completion[:completion.find(s)]returncompletiondef_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:prediction=self._create_prediction(prompt,**kwargs)stop_conditions=stoporself.stopstop_condition_reached=Falsecurrent_completion:str=""foroutputinprediction.output_iterator():current_completion+=output# test for stop conditions, if specifiedforsinstop_conditions:ifsincurrent_completion:prediction.cancel()stop_condition_reached=True# Potentially some tokens that should still be yielded before ending# stream.stop_index=max(output.find(s),0)output=output[:stop_index]ifnotoutput:breakifoutput:ifrun_manager:run_manager.on_llm_new_token(output,verbose=self.verbose,)yieldGenerationChunk(text=output)ifstop_condition_reached:breakdef_create_prediction(self,prompt:str,**kwargs:Any)->Prediction:try:importreplicateasreplicate_pythonexceptImportError:raiseImportError("Could not import replicate python package. ""Please install it with `pip install replicate`.")# get the model and versionifself.version_objisNone:if":"inself.model:model_str,version_str=self.model.split(":")model=replicate_python.models.get(model_str)self.version_obj=model.versions.get(version_str)else:model=replicate_python.models.get(self.model)self.version_obj=model.latest_versionifself.prompt_keyisNone:# sort through the openapi schema to get the name of the first inputinput_properties=sorted(self.version_obj.openapi_schema["components"]["schemas"]["Input"]["properties"].items(),key=lambdaitem:item[1].get("x-order",0),)self.prompt_key=input_properties[0][0]input_:Dict={self.prompt_key:prompt,**self.model_kwargs,**kwargs,}# if it's an official modelif":"notinself.model:returnreplicate_python.models.predictions.create(self.model,input=input_)else:returnreplicate_python.predictions.create(version=self.version_obj,input=input_)