[docs]defgemma_messages_to_prompt(history:List[BaseMessage])->str:"""Converts a list of messages to a chat prompt for Gemma."""messages:List[str]=[]iflen(messages)==1:content=cast(str,history[0].content)ifisinstance(history[0],SystemMessage):raiseValueError("Gemma currently doesn't support system message!")returncontentformessageinhistory:content=cast(str,message.content)ifisinstance(message,SystemMessage):raiseValueError("Gemma currently doesn't support system message!")elifisinstance(message,AIMessage):messages.append(MODEL_CHAT_TEMPLATE.format(prompt=content))elifisinstance(message,HumanMessage):messages.append(USER_CHAT_TEMPLATE.format(prompt=content))else:raiseValueError(f"Unexpected message with type {type(message)}")messages.append("<start_of_turn>model\n")return"".join(messages)
def_parse_gemma_chat_response(response:str)->str:"""Removes chat history from the response."""pattern="<start_of_turn>model\n"pos=response.rfind(pattern)ifpos==-1:returnresponsetext=response[(pos+len(pattern)):]pos=text.find("<start_of_turn>user\n")ifpos>0:returntext[:pos]returntextclass_GemmaBase(BaseModel):max_tokens:Optional[int]=None"""The maximum number of tokens to generate."""temperature:Optional[float]=None"""The temperature to use for sampling."""top_p:Optional[float]=None"""The top-p value to use for sampling."""top_k:Optional[int]=None"""The top-k value to use for sampling."""model_config=ConfigDict(protected_namespaces=())@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling gemma."""params={"max_tokens":self.max_tokens,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,}return{k:vfork,vinparams.items()}def_get_params(self,**kwargs)->Dict[str,Any]:params={k:kwargs.get(k,v)fork,vinself._default_params.items()}return{k:vfork,vinparams.items()ifvisnotNone}
[docs]classGemmaVertexAIModelGarden(VertexAIModelGarden):allowed_model_args:Optional[List[str]]=["temperature","top_p","top_k","max_tokens",]@propertydef_llm_type(self)->str:return"gemma_vertexai_model_garden"# Needed so that mypy doesn't flag missing aliased init args.def__init__(self,**kwargs:Any)->None:super().__init__(**kwargs)
[docs]classGemmaChatVertexAIModelGarden(_GemmaBase,_BaseVertexAIModelGarden,BaseChatModel):allowed_model_args:Optional[List[str]]=["temperature","top_p","top_k","max_tokens","max_length",]parse_response:bool=False"""Whether to post-process the chat response and clean repeations """"""or multi-turn statements."""def__init__(self,*,model_name:Optional[str]=None,**kwargs:Any)->None:"""Needed for mypy typing to recognize model_name as a valid arg."""ifmodel_name:kwargs["model_name"]=model_namesuper().__init__(**kwargs)model_config=ConfigDict(populate_by_name=True,protected_namespaces=(),)@propertydef_llm_type(self)->str:return"gemma_vertexai_model_garden"@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling gemma."""# support both Gemma 1B and 2Bparams=super()._default_paramsparams["max_length"]=self.max_tokensreturnparamsdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:request=self._get_params(**kwargs)request["prompt"]=gemma_messages_to_prompt(messages)output=self.client.predict(endpoint=self.endpoint_path,instances=[request])text=output.predictions[0]ifself.parse_responseorkwargs.get("parse_response"):text=_parse_gemma_chat_response(text)ifstop:text=enforce_stop_tokens(text,stop)generations=[ChatGeneration(message=AIMessage(content=text),)]returnChatResult(generations=generations)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:"""Top Level call"""request=self._get_params(**kwargs)request["prompt"]=gemma_messages_to_prompt(messages)output=awaitself.async_client.predict(endpoint=self.endpoint_path,instances=[request])text=output.predictions[0]ifself.parse_responseorkwargs.get("parse_response"):text=_parse_gemma_chat_response(text)ifstop:text=enforce_stop_tokens(text,stop)generations=[ChatGeneration(message=AIMessage(content=text),)]returnChatResult(generations=generations)
class_GemmaLocalKaggleBase(_GemmaBase):"""Local gemma model loaded from Kaggle."""client:Any=Field(default=None,exclude=True)#: :meta private:keras_backend:str="jax"model_name:str=Field(default="gemma_2b_en",alias="model")"""Gemma model name."""model_config=ConfigDict(populate_by_name=True,)def__init__(self,*,model_name:Optional[str]=None,**kwargs:Any)->None:"""Needed for mypy typing to recognize model_name as a valid arg."""ifmodel_name:kwargs["model_name"]=model_namesuper().__init__(**kwargs)@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that llama-cpp-python library is installed."""try:os.environ["KERAS_BACKEND"]=self.keras_backendfromkeras_nlp.modelsimportGemmaCausalLM# type: ignoreexceptImportError:raiseImportError("Could not import GemmaCausalLM library. ""Please install the GemmaCausalLM library to ""use this model: pip install keras-nlp keras>=3 kaggle")self.client=GemmaCausalLM.from_preset(self.model_name)returnself@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling gemma."""params={"max_length":self.max_tokens}return{k:vfork,vinparams.items()ifvisnotNone}def_get_params(self,**kwargs)->Dict[str,Any]:mapping={"max_tokens":"max_length"}params={mapping[k]:vfork,vinkwargs.items()ifkinmapping}return{**self._default_params,**params}
[docs]classGemmaLocalKaggle(_GemmaLocalKaggleBase,BaseLLM):"""Local gemma chat model loaded from Kaggle."""def__init__(self,*,model_name:Optional[str]=None,**kwargs:Any)->None:"""Only needed for typing."""ifmodel_name:kwargs["model_name"]=model_namesuper().__init__(**kwargs)def_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:"""Run the LLM on the given prompt and input."""params=self._get_params(**kwargs)results=self.client.generate(prompts,**params)results=[results]ifisinstance(results,str)elseresultsifstop:results=[enforce_stop_tokens(text,stop)fortextinresults]returnLLMResult(generations=[[Generation(text=result)]forresultinresults])@propertydef_llm_type(self)->str:"""Return type of llm."""return"gemma_local_kaggle"
[docs]classGemmaChatLocalKaggle(_GemmaLocalKaggleBase,BaseChatModel):parse_response:bool=False"""Whether to post-process the chat response and clean repeations """"""or multi-turn statements."""def__init__(self,*,model_name:Optional[str]=None,**kwargs:Any)->None:"""Needed for mypy typing to recognize model_name as a valid arg."""ifmodel_name:kwargs["model_name"]=model_namesuper().__init__(**kwargs)def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:params=self._get_params(**kwargs)prompt=gemma_messages_to_prompt(messages)text=self.client.generate(prompt,**params)ifself.parse_responseorkwargs.get("parse_response"):text=_parse_gemma_chat_response(text)ifstop:text=enforce_stop_tokens(text,stop)generation=ChatGeneration(message=AIMessage(content=text))returnChatResult(generations=[generation])@propertydef_llm_type(self)->str:"""Return type of llm."""return"gemma_local_chat_kaggle"
class_GemmaLocalHFBase(_GemmaBase):"""Local gemma model loaded from HuggingFace."""tokenizer:Any=None#: :meta private:client:Any=Field(default=None,exclude=True)#: :meta private:hf_access_token:strcache_dir:Optional[str]=Nonemodel_name:str=Field(default="google/gemma-2b",alias="model")"""Gemma model name."""model_config=ConfigDict(populate_by_name=True,)@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that llama-cpp-python library is installed."""try:fromtransformersimportAutoTokenizer,GemmaForCausalLM# type: ignoreexceptImportError:raiseImportError("Could not import GemmaForCausalLM library. ""Please install the GemmaForCausalLM library to ""use this model: pip install transformers>=4.38.1")self.tokenizer=AutoTokenizer.from_pretrained(self.model_name,token=self.hf_access_token)self.client=GemmaForCausalLM.from_pretrained(self.model_name,token=self.hf_access_token,cache_dir=self.cache_dir,)returnself@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling gemma."""params={"max_length":self.max_tokens}return{k:vfork,vinparams.items()ifvisnotNone}def_get_params(self,**kwargs)->Dict[str,Any]:mapping={"max_tokens":"max_length"}params={mapping[k]:vfork,vinkwargs.items()ifkinmapping}return{**self._default_params,**params}def_run(self,prompt:str,**kwargs:Any)->str:inputs=self.tokenizer(prompt,return_tensors="pt")params=self._get_params(**kwargs)generate_ids=self.client.generate(inputs.input_ids,**params)returnself.tokenizer.batch_decode(generate_ids,skip_special_tokens=True,clean_up_tokenization_spaces=False)[0]
[docs]classGemmaLocalHF(_GemmaLocalHFBase,BaseLLM):"""Local gemma model loaded from HuggingFace."""def_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:"""Run the LLM on the given prompt and input."""results=[self._run(prompt,**kwargs)forpromptinprompts]ifstop:results=[enforce_stop_tokens(text,stop)fortextinresults]returnLLMResult(generations=[[Generation(text=text)]fortextinresults])@propertydef_llm_type(self)->str:"""Return type of llm."""return"gemma_local_hf"
[docs]classGemmaChatLocalHF(_GemmaLocalHFBase,BaseChatModel):parse_response:bool=False"""Whether to post-process the chat response and clean repeations """"""or multi-turn statements."""def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:prompt=gemma_messages_to_prompt(messages)text=self._run(prompt,**kwargs)ifself.parse_responseorkwargs.get("parse_response"):text=_parse_gemma_chat_response(text)ifstop:text=enforce_stop_tokens(text,stop)generation=ChatGeneration(message=AIMessage(content=text))returnChatResult(generations=[generation])@propertydef_llm_type(self)->str:"""Return type of llm."""return"gemma_local_chat_hf"