[docs]classChatAI21(BaseChatModel,AI21Base):"""ChatAI21 chat model. Different model types support different parameters and different parameter values. Please read the [AI21 reference documentation] (https://docs.ai21.com/reference) for your model to understand which parameters are available. Example: .. code-block:: python from langchain_ai21 import ChatAI21 model = ChatAI21( # defaults to os.environ.get("AI21_API_KEY") api_key="my_api_key" ) """model:str"""Model type you wish to interact with. You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""stop:Optional[List[str]]=None"""Default stop sequences."""max_tokens:int=512"""The maximum number of tokens to generate for each response."""temperature:float=0.4"""A value controlling the "creativity" of the model's responses."""top_p:float=1"""A value controlling the diversity of the model's responses."""n:int=1"""Number of chat completions to generate for each prompt."""streaming:bool=False_chat_adapter:ChatAdapter@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate the environment."""model=self.modelself._chat_adapter=create_chat_adapter(model)returnselfmodel_config=ConfigDict(arbitrary_types_allowed=True,)@propertydef_llm_type(self)->str:"""Return type of chat model."""return"chat-ai21"@propertydef_default_params(self)->Mapping[str,Any]:base_params={"model":self.model,"max_tokens":self.max_tokens,"temperature":self.temperature,"top_p":self.top_p,"n":self.n,}ifself.stop:base_params["stop_sequences"]=self.stopreturnbase_paramsdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._get_invocation_params(stop=stop,**kwargs)ls_params=LangSmithParams(ls_provider="ai21",ls_model_name=self.model,ls_model_type="chat",ls_temperature=params.get("temperature",self.temperature),)ifls_max_tokens:=params.get("max_tokens",self.max_tokens):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporparams.get("stop",None)orself.stop:ls_params["ls_stop"]=ls_stopreturnls_paramsdef_build_params_for_request(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,**kwargs:Any,)->Mapping[str,Any]:params={}converted_messages=self._chat_adapter.convert_messages(messages)ifstopisnotNone:if"stop"inkwargs:raiseValueError("stop is defined in both stop and kwargs")params["stop_sequences"]=stopreturn{**converted_messages,**self._default_params,**params,**kwargs,}def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamorself.streamingifshould_stream:returnself._handle_stream_from_generate(messages=messages,stop=stop,run_manager=run_manager,**kwargs,)params=self._build_params_for_request(messages=messages,stop=stop,stream=should_stream,**kwargs,)messages=self._chat_adapter.call(self.client,**params)generations=[ChatGeneration(message=message)formessageinmessages]returnChatResult(generations=generations)def_handle_stream_from_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:stream_iter=self._stream(messages=messages,stop=stop,run_manager=run_manager,**kwargs,)returngenerate_from_stream(stream_iter)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:params=self._build_params_for_request(messages=messages,stop=stop,stream=True,**kwargs,)forchunkinself._chat_adapter.call(self.client,**params):ifrun_managerandisinstance(chunk.message.content,str):run_manager.on_llm_new_token(token=chunk.message.content,chunk=chunk)yieldchunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:returnawaitasyncio.get_running_loop().run_in_executor(None,partial(self._generate,**kwargs),messages,stop,run_manager)def_get_system_message_from_message(self,message:BaseMessage)->str:ifnotisinstance(message.content,str):raiseValueError(f"System Message must be of type str. Got {type(message.content)}")returnmessage.content