"""RWKV models.Based on https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.py https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py"""fromtypingimportAny,Dict,List,Mapping,Optional,Setfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_models.llmsimportLLMfromlangchain_core.utilsimportpre_initfrompydanticimportBaseModel,ConfigDictfromlangchain_community.llms.utilsimportenforce_stop_tokens
[docs]classRWKV(LLM,BaseModel):"""RWKV language models. To use, you should have the ``rwkv`` python package installed, the pre-trained model file, and the model's config information. Example: .. code-block:: python from langchain_community.llms import RWKV model = RWKV(model="./models/rwkv-3b-fp16.bin", strategy="cpu fp32") # Simplest invocation response = model.invoke("Once upon a time, ") """model:str"""Path to the pre-trained RWKV model file."""tokens_path:str"""Path to the RWKV tokens file."""strategy:str="cpu fp32""""Token context window."""rwkv_verbose:bool=True"""Print debug information."""temperature:float=1.0"""The temperature to use for sampling."""top_p:float=0.5"""The top-p value to use for sampling."""penalty_alpha_frequency:float=0.4"""Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.."""penalty_alpha_presence:float=0.4"""Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.."""CHUNK_LEN:int=256"""Batch size for prompt processing."""max_tokens_per_generation:int=256"""Maximum number of tokens to generate."""client:Any=None#: :meta private:tokenizer:Any=None#: :meta private:pipeline:Any=None#: :meta private:model_tokens:Any=None#: :meta private:model_state:Any=None#: :meta private:model_config=ConfigDict(extra="forbid",)@propertydef_default_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{"verbose":self.verbose,"top_p":self.top_p,"temperature":self.temperature,"penalty_alpha_frequency":self.penalty_alpha_frequency,"penalty_alpha_presence":self.penalty_alpha_presence,"CHUNK_LEN":self.CHUNK_LEN,"max_tokens_per_generation":self.max_tokens_per_generation,}@staticmethoddef_rwkv_param_names()->Set[str]:"""Get the identifying parameters."""return{"verbose",}
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that the python package exists in the environment."""try:importtokenizersexceptImportError:raiseImportError("Could not import tokenizers python package. ""Please install it with `pip install tokenizers`.")try:fromrwkv.modelimportRWKVasRWKVMODELfromrwkv.utilsimportPIPELINEvalues["tokenizer"]=tokenizers.Tokenizer.from_file(values["tokens_path"])rwkv_keys=cls._rwkv_param_names()model_kwargs={k:vfork,vinvalues.items()ifkinrwkv_keys}model_kwargs["verbose"]=values["rwkv_verbose"]values["client"]=RWKVMODEL(values["model"],strategy=values["strategy"],**model_kwargs)values["pipeline"]=PIPELINE(values["client"],values["tokens_path"])exceptImportError:raiseImportError("Could not import rwkv python package. ""Please install it with `pip install rwkv`.")returnvalues
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""return{"model":self.model,**self._default_params,**{k:vfork,vinself.__dict__.items()ifkinRWKV._rwkv_param_names()},}@propertydef_llm_type(self)->str:"""Return the type of llm."""return"rwkv"
def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:r"""RWKV generation Args: prompt: The prompt to pass into the model. stop: A list of strings to stop generation when encountered. Returns: The string generated by the model. Example: .. code-block:: python prompt = "Once upon a time, " response = model.invoke(prompt, n_predict=55) """text=self.rwkv_generate(prompt)ifstopisnotNone:text=enforce_stop_tokens(text,stop)returntext