Source code for langchain_experimental.llms.rellm_decoder
"""Experimental implementation of RELLM wrapped LLM."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Any,List,Optional,castfromlangchain_community.llms.huggingface_pipelineimportHuggingFacePipelinefromlangchain_community.llms.utilsimportenforce_stop_tokensfromlangchain_core.callbacks.managerimportCallbackManagerForLLMRunfrompydanticimportField,model_validatorifTYPE_CHECKING:importrellmfromregeximportPatternasRegexPatternelse:try:fromregeximportPatternasRegexPatternexceptImportError:pass
[docs]defimport_rellm()->rellm:"""Lazily import of the rellm package."""try:importrellmexceptImportError:raiseImportError("Could not import rellm python package. ""Please install it with `pip install rellm`.")returnrellm
[docs]classRELLM(HuggingFacePipeline):"""RELLM wrapped LLM using HuggingFace Pipeline API."""regex:RegexPattern=Field(...,description="The structured format to complete.")max_new_tokens:int=Field(default=200,description="Maximum number of new tokens to generate.")@model_validator(mode="before")@classmethoddefcheck_rellm_installation(cls,values:dict)->Any:import_rellm()returnvaluesdef_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:rellm=import_rellm()fromtransformersimportText2TextGenerationPipelinepipeline=cast(Text2TextGenerationPipeline,self.pipeline)text=rellm.complete_re(prompt,self.regex,tokenizer=pipeline.tokenizer,model=pipeline.model,max_new_tokens=self.max_new_tokens,)ifstopisnotNone:# This is a bit hacky, but I can't figure out a better way to enforce# stop tokens when making calls to huggingface_hub.text=enforce_stop_tokens(text,stop)returntext