[docs]classChatSnowflakeCortexError(Exception):"""Error with Snowpark client."""
def_convert_message_to_dict(message:BaseMessage)->dict:"""Convert a LangChain message to a dictionary. Args: message: The LangChain message. Returns: The dictionary. """message_dict:Dict[str,Any]={"content":message.content,}# populate role and additional message dataifisinstance(message,ChatMessage)andmessage.roleinSUPPORTED_ROLES:message_dict["role"]=message.roleelifisinstance(message,SystemMessage):message_dict["role"]="system"elifisinstance(message,HumanMessage):message_dict["role"]="user"elifisinstance(message,AIMessage):message_dict["role"]="assistant"else:raiseTypeError(f"Got unknown type {message}")returnmessage_dictdef_truncate_at_stop_tokens(text:str,stop:Optional[List[str]],)->str:"""Truncates text at the earliest stop token found."""ifstopisNone:returntextforstop_tokeninstop:stop_token_idx=text.find(stop_token)ifstop_token_idx!=-1:text=text[:stop_token_idx]returntext
[docs]classChatSnowflakeCortex(BaseChatModel):"""Snowflake Cortex based Chat model To use you must have the ``snowflake-snowpark-python`` Python package installed and either: 1. environment variables set with your snowflake credentials or 2. directly passed in as kwargs to the ChatSnowflakeCortex constructor. Example: .. code-block:: python from langchain_community.chat_models import ChatSnowflakeCortex chat = ChatSnowflakeCortex() """_sp_session:Any=None"""Snowpark session object."""model:str="snowflake-arctic""""Snowflake cortex hosted LLM model name, defaulted to `snowflake-arctic`. Refer to docs for more options."""cortex_function:str="complete""""Cortex function to use, defaulted to `complete`. Refer to docs for more options."""temperature:float=0.7"""Model temperature. Value should be >= 0 and <= 1.0"""max_tokens:Optional[int]=None"""The maximum number of output tokens in the response."""top_p:Optional[float]=None"""top_p adjusts the number of choices for each predicted tokens based on cumulative probabilities. Value should be ranging between 0.0 and 1.0. """snowflake_username:Optional[str]=Field(default=None,alias="username")"""Automatically inferred from env var `SNOWFLAKE_USERNAME` if not provided."""snowflake_password:Optional[SecretStr]=Field(default=None,alias="password")"""Automatically inferred from env var `SNOWFLAKE_PASSWORD` if not provided."""snowflake_account:Optional[str]=Field(default=None,alias="account")"""Automatically inferred from env var `SNOWFLAKE_ACCOUNT` if not provided."""snowflake_database:Optional[str]=Field(default=None,alias="database")"""Automatically inferred from env var `SNOWFLAKE_DATABASE` if not provided."""snowflake_schema:Optional[str]=Field(default=None,alias="schema")"""Automatically inferred from env var `SNOWFLAKE_SCHEMA` if not provided."""snowflake_warehouse:Optional[str]=Field(default=None,alias="warehouse")"""Automatically inferred from env var `SNOWFLAKE_WAREHOUSE` if not provided."""snowflake_role:Optional[str]=Field(default=None,alias="role")"""Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided."""@root_validator(pre=True)defbuild_extra(cls,values:Dict[str,Any])->Dict[str,Any]:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)extra=values.get("model_kwargs",{})values["model_kwargs"]=build_extra_kwargs(extra,values,all_required_field_names)returnvalues@pre_initdefvalidate_environment(cls,values:Dict)->Dict:try:fromsnowflake.snowparkimportSessionexceptImportError:raiseImportError("`snowflake-snowpark-python` package not found, please install it with ""`pip install snowflake-snowpark-python`")values["snowflake_username"]=get_from_dict_or_env(values,"snowflake_username","SNOWFLAKE_USERNAME")values["snowflake_password"]=convert_to_secret_str(get_from_dict_or_env(values,"snowflake_password","SNOWFLAKE_PASSWORD"))values["snowflake_account"]=get_from_dict_or_env(values,"snowflake_account","SNOWFLAKE_ACCOUNT")values["snowflake_database"]=get_from_dict_or_env(values,"snowflake_database","SNOWFLAKE_DATABASE")values["snowflake_schema"]=get_from_dict_or_env(values,"snowflake_schema","SNOWFLAKE_SCHEMA")values["snowflake_warehouse"]=get_from_dict_or_env(values,"snowflake_warehouse","SNOWFLAKE_WAREHOUSE")values["snowflake_role"]=get_from_dict_or_env(values,"snowflake_role","SNOWFLAKE_ROLE")connection_params={"account":values["snowflake_account"],"user":values["snowflake_username"],"password":values["snowflake_password"].get_secret_value(),"database":values["snowflake_database"],"schema":values["snowflake_schema"],"warehouse":values["snowflake_warehouse"],"role":values["snowflake_role"],}try:values["_sp_session"]=Session.builder.configs(connection_params).create()exceptExceptionase:raiseChatSnowflakeCortexError(f"Failed to create session: {e}")returnvaluesdef__del__(self)->None:ifgetattr(self,"_sp_session",None)isnotNone:self._sp_session.close()@propertydef_llm_type(self)->str:"""Get the type of language model used by this chat model."""returnf"snowflake-cortex-{self.model}"def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:message_dicts=[_convert_message_to_dict(m)forminmessages]message_str=str(message_dicts)options={"temperature":self.temperature}ifself.top_pisnotNone:options["top_p"]=self.top_pifself.max_tokensisnotNone:options["max_tokens"]=self.max_tokensoptions_str=str(options)sql_stmt=f""" select snowflake.cortex.{self.cortex_function}( '{self.model}' ,{message_str},{options_str}) as llm_response;"""try:l_rows=self._sp_session.sql(sql_stmt).collect()exceptExceptionase:raiseChatSnowflakeCortexError(f"Error while making request to Snowflake Cortex via Snowpark: {e}")response=json.loads(l_rows[0]["LLM_RESPONSE"])ai_message_content=response["choices"][0]["messages"]content=_truncate_at_stop_tokens(ai_message_content,stop)message=AIMessage(content=content,response_metadata=response["usage"],)generation=ChatGeneration(message=message)returnChatResult(generations=[generation])