[docs]classChatAI21(BaseChatModel,AI21Base):"""ChatAI21 chat model. Different model types support different parameters and different parameter values. Please read the [AI21 reference documentation] (https://docs.ai21.com/reference) for your model to understand which parameters are available. Example: .. code-block:: python from langchain_ai21 import ChatAI21 model = ChatAI21( # defaults to os.environ.get("AI21_API_KEY") api_key="my_api_key" ) """model:str"""Model type you wish to interact with. You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""num_results:int=1"""The number of responses to generate for a given prompt."""stop:Optional[List[str]]=None"""Default stop sequences."""max_tokens:int=512"""The maximum number of tokens to generate for each response."""min_tokens:int=0"""The minimum number of tokens to generate for each response. _Not supported for all models._"""temperature:float=0.4"""A value controlling the "creativity" of the model's responses."""top_p:float=1"""A value controlling the diversity of the model's responses."""top_k_return:int=0"""The number of top-scoring tokens to consider for each generation step. _Not supported for all models._"""frequency_penalty:Optional[Any]=None"""A penalty applied to tokens that are frequently generated. _Not supported for all models._"""presence_penalty:Optional[Any]=None""" A penalty applied to tokens that are already present in the prompt. _Not supported for all models._"""count_penalty:Optional[Any]=None"""A penalty applied to tokens based on their frequency in the generated responses. _Not supported for all models._"""n:int=1"""Number of chat completions to generate for each prompt."""streaming:bool=False_chat_adapter:ChatAdapter@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate the environment."""model=values["model"]values["_chat_adapter"]=create_chat_adapter(model)returnvaluesclassConfig:"""Configuration for this pydantic object."""arbitrary_types_allowed=True@propertydef_llm_type(self)->str:"""Return type of chat model."""return"chat-ai21"@propertydef_default_params(self)->Mapping[str,Any]:base_params={"model":self.model,"num_results":self.num_results,"max_tokens":self.max_tokens,"min_tokens":self.min_tokens,"temperature":self.temperature,"top_p":self.top_p,"top_k_return":self.top_k_return,"n":self.n,}ifself.stop:base_params["stop_sequences"]=self.stopifself.count_penaltyisnotNone:base_params["count_penalty"]=self.count_penalty.to_dict()ifself.frequency_penaltyisnotNone:base_params["frequency_penalty"]=self.frequency_penalty.to_dict()ifself.presence_penaltyisnotNone:base_params["presence_penalty"]=self.presence_penalty.to_dict()returnbase_paramsdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._get_invocation_params(stop=stop,**kwargs)ls_params=LangSmithParams(ls_provider="ai21",ls_model_name=self.model,ls_model_type="chat",ls_temperature=params.get("temperature",self.temperature),)ifls_max_tokens:=params.get("max_tokens",self.max_tokens):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporparams.get("stop",None)orself.stop:ls_params["ls_stop"]=ls_stopreturnls_paramsdef_build_params_for_request(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,**kwargs:Any,)->Mapping[str,Any]:params={}converted_messages=self._chat_adapter.convert_messages(messages)ifstopisnotNone:if"stop"inkwargs:raiseValueError("stop is defined in both stop and kwargs")params["stop_sequences"]=stopreturn{**converted_messages,**self._default_params,**params,**kwargs,}def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamorself.streamingifshould_stream:returnself._handle_stream_from_generate(messages=messages,stop=stop,run_manager=run_manager,**kwargs,)params=self._build_params_for_request(messages=messages,stop=stop,stream=should_stream,**kwargs,)messages=self._chat_adapter.call(self.client,**params)generations=[ChatGeneration(message=message)formessageinmessages]returnChatResult(generations=generations)def_handle_stream_from_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:stream_iter=self._stream(messages=messages,stop=stop,run_manager=run_manager,**kwargs,)returngenerate_from_stream(stream_iter)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:params=self._build_params_for_request(messages=messages,stop=stop,stream=True,**kwargs,)forchunkinself._chat_adapter.call(self.client,**params):ifrun_managerandisinstance(chunk.message.content,str):run_manager.on_llm_new_token(token=chunk.message.content,chunk=chunk)yieldchunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:returnawaitasyncio.get_running_loop().run_in_executor(None,partial(self._generate,**kwargs),messages,stop,run_manager)def_get_system_message_from_message(self,message:BaseMessage)->str:ifnotisinstance(message.content,str):raiseValueError(f"System Message must be of type str. Got {type(message.content)}")returnmessage.content