Source code for langchain_community.chat_models.cloudflare_workersai
importloggingfromoperatorimportitemgetterfromtypingimport(Any,Callable,Dict,List,Literal,Optional,Sequence,Type,Union,cast,)fromuuidimportuuid4importrequestsfromlangchain.schemaimportAIMessage,ChatGeneration,ChatResult,HumanMessagefromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimportBaseChatModelfromlangchain_core.messagesimport(AIMessageChunk,BaseMessage,SystemMessage,ToolCall,ToolMessage,)fromlangchain_core.messages.toolimporttool_callfromlangchain_core.output_parsersimport(JsonOutputParser,PydanticOutputParser,)fromlangchain_core.output_parsers.baseimportOutputParserLikefromlangchain_core.output_parsers.openai_toolsimport(JsonOutputKeyToolsParser,PydanticToolsParser,)fromlangchain_core.runnablesimportRunnable,RunnablePassthroughfromlangchain_core.runnables.baseimportRunnableMapfromlangchain_core.toolsimportBaseToolfromlangchain_core.utils.function_callingimportconvert_to_openai_toolfromlangchain_core.utils.pydanticimportis_basemodel_subclassfrompydanticimportBaseModel,Field# Initialize logginglogging.basicConfig(level=logging.INFO,format="%(asctime)s - %(levelname)s - %(message)s",datefmt="%Y-%m-%d %H:%M:%S",)_logger=logging.getLogger(__name__)def_is_pydantic_class(obj:Any)->bool:returnisinstance(obj,type)andis_basemodel_subclass(obj)def_convert_messages_to_cloudflare_messages(messages:List[BaseMessage],)->List[Dict[str,Any]]:"""Convert LangChain messages to Cloudflare Workers AI format."""cloudflare_messages=[]msg:Dict[str,Any]formessageinmessages:# Base structure for each messagemsg={"role":"","content":message.contentifisinstance(message.content,str)else"",}# Determine role and additional fields based on message typeifisinstance(message,HumanMessage):msg["role"]="user"elifisinstance(message,AIMessage):msg["role"]="assistant"# If the AIMessage includes tool calls, format them as neededifmessage.tool_calls:tool_calls=[{"name":tool_call["name"],"arguments":tool_call["args"]}fortool_callinmessage.tool_calls]msg["tool_calls"]=tool_callselifisinstance(message,SystemMessage):msg["role"]="system"elifisinstance(message,ToolMessage):msg["role"]="tool"msg["tool_call_id"]=(message.tool_call_id)# Use tool_call_id if it's a ToolMessage# Add the formatted message to the listcloudflare_messages.append(msg)returncloudflare_messagesdef_get_tool_calls_from_response(response:requests.Response)->List[ToolCall]:"""Get tool calls from ollama response."""tool_calls=[]if"tool_calls"inresponse.json()["result"]:fortcinresponse.json()["result"]["tool_calls"]:tool_calls.append(tool_call(id=str(uuid4()),name=tc["name"],args=tc["arguments"],))returntool_calls
[docs]classChatCloudflareWorkersAI(BaseChatModel):"""Custom chat model for Cloudflare Workers AI"""account_id:str=Field(...)api_token:str=Field(...)model:str=Field(...)ai_gateway:str=""url:str=""base_url:str="https://api.cloudflare.com/client/v4/accounts"gateway_url:str="https://gateway.ai.cloudflare.com/v1"def__init__(self,**kwargs:Any)->None:"""Initialize with necessary credentials."""super().__init__(**kwargs)ifself.ai_gateway:self.url=(f"{self.gateway_url}/{self.account_id}/"f"{self.ai_gateway}/workers-ai/run/{self.model}")else:self.url=f"{self.base_url}/{self.account_id}/ai/run/{self.model}"def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:"""Generate a response based on the messages provided."""formatted_messages=_convert_messages_to_cloudflare_messages(messages)headers={"Authorization":f"Bearer {self.api_token}"}prompt="\n".join(f"role: {msg['role']}, content: {msg['content']}"+(f", tools: {msg['tool_calls']}"if"tool_calls"inmsgelse"")+(f", tool_call_id: {msg['tool_call_id']}"if"tool_call_id"inmsgelse"")formsginformatted_messages)# Initialize `data` with `prompt`data={"prompt":prompt,"tools":kwargs["tools"]if"tools"inkwargselseNone,**{key:valueforkey,valueinkwargs.items()ifkeynotin["tools"]},}# Ensure `tools` is a list if it's included in `kwargs`ifdata["tools"]isnotNoneandnotisinstance(data["tools"],list):data["tools"]=[data["tools"]]_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")response=requests.post(self.url,headers=headers,json=data)tool_calls=_get_tool_calls_from_response(response)ai_message=AIMessage(content=str(response.json()),tool_calls=cast(AIMessageChunk,tool_calls))chat_generation=ChatGeneration(message=ai_message)returnChatResult(generations=[chat_generation])
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type,Callable[...,Any],BaseTool]],**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tools for use in model generation."""formatted_tools=[convert_to_openai_tool(tool)fortoolintools]returnsuper().bind(tools=formatted_tools,**kwargs)
[docs]defwith_structured_output(self,schema:Union[Dict,Type[BaseModel]],*,include_raw:bool=False,method:Optional[Literal["json_mode","function_calling"]]="function_calling",**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:"""Model wrapper that returns outputs formatted to match the given schema."""_=kwargs.pop("strict",None)ifkwargs:raiseValueError(f"Received unsupported arguments {kwargs}")is_pydantic_schema=_is_pydantic_class(schema)ifmethod=="json_schema":# Some applications require that incompatible parameters (e.g., unsupported# methods) be handled.method="function_calling"ifmethod=="function_calling":ifschemaisNone:raiseValueError("schema must be specified when method is 'function_calling'. ""Received None.")tool_name=convert_to_openai_tool(schema)["function"]["name"]llm=self.bind_tools([schema],tool_choice=tool_name)ifis_pydantic_schema:output_parser:OutputParserLike=PydanticToolsParser(tools=[schema],# type: ignore[list-item]first_tool_only=True,# type: ignore[list-item])else:output_parser=JsonOutputKeyToolsParser(key_name=tool_name,first_tool_only=True)elifmethod=="json_mode":llm=self.bind(response_format={"type":"json_object"})output_parser=(PydanticOutputParser(pydantic_object=schema)# type: ignore[type-var, arg-type]ifis_pydantic_schemaelseJsonOutputParser())else:raiseValueError(f"Unrecognized method argument. Expected one of 'function_calling' or "f"'json_mode'. Received: '{method}'")ifinclude_raw:parser_assign=RunnablePassthrough.assign(parsed=itemgetter("raw")|output_parser,parsing_error=lambda_:None)parser_none=RunnablePassthrough.assign(parsed=lambda_:None)parser_with_fallback=parser_assign.with_fallbacks([parser_none],exception_key="parsing_error")returnRunnableMap(raw=llm)|parser_with_fallbackelse:returnllm|output_parser
@propertydef_llm_type(self)->str:"""Return the type of the LLM (for Langchain compatibility)."""return"cloudflare-workers-ai"