import logging
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
from uuid import uuid4
import requests
from langchain.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableMap
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field
# Initialize logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_logger = logging.getLogger(__name__)
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
def _convert_messages_to_cloudflare_messages(
messages: List[BaseMessage],
) -> List[Dict[str, Any]]:
"""Convert LangChain messages to Cloudflare Workers AI format."""
cloudflare_messages = []
msg: Dict[str, Any]
for message in messages:
# Base structure for each message
msg = {
"role": "",
"content": message.content if isinstance(message.content, str) else "",
}
# Determine role and additional fields based on message type
if isinstance(message, HumanMessage):
msg["role"] = "user"
elif isinstance(message, AIMessage):
msg["role"] = "assistant"
# If the AIMessage includes tool calls, format them as needed
if message.tool_calls:
tool_calls = [
{"name": tool_call["name"], "arguments": tool_call["args"]}
for tool_call in message.tool_calls
]
msg["tool_calls"] = tool_calls
elif isinstance(message, SystemMessage):
msg["role"] = "system"
elif isinstance(message, ToolMessage):
msg["role"] = "tool"
msg["tool_call_id"] = (
message.tool_call_id
) # Use tool_call_id if it's a ToolMessage
# Add the formatted message to the list
cloudflare_messages.append(msg)
return cloudflare_messages
def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "tool_calls" in response.json()["result"]:
for tc in response.json()["result"]["tool_calls"]:
tool_calls.append(
tool_call(
id=str(uuid4()),
name=tc["name"],
args=tc["arguments"],
)
)
return tool_calls
[docs]
class ChatCloudflareWorkersAI(BaseChatModel):
"""Custom chat model for Cloudflare Workers AI"""
account_id: str = Field(...)
api_token: str = Field(...)
model: str = Field(...)
ai_gateway: str = ""
url: str = ""
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
gateway_url: str = "https://gateway.ai.cloudflare.com/v1"
def __init__(self, **kwargs: Any) -> None:
"""Initialize with necessary credentials."""
super().__init__(**kwargs)
if self.ai_gateway:
self.url = (
f"{self.gateway_url}/{self.account_id}/"
f"{self.ai_gateway}/workers-ai/run/{self.model}"
)
else:
self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response based on the messages provided."""
formatted_messages = _convert_messages_to_cloudflare_messages(messages)
headers = {"Authorization": f"Bearer {self.api_token}"}
prompt = "\n".join(
f"role: {msg['role']}, content: {msg['content']}"
+ (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
+ (
f", tool_call_id: {msg['tool_call_id']}"
if "tool_call_id" in msg
else ""
)
for msg in formatted_messages
)
# Initialize `data` with `prompt`
data = {
"prompt": prompt,
"tools": kwargs["tools"] if "tools" in kwargs else None,
**{key: value for key, value in kwargs.items() if key not in ["tools"]},
}
# Ensure `tools` is a list if it's included in `kwargs`
if data["tools"] is not None and not isinstance(data["tools"], list):
data["tools"] = [data["tools"]]
_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")
response = requests.post(self.url, headers=headers, json=data)
tool_calls = _get_tool_calls_from_response(response)
ai_message = AIMessage(
content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
)
chat_generation = ChatGeneration(message=ai_message)
return ChatResult(generations=[chat_generation])
[docs]
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
method: Optional[Literal["json_mode", "function_calling"]] = "function_calling",
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema."""
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
@property
def _llm_type(self) -> str:
"""Return the type of the LLM (for Langchain compatibility)."""
return "cloudflare-workers-ai"