Source code for langchain_community.chat_models.mlx

"""MLX Chat Wrapper."""

from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Type,
    Union,
)

from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.outputs import (
    ChatGeneration,
    ChatGenerationChunk,
    ChatResult,
    LLMResult,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_community.llms.mlx_pipeline import MLXPipeline

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""


[docs] class ChatMLX(BaseChatModel): """MLX chat models. Works with `MLXPipeline` LLM. To use, you should have the ``mlx-lm`` python package installed. Example: .. code-block:: python from langchain_community.chat_models import chatMLX from langchain_community.llms import MLXPipeline llm = MLXPipeline.from_model_id( model_id="mlx-community/quantized-gemma-2b-it", ) chat = chatMLX(llm=llm) """ llm: MLXPipeline system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) tokenizer: Any = None def __init__(self, **kwargs: Any): super().__init__(**kwargs) self.tokenizer = self.llm.tokenizer def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: llm_input = self._to_chat_prompt(messages) llm_result = self.llm._generate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs ) return self._to_chat_result(llm_result) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: llm_input = self._to_chat_prompt(messages) llm_result = await self.llm._agenerate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs ) return self._to_chat_result(llm_result) def _to_chat_prompt( self, messages: List[BaseMessage], tokenize: bool = False, return_tensors: Optional[str] = None, ) -> str: """Convert a list of messages into a prompt format expected by wrapped LLM.""" if not messages: raise ValueError("At least one HumanMessage must be provided!") if not isinstance(messages[-1], HumanMessage): raise ValueError("Last message must be a HumanMessage!") messages_dicts = [self._to_chatml_format(m) for m in messages] return self.tokenizer.apply_chat_template( messages_dicts, tokenize=tokenize, add_generation_prompt=True, return_tensors=return_tensors, ) def _to_chatml_format(self, message: BaseMessage) -> dict: """Convert LangChain message to ChatML format.""" if isinstance(message, SystemMessage): role = "system" elif isinstance(message, AIMessage): role = "assistant" elif isinstance(message, HumanMessage): role = "user" else: raise ValueError(f"Unknown message type: {type(message)}") return {"role": role, "content": message.content} @staticmethod def _to_chat_result(llm_result: LLMResult) -> ChatResult: chat_generations = [] for g in llm_result.generations[0]: chat_generation = ChatGeneration( message=AIMessage(content=g.text), generation_info=g.generation_info ) chat_generations.append(chat_generation) return ChatResult( generations=chat_generations, llm_output=llm_result.llm_output ) @property def _llm_type(self) -> str: return "mlx-chat-wrapper" def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: import mlx.core as mx from mlx_lm.utils import generate_step try: import mlx.core as mx from mlx_lm.utils import generate_step except ImportError: raise ImportError( "Could not import mlx_lm python package. " "Please install it with `pip install mlx_lm`." ) model_kwargs = kwargs.get("model_kwargs", self.llm.pipeline_kwargs) temp: float = model_kwargs.get("temp", 0.0) max_new_tokens: int = model_kwargs.get("max_tokens", 100) repetition_penalty: Optional[float] = model_kwargs.get( "repetition_penalty", None ) repetition_context_size: Optional[int] = model_kwargs.get( "repetition_context_size", None ) llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np") prompt_tokens = mx.array(llm_input[0]) eos_token_id = self.tokenizer.eos_token_id for (token, prob), n in zip( generate_step( prompt_tokens, self.llm.model, temp=temp, repetition_penalty=repetition_penalty, repetition_context_size=repetition_context_size, ), range(max_new_tokens), ): # identify text to yield text: Optional[str] = None if not isinstance(token, int): text = self.tokenizer.decode(token.item()) else: text = self.tokenizer.decode(token) # yield text, if any if text: chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) if run_manager: run_manager.on_llm_new_token(text, chunk=chunk) yield chunk # break if stop sequence found if token == eos_token_id or (stop is not None and text in stop): break
[docs] def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], *, tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. Assumes model is compatible with OpenAI tool-calling API. Args: tools: A list of tool definitions to bind to this chat model. Supports any tool definition handled by :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function or "auto" to automatically determine which function to call (if any), or a dict of the form: {"type": "function", "function": {"name": <<tool_name>>}}. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] if tool_choice is not None and tool_choice: if len(formatted_tools) != 1: raise ValueError( "When specifying `tool_choice`, you must provide exactly one " f"tool. Received {len(formatted_tools)} tools." ) if isinstance(tool_choice, str): if tool_choice not in ("auto", "none"): tool_choice = { "type": "function", "function": {"name": tool_choice}, } elif isinstance(tool_choice, bool): tool_choice = formatted_tools[0] elif isinstance(tool_choice, dict): if ( formatted_tools[0]["function"]["name"] != tool_choice["function"]["name"] ): raise ValueError( f"Tool choice {tool_choice} was specified, but the only " f"provided tool was {formatted_tools[0]['function']['name']}." ) else: raise ValueError( f"Unrecognized tool_choice type. Expected str, bool or dict. " f"Received: {tool_choice}" ) kwargs["tool_choice"] = tool_choice return super().bind(tools=formatted_tools, **kwargs)