[docs]classChatMaestro(BaseChatModel,AI21Base):"""Chat model using Maestro LLM."""output_type:Optional[Dict[str,Any]]=None"""Optional dictionary specifying the output type."""models:Optional[List[str]]=None"""Optional list of model names to use. Available models https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""tools:Optional[List[Dict[str,ToolType]]]=None"""Optional list of tools."""tool_resources:Optional[ToolResources]=None"""Optional resources for the tools."""context:Optional[Dict[str,Any]]=None"""Optional dictionary providing context for the chat."""budget:Optional[Budget]=None"""Optional budget constraints for the chat."""poll_interval_sec:Optional[float]=1"""Interval in seconds for polling the run status."""poll_timeout_sec:Optional[float]=120"""Timeout in seconds for polling the run status."""@propertydef_llm_type(self)->str:"""Return the type of LLM."""return"chat-maestro"def_call(self,messages:List[BaseMessage],**kwargs:Any)->RunResponse:"""API call to Maestro."""payload=self._prepare_payload(messages,**kwargs)result=self.client.beta.maestro.runs.create_and_poll(**payload)ifresult.status!="completed":raiseRuntimeError(f"Maestro run failed with status: {result.status}")returnresultasyncdef_acall(self,messages:List[BaseMessage],**kwargs:Any)->RunResponse:"""Asynchronous API call to Maestro."""payload=self._prepare_payload(messages,**kwargs)result=awaitself.async_client.beta.maestro.runs.create_and_poll(**payload)ifresult.status!="completed":raiseRuntimeError(f"Maestro run failed with status: {result.status}")returnresultdef_generate(self,messages:list[BaseMessage],stop:Optional[list[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:"""Generates a response using Maestro LLM."""response_data=self._call(messages,**kwargs)returnself._handle_chat_result(response_data)asyncdef_agenerate(self,messages:list[BaseMessage],stop:Optional[list[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:"""Asynchronous agent call to Maestro."""response_data=awaitself._acall(messages,**kwargs)returnself._handle_chat_result(response_data)@staticmethoddef_prepare_payload(messages:List[BaseMessage],**kwargs:Any)->Dict[str,Any]:"""Prepare the payload for the API call with validation."""formatted_messages=[{"role":"user","content":message.content}formessageinmessages]payload={"input":formatted_messages,**kwargs}requirements=payload.pop("requirements",[])ifrequirements:ChatMaestro.validate_list(requirements,"requirements")payload["requirements"]=[{"name":req,"description":req}forreqinrequirements]variables=payload.pop("variables",[])ifvariables:ChatMaestro.validate_list(variables,"variables")variables_str=" ".join(variables)payload["requirements"]=payload.get("requirements",[])+[{"name":"output should contain only these variables:"f" {variables_str}","description":variables_str,}]returnpayload
[docs]@staticmethoddefvalidate_list(obj:List[str],obj_name:str,expected_type:Type=str)->None:"""Validate that obj is a list of the expected type."""ifobjisnotNoneand(notisinstance(obj,list)orany(notisinstance(var,expected_type)forvarinobj)):raiseValueError(f"{obj_name} must be a list of {expected_type.__name__}")
@staticmethoddef_handle_chat_result(response_data:RunResponse)->ChatResult:"""Handle the response data from the Maestro run."""ai_message=AIMessage(content=response_data.result)generation=ChatGeneration(message=ai_message)returnChatResult(generations=[generation])