"""Wrapper around Google's PaLM Chat API."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import BaseModel, SecretStr
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
if TYPE_CHECKING:
import google.generativeai as genai
logger = logging.getLogger(__name__)
[docs]
class ChatGooglePalmError(Exception):
"""Error with the `Google PaLM` API."""
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
def _response_to_result(
response: genai.types.ChatResponse,
stop: Optional[List[str]],
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
if not response.candidates:
raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
generations: List[ChatGeneration] = []
for candidate in response.candidates:
author = candidate.get("author")
if author is None:
raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
if content is None:
raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
if author == "ai":
generations.append(
ChatGeneration(text=content, message=AIMessage(content=content))
)
elif author == "human":
generations.append(
ChatGeneration(
text=content,
message=HumanMessage(content=content),
)
)
else:
generations.append(
ChatGeneration(
text=content,
message=ChatMessage(role=author, content=content),
)
)
return ChatResult(generations=generations)
def _messages_to_prompt_dict(
input_messages: List[BaseMessage],
) -> genai.types.MessagePromptDict:
"""Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""
import google.generativeai as genai
context: str = ""
examples: List[genai.types.MessageDict] = []
messages: List[genai.types.MessageDict] = []
remaining = list(enumerate(input_messages))
while remaining:
index, input_message = remaining.pop(0)
if isinstance(input_message, SystemMessage):
if index != 0:
raise ChatGooglePalmError("System message must be first input message.")
context = cast(str, input_message.content)
elif isinstance(input_message, HumanMessage) and input_message.example:
if messages:
raise ChatGooglePalmError(
"Message examples must come before other messages."
)
_, next_input_message = remaining.pop(0)
if isinstance(next_input_message, AIMessage) and next_input_message.example:
examples.extend(
[
genai.types.MessageDict(
author="human", content=input_message.content
),
genai.types.MessageDict(
author="ai", content=next_input_message.content
),
]
)
else:
raise ChatGooglePalmError(
"Human example message must be immediately followed by an "
" AI example response."
)
elif isinstance(input_message, AIMessage) and input_message.example:
raise ChatGooglePalmError(
"AI example message must be immediately preceded by a Human "
"example message."
)
elif isinstance(input_message, AIMessage):
messages.append(
genai.types.MessageDict(author="ai", content=input_message.content)
)
elif isinstance(input_message, HumanMessage):
messages.append(
genai.types.MessageDict(author="human", content=input_message.content)
)
elif isinstance(input_message, ChatMessage):
messages.append(
genai.types.MessageDict(
author=input_message.role, content=input_message.content
)
)
else:
raise ChatGooglePalmError(
"Messages without an explicit role not supported by PaLM API."
)
return genai.types.MessagePromptDict(
context=context,
examples=examples,
messages=messages,
)
def _create_retry_decorator() -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
import google.api_core.exceptions
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
[docs]
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
return llm.client.chat(**kwargs)
return _chat_with_retry(**kwargs)
[docs]
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs)
[docs]
class ChatGooglePalm(BaseChatModel, BaseModel):
"""`Google PaLM` Chat models API.
To use you must have the google.generativeai Python package installed and
either:
1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatGooglePalm
chat = ChatGooglePalm()
"""
client: Any #: :meta private:
model_name: str = "models/chat-bison-001"
"""Model name to use."""
google_api_key: Optional[SecretStr] = None
temperature: Optional[float] = None
"""Run inference with this temperature. Must be in the closed
interval [0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@classmethod
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "google_palm"]
[docs]
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
google_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY")
)
try:
import google.generativeai as genai
genai.configure(api_key=google_api_key.get_secret_value())
except ImportError:
raise ChatGooglePalmError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`"
)
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = chat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = await achat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
)
return _response_to_result(response, stop)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "google-palm-chat"