Source code for langchain_community.chat_models.coze

import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    generate_from_stream,
)
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    BaseMessageChunk,
    ChatMessage,
    ChatMessageChunk,
    HumanMessage,
    HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
    convert_to_secret_str,
    get_from_dict_or_env,
)

logger = logging.getLogger(__name__)

DEFAULT_API_BASE = "https://api.coze.com"


def _convert_message_to_dict(message: BaseMessage) -> dict:
    message_dict: Dict[str, Any]
    if isinstance(message, HumanMessage):
        message_dict = {
            "role": "user",
            "content": message.content,
            "content_type": "text",
        }
    else:
        message_dict = {
            "role": "assistant",
            "content": message.content,
            "content_type": "text",
        }
    return message_dict


def _convert_dict_to_message(_dict: Mapping[str, Any]) -> Union[BaseMessage, None]:
    msg_type = _dict["type"]
    if msg_type != "answer":
        return None
    role = _dict["role"]
    if role == "user":
        return HumanMessage(content=_dict["content"])
    elif role == "assistant":
        return AIMessage(content=_dict.get("content", "") or "")
    else:
        return ChatMessage(content=_dict["content"], role=role)


def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
    role = _dict.get("role")
    content = _dict.get("content") or ""

    if role == "user":
        return HumanMessageChunk(content=content)
    elif role == "assistant":
        return AIMessageChunk(content=content)
    else:
        return ChatMessageChunk(content=content, role=role)  # type: ignore[arg-type]


[docs]class ChatCoze(BaseChatModel): """ChatCoze chat models API by coze.com For more information, see https://www.coze.com/open/docs/chat """ @property def lc_secrets(self) -> Dict[str, str]: return { "coze_api_key": "COZE_API_KEY", } @property def lc_serializable(self) -> bool: return True coze_api_base: str = Field(default=DEFAULT_API_BASE) """Coze custom endpoints""" coze_api_key: Optional[SecretStr] = None """Coze API Key""" request_timeout: int = Field(default=60, alias="timeout") """request timeout for chat http requests""" bot_id: str = Field(default="") """The ID of the bot that the API interacts with.""" conversation_id: str = Field(default="") """Indicate which conversation the dialog is taking place in. If there is no need to distinguish the context of the conversation(just a question and answer), skip this parameter. It will be generated by the system.""" user: str = Field(default="") """The user who calls the API to chat with the bot.""" streaming: bool = False """Whether to stream the response to the client. false: if no value is specified or set to false, a non-streaming response is returned. "Non-streaming response" means that all responses will be returned at once after they are all ready, and the client does not need to concatenate the content. true: set to true, partial message deltas will be sent . "Streaming response" will provide real-time response of the model to the client, and the client needs to assemble the final reply based on the type of message. """ class Config: allow_population_by_field_name = True @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: values["coze_api_base"] = get_from_dict_or_env( values, "coze_api_base", "COZE_API_BASE", DEFAULT_API_BASE, ) values["coze_api_key"] = convert_to_secret_str( get_from_dict_or_env( values, "coze_api_key", "COZE_API_KEY", ) ) return values @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Coze API.""" return { "bot_id": self.bot_id, "conversation_id": self.conversation_id, "user": self.user, "streaming": self.streaming, } def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._stream( messages=messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) r = self._chat(messages, **kwargs) res = r.json() if res["code"] != 0: raise ValueError( f"Error from Coze api response: {res['code']}: {res['msg']}, " f"logid: {r.headers.get('X-Tt-Logid')}" ) return self._create_chat_result(res.get("messages") or []) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: res = self._chat(messages, **kwargs) for chunk in res.iter_lines(): chunk = chunk.decode("utf-8").strip("\r\n") parts = chunk.split("data:", 1) chunk = parts[1] if len(parts) > 1 else None if chunk is None: continue response = json.loads(chunk) if response["event"] == "done": break elif ( response["event"] != "message" or response["message"]["type"] != "answer" ): continue chunk = _convert_delta_to_message_chunk(response["message"]) cg_chunk = ChatGenerationChunk(message=chunk) if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) yield cg_chunk def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: parameters = {**self._default_params, **kwargs} query = "" chat_history = [] for msg in messages: if isinstance(msg, HumanMessage): query = f"{msg.content}" # overwrite, to get last user message as query chat_history.append(_convert_message_to_dict(msg)) conversation_id = parameters.pop("conversation_id") bot_id = parameters.pop("bot_id") user = parameters.pop("user") streaming = parameters.pop("streaming") payload = { "conversation_id": conversation_id, "bot_id": bot_id, "user": user, "query": query, "stream": streaming, } if chat_history: payload["chat_history"] = chat_history url = self.coze_api_base + "/open_api/v2/chat" api_key = "" if self.coze_api_key: api_key = self.coze_api_key.get_secret_value() res = requests.post( url=url, timeout=self.request_timeout, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", }, json=payload, stream=streaming, ) if res.status_code != 200: logid = res.headers.get("X-Tt-Logid") raise ValueError(f"Error from Coze api response: {res}, logid: {logid}") return res def _create_chat_result(self, messages: List[Mapping[str, Any]]) -> ChatResult: generations = [] for c in messages: msg = _convert_dict_to_message(c) if msg: generations.append(ChatGeneration(message=msg)) llm_output = {"token_usage": "", "model": ""} return ChatResult(generations=generations, llm_output=llm_output) @property def _llm_type(self) -> str: return "coze-chat"