from__future__importannotationsimportjsonfromabcimportABC,abstractmethodfromtypingimportAny,Dict,Iterator,List,Literal,Optional,Union,overloadfromai21.modelsimportRoleTypefromai21.models.chatimportAssistantMessageasAI21AssistantMessagefromai21.models.chatimportChatCompletionChunk,ChatMessageParamfromai21.models.chatimportChatMessageasAI21ChatMessagefromai21.models.chatimportSystemMessageasAI21SystemMessagefromai21.models.chatimportToolCallasAI21ToolCallfromai21.models.chatimportToolFunctionasAI21ToolFunctionfromai21.models.chatimportToolMessageasAI21ToolMessagefromai21.models.chatimportUserMessageasAI21UserMessagefromai21.stream.streamimportStreamasAI21Streamfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,BaseMessageChunk,HumanMessage,SystemMessage,ToolCall,ToolMessage,)fromlangchain_core.messages.aiimportUsageMetadatafromlangchain_core.output_parsers.openai_toolsimportparse_tool_callfromlangchain_core.outputsimportChatGenerationChunk_ChatMessageTypes=Union[AI21ChatMessage]_SYSTEM_ERR_MESSAGE="System message must be at beginning of message list."_ROLE_TYPE=Union[str,RoleType]
[docs]classChatAdapter(ABC):"""Common interface for the different Chat models available in AI21. It converts LangChain messages to AI21 messages. Calls the appropriate AI21 model API with the converted messages. """
def_convert_message_to_ai21_message(self,message:BaseMessage,)->_ChatMessageTypes:role=self._parse_role(message)returnself._chat_message(role=role,message=message)def_parse_role(self,message:BaseMessage)->_ROLE_TYPE:ifisinstance(message,SystemMessage):returnRoleType.SYSTEMifisinstance(message,HumanMessage):returnRoleType.USERifisinstance(message,AIMessage):returnRoleType.ASSISTANTifisinstance(message,ToolMessage):returnRoleType.TOOL# if it gets here, we rely on the server to handle the role typereturnmessage.type@abstractmethoddef_chat_message(self,role:_ROLE_TYPE,message:BaseMessage,)->_ChatMessageTypes:pass@overloaddefcall(self,client:Any,stream:Literal[True],**params:Any,)->Iterator[ChatGenerationChunk]:pass@overloaddefcall(self,client:Any,stream:Literal[False],**params:Any,)->List[BaseMessage]:pass
def_get_system_message_from_message(self,message:BaseMessage)->str:ifnotisinstance(message.content,str):raiseValueError(f"System Message must be of type str. Got {type(message.content)}")returnmessage.content
[docs]classJambaChatCompletionsAdapter(ChatAdapter):"""Adapter for Jamba Chat Completions."""
def_convert_lc_tool_calls_to_ai21_tool_calls(self,tool_calls:List[ToolCall])->Optional[List[AI21ToolCall]]:""" Convert Langchain ToolCalls to AI21 ToolCalls. """ai21_tool_calls:List[AI21ToolCall]=[]forlc_tool_callintool_calls:if"id"notinlc_tool_callornotlc_tool_call["id"]:raiseValueError("Tool call ID is missing or empty.")ai21_tool_call=AI21ToolCall(id=lc_tool_call["id"],type="function",function=AI21ToolFunction(name=lc_tool_call["name"],arguments=json.dumps(lc_tool_call["args"]),),)ai21_tool_calls.append(ai21_tool_call)returnai21_tool_callsdef_get_content_as_string(self,base_message:BaseMessage)->str:ifisinstance(base_message.content,str):returnbase_message.contentelifisinstance(base_message.content,list):return"\n".join(str(item)foriteminbase_message.content)else:raiseValueError("Unsupported content type")def_chat_message(self,role:_ROLE_TYPE,message:BaseMessage,)->ChatMessageParam:content=self._get_content_as_string(message)ifisinstance(message,AIMessage):returnAI21AssistantMessage(tool_calls=self._convert_lc_tool_calls_to_ai21_tool_calls(message.tool_calls),content=contentorNone,)ifisinstance(message,ToolMessage):returnAI21ToolMessage(tool_call_id=message.tool_call_id,content=content,)ifisinstance(message,HumanMessage):returnAI21UserMessage(content=content,)ifisinstance(message,SystemMessage):returnAI21SystemMessage(content=content,)returnAI21ChatMessage(role=role.valueifisinstance(role,RoleType)elserole,content=content,)@overloaddefcall(self,client:Any,stream:Literal[True],**params:Any,)->Iterator[ChatGenerationChunk]:...@overloaddefcall(self,client:Any,stream:Literal[False],**params:Any,)->List[BaseMessage]:...