[docs]classChatWriter(BaseChatModel):"""Writer chat model. To use, you should have the ``writer-sdk`` Python package installed, and the environment variable ``WRITER_API_KEY`` set with your API key or pass 'api_key' init param. Example: .. code-block:: python from langchain_community.chat_models import ChatWriter chat = ChatWriter( api_key="your key" model="palmyra-x-004" ) """client:Any=Field(default=None,exclude=True)#: :meta private:async_client:Any=Field(default=None,exclude=True)#: :meta private:api_key:Optional[SecretStr]=Field(default=None)"""Writer API key."""model_name:str=Field(default="palmyra-x-004",alias="model")"""Model name to use."""temperature:float=0.7"""What sampling temperature to use."""model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for `create` call not explicitly specified."""n:int=1"""Number of chat completions to generate for each prompt."""max_tokens:Optional[int]=None"""Maximum number of tokens to generate."""model_config=ConfigDict(populate_by_name=True)@propertydef_llm_type(self)->str:"""Return type of chat model."""return"writer-chat"@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{"model_name":self.model_name,"temperature":self.temperature,**self.model_kwargs,}@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling Writer API."""return{"model":self.model_name,"temperature":self.temperature,"n":self.n,"max_tokens":self.max_tokens,**self.model_kwargs,}@model_validator(mode="before")@classmethoddefvalidate_environment(cls,values:Dict)->Any:"""Validates that api key is passed and creates Writer clients."""try:fromwriteraiimportAsyncClient,ClientexceptImportErrorase:raiseImportError("Could not import writerai python package. ""Please install it with `pip install writerai`.")fromeifnotvalues.get("client"):values.update({"client":Client(api_key=get_from_dict_or_env(values,"api_key","WRITER_API_KEY"))})ifnotvalues.get("async_client"):values.update({"async_client":AsyncClient(api_key=get_from_dict_or_env(values,"api_key","WRITER_API_KEY"))})ifnot(type(values.get("client"))isClientandtype(values.get("async_client"))isAsyncClient):raiseValueError("'client' attribute must be with type 'Client' and ""'async_client' must be with type 'AsyncClient' from 'writerai' package")returnvaluesdef_create_chat_result(self,response:Any)->ChatResult:generations=[]forchoiceinresponse.choices:message=self._convert_writer_to_langchain(choice.message)gen=ChatGeneration(message=message,generation_info=dict(finish_reason=choice.finish_reason),)generations.append(gen)token_usage={}ifresponse.usage:token_usage=response.usage.__dict__llm_output={"token_usage":token_usage,"model_name":self.model_name,"system_fingerprint":response.system_fingerprint,}returnChatResult(generations=generations,llm_output=llm_output)@staticmethoddef_convert_langchain_to_writer(message:BaseMessage)->dict:"""Convert a LangChain message to a Writer message dict."""message_dict={"role":"","content":message.content}ifisinstance(message,ChatMessage):message_dict["role"]=message.roleelifisinstance(message,HumanMessage):message_dict["role"]="user"elifisinstance(message,AIMessage):message_dict["role"]="assistant"ifmessage.tool_calls:message_dict["tool_calls"]=[{"id":tool["id"],"type":"function","function":{"name":tool["name"],"arguments":tool["args"]},}fortoolinmessage.tool_calls]elifisinstance(message,SystemMessage):message_dict["role"]="system"elifisinstance(message,ToolMessage):message_dict["role"]="tool"message_dict["tool_call_id"]=message.tool_call_idelse:raiseValueError(f"Got unknown message type: {type(message)}")ifmessage.name:message_dict["name"]=message.namereturnmessage_dict@staticmethoddef_convert_writer_to_langchain(response_message:Any)->BaseMessage:"""Convert a Writer message to a LangChain message."""ifnotisinstance(response_message,dict):response_message=json.loads(json.dumps(response_message,default=lambdao:o.__dict__))role=response_message.get("role","")content=response_message.get("content")ifnotcontent:content=""ifrole=="user":returnHumanMessage(content=content)elifrole=="assistant":additional_kwargs={}iftool_calls:=response_message.get("tool_calls",[]):additional_kwargs["tool_calls"]=tool_callsreturnAIMessageChunk(content=content,additional_kwargs=additional_kwargs)elifrole=="system":returnSystemMessage(content=content)elifrole=="tool":returnToolMessage(content=content,tool_call_id=response_message.get("tool_call_id",""),name=response_message.get("name",""),)else:returnChatMessage(content=content,role=role)def_convert_messages_to_writer(self,messages:List[BaseMessage],stop:Optional[List[str]]=None)->Tuple[List[Dict[str,Any]],Dict[str,Any]]:"""Convert a list of LangChain messages to List of Writer dicts."""params={"model":self.model_name,"temperature":self.temperature,"n":self.n,**self.model_kwargs,}ifstop:params["stop"]=stopifself.max_tokensisnotNone:params["max_tokens"]=self.max_tokensmessage_dicts=[self._convert_langchain_to_writer(m)forminmessages]returnmessage_dicts,paramsdef_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:message_dicts,params=self._convert_messages_to_writer(messages,stop)params={**params,**kwargs,"stream":True}response=self.client.chat.chat(messages=message_dicts,**params)forchunkinresponse:delta=chunk.choices[0].deltaifnotdeltaornotdelta.content:continuechunk=self._convert_writer_to_langchain({"role":"assistant","content":delta.content,})chunk=ChatGenerationChunk(message=chunk)ifrun_manager:run_manager.on_llm_new_token(chunk.text)yieldchunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:message_dicts,params=self._convert_messages_to_writer(messages,stop)params={**params,**kwargs,"stream":True}response=awaitself.async_client.chat.chat(messages=message_dicts,**params)asyncforchunkinresponse:delta=chunk.choices[0].deltaifnotdeltaornotdelta.content:continuechunk=self._convert_writer_to_langchain({"role":"assistant","content":delta.content,})chunk=ChatGenerationChunk(message=chunk)ifrun_manager:awaitrun_manager.on_llm_new_token(chunk.text)yieldchunkdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:message_dicts,params=self._convert_messages_to_writer(messages,stop)params={**params,**kwargs}response=self.client.chat.chat(messages=message_dicts,**params)returnself._create_chat_result(response)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:message_dicts,params=self._convert_messages_to_writer(messages,stop)params={**params,**kwargs}response=awaitself.async_client.chat.chat(messages=message_dicts,**params)returnself._create_chat_result(response)
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type[BaseModel],Callable]],*,tool_choice:Optional[Union[str,Literal["auto","none"]]]=None,**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tools to the chat model. Args: tools: Tools to bind to the model tool_choice: Which tool to require ('auto', 'none', or specific tool name) **kwargs: Additional parameters to pass to the chat model Returns: A runnable that will use the tools """formatted_tools=[convert_to_openai_tool(tool)fortoolintools]iftool_choice:kwargs["tool_choice"]=((tool_choice)iftool_choicein("auto","none")else{"type":"function","function":{"name":tool_choice}})returnsuper().bind(tools=formatted_tools,**kwargs)