Source code for langchain_community.chat_models.google_palm
"""Wrapper around Google's PaLM Chat API."""from__future__importannotationsimportloggingfromtypingimportTYPE_CHECKING,Any,Callable,Dict,List,Optional,castfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.chat_modelsimportBaseChatModelfromlangchain_core.messagesimport(AIMessage,BaseMessage,ChatMessage,HumanMessage,SystemMessage,)fromlangchain_core.outputsimport(ChatGeneration,ChatResult,)fromlangchain_core.pydantic_v1importBaseModel,SecretStrfromlangchain_core.utilsimportconvert_to_secret_str,get_from_dict_or_env,pre_initfromtenacityimport(before_sleep_log,retry,retry_if_exception_type,stop_after_attempt,wait_exponential,)ifTYPE_CHECKING:importgoogle.generativeaiasgenailogger=logging.getLogger(__name__)
[docs]classChatGooglePalmError(Exception):"""Error with the `Google PaLM` API."""
def_truncate_at_stop_tokens(text:str,stop:Optional[List[str]],)->str:"""Truncates text at the earliest stop token found."""ifstopisNone:returntextforstop_tokeninstop:stop_token_idx=text.find(stop_token)ifstop_token_idx!=-1:text=text[:stop_token_idx]returntextdef_response_to_result(response:genai.types.ChatResponse,stop:Optional[List[str]],)->ChatResult:"""Converts a PaLM API response into a LangChain ChatResult."""ifnotresponse.candidates:raiseChatGooglePalmError("ChatResponse must have at least one candidate.")generations:List[ChatGeneration]=[]forcandidateinresponse.candidates:author=candidate.get("author")ifauthorisNone:raiseChatGooglePalmError(f"ChatResponse must have an author: {candidate}")content=_truncate_at_stop_tokens(candidate.get("content",""),stop)ifcontentisNone:raiseChatGooglePalmError(f"ChatResponse must have a content: {candidate}")ifauthor=="ai":generations.append(ChatGeneration(text=content,message=AIMessage(content=content)))elifauthor=="human":generations.append(ChatGeneration(text=content,message=HumanMessage(content=content),))else:generations.append(ChatGeneration(text=content,message=ChatMessage(role=author,content=content),))returnChatResult(generations=generations)def_messages_to_prompt_dict(input_messages:List[BaseMessage],)->genai.types.MessagePromptDict:"""Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""importgoogle.generativeaiasgenaicontext:str=""examples:List[genai.types.MessageDict]=[]messages:List[genai.types.MessageDict]=[]remaining=list(enumerate(input_messages))whileremaining:index,input_message=remaining.pop(0)ifisinstance(input_message,SystemMessage):ifindex!=0:raiseChatGooglePalmError("System message must be first input message.")context=cast(str,input_message.content)elifisinstance(input_message,HumanMessage)andinput_message.example:ifmessages:raiseChatGooglePalmError("Message examples must come before other messages.")_,next_input_message=remaining.pop(0)ifisinstance(next_input_message,AIMessage)andnext_input_message.example:examples.extend([genai.types.MessageDict(author="human",content=input_message.content),genai.types.MessageDict(author="ai",content=next_input_message.content),])else:raiseChatGooglePalmError("Human example message must be immediately followed by an "" AI example response.")elifisinstance(input_message,AIMessage)andinput_message.example:raiseChatGooglePalmError("AI example message must be immediately preceded by a Human ""example message.")elifisinstance(input_message,AIMessage):messages.append(genai.types.MessageDict(author="ai",content=input_message.content))elifisinstance(input_message,HumanMessage):messages.append(genai.types.MessageDict(author="human",content=input_message.content))elifisinstance(input_message,ChatMessage):messages.append(genai.types.MessageDict(author=input_message.role,content=input_message.content))else:raiseChatGooglePalmError("Messages without an explicit role not supported by PaLM API.")returngenai.types.MessagePromptDict(context=context,examples=examples,messages=messages,)def_create_retry_decorator()->Callable[[Any],Any]:"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""importgoogle.api_core.exceptionsmultiplier=2min_seconds=1max_seconds=60max_retries=10returnretry(reraise=True,stop=stop_after_attempt(max_retries),wait=wait_exponential(multiplier=multiplier,min=min_seconds,max=max_seconds),retry=(retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)|retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)|retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)),before_sleep=before_sleep_log(logger,logging.WARNING),)
[docs]defchat_with_retry(llm:ChatGooglePalm,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator()@retry_decoratordef_chat_with_retry(**kwargs:Any)->Any:returnllm.client.chat(**kwargs)return_chat_with_retry(**kwargs)
[docs]asyncdefachat_with_retry(llm:ChatGooglePalm,**kwargs:Any)->Any:"""Use tenacity to retry the async completion call."""retry_decorator=_create_retry_decorator()@retry_decoratorasyncdef_achat_with_retry(**kwargs:Any)->Any:# Use OpenAI's async api https://github.com/openai/openai-python#async-apireturnawaitllm.client.chat_async(**kwargs)returnawait_achat_with_retry(**kwargs)
[docs]classChatGooglePalm(BaseChatModel,BaseModel):"""`Google PaLM` Chat models API. To use you must have the google.generativeai Python package installed and either: 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or 2. Pass your API key using the google_api_key kwarg to the ChatGoogle constructor. Example: .. code-block:: python from langchain_community.chat_models import ChatGooglePalm chat = ChatGooglePalm() """client:Any#: :meta private:model_name:str="models/chat-bison-001""""Model name to use."""google_api_key:Optional[SecretStr]=Nonetemperature:Optional[float]=None"""Run inference with this temperature. Must be in the closed interval [0.0, 1.0]."""top_p:Optional[float]=None"""Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""top_k:Optional[int]=None"""Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive."""n:int=1"""Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated."""@propertydeflc_secrets(self)->Dict[str,str]:return{"google_api_key":"GOOGLE_API_KEY"}@classmethoddefis_lc_serializable(self)->bool:returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","chat_models","google_palm"]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate api key, python package exists, temperature, top_p, and top_k."""google_api_key=convert_to_secret_str(get_from_dict_or_env(values,"google_api_key","GOOGLE_API_KEY"))try:importgoogle.generativeaiasgenaigenai.configure(api_key=google_api_key.get_secret_value())exceptImportError:raiseChatGooglePalmError("Could not import google.generativeai python package. ""Please install it with `pip install google-generativeai`")values["client"]=genaiifvalues["temperature"]isnotNoneandnot0<=values["temperature"]<=1:raiseValueError("temperature must be in the range [0.0, 1.0]")ifvalues["top_p"]isnotNoneandnot0<=values["top_p"]<=1:raiseValueError("top_p must be in the range [0.0, 1.0]")ifvalues["top_k"]isnotNoneandvalues["top_k"]<=0:raiseValueError("top_k must be positive")returnvaluesdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:prompt=_messages_to_prompt_dict(messages)response:genai.types.ChatResponse=chat_with_retry(self,model=self.model_name,prompt=prompt,temperature=self.temperature,top_p=self.top_p,top_k=self.top_k,candidate_count=self.n,**kwargs,)return_response_to_result(response,stop)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:prompt=_messages_to_prompt_dict(messages)response:genai.types.ChatResponse=awaitachat_with_retry(self,model=self.model_name,prompt=prompt,temperature=self.temperature,top_p=self.top_p,top_k=self.top_k,candidate_count=self.n,)return_response_to_result(response,stop)@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{"model_name":self.model_name,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,"n":self.n,}@propertydef_llm_type(self)->str:return"google-palm-chat"