Source code for langchain_community.chat_models.mlx
"""MLX Chat Wrapper."""fromtypingimport(Any,Callable,Dict,Iterator,List,Literal,Optional,Sequence,Type,Union,)fromlangchain_core.callbacks.managerimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimportBaseChatModelfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,HumanMessage,SystemMessage,)fromlangchain_core.outputsimport(ChatGeneration,ChatGenerationChunk,ChatResult,LLMResult,)fromlangchain_core.runnablesimportRunnablefromlangchain_core.toolsimportBaseToolfromlangchain_core.utils.function_callingimportconvert_to_openai_toolfromlangchain_community.llms.mlx_pipelineimportMLXPipelineDEFAULT_SYSTEM_PROMPT="""You are a helpful, respectful, and honest assistant."""
[docs]classChatMLX(BaseChatModel):"""MLX chat models. Works with `MLXPipeline` LLM. To use, you should have the ``mlx-lm`` python package installed. Example: .. code-block:: python from langchain_community.chat_models import chatMLX from langchain_community.llms import MLXPipeline llm = MLXPipeline.from_model_id( model_id="mlx-community/quantized-gemma-2b-it", ) chat = chatMLX(llm=llm) """llm:MLXPipelinesystem_message:SystemMessage=SystemMessage(content=DEFAULT_SYSTEM_PROMPT)tokenizer:Any=Nonedef__init__(self,**kwargs:Any):super().__init__(**kwargs)self.tokenizer=self.llm.tokenizerdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:llm_input=self._to_chat_prompt(messages)llm_result=self.llm._generate(prompts=[llm_input],stop=stop,run_manager=run_manager,**kwargs)returnself._to_chat_result(llm_result)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:llm_input=self._to_chat_prompt(messages)llm_result=awaitself.llm._agenerate(prompts=[llm_input],stop=stop,run_manager=run_manager,**kwargs)returnself._to_chat_result(llm_result)def_to_chat_prompt(self,messages:List[BaseMessage],tokenize:bool=False,return_tensors:Optional[str]=None,)->str:"""Convert a list of messages into a prompt format expected by wrapped LLM."""ifnotmessages:raiseValueError("At least one HumanMessage must be provided!")ifnotisinstance(messages[-1],HumanMessage):raiseValueError("Last message must be a HumanMessage!")messages_dicts=[self._to_chatml_format(m)forminmessages]returnself.tokenizer.apply_chat_template(messages_dicts,tokenize=tokenize,add_generation_prompt=True,return_tensors=return_tensors,)def_to_chatml_format(self,message:BaseMessage)->dict:"""Convert LangChain message to ChatML format."""ifisinstance(message,SystemMessage):role="system"elifisinstance(message,AIMessage):role="assistant"elifisinstance(message,HumanMessage):role="user"else:raiseValueError(f"Unknown message type: {type(message)}")return{"role":role,"content":message.content}@staticmethoddef_to_chat_result(llm_result:LLMResult)->ChatResult:chat_generations=[]forginllm_result.generations[0]:chat_generation=ChatGeneration(message=AIMessage(content=g.text),generation_info=g.generation_info)chat_generations.append(chat_generation)returnChatResult(generations=chat_generations,llm_output=llm_result.llm_output)@propertydef_llm_type(self)->str:return"mlx-chat-wrapper"def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:importmlx.coreasmxfrommlx_lm.utilsimportgenerate_steptry:importmlx.coreasmxfrommlx_lm.sample_utilsimportmake_logits_processors,make_samplerfrommlx_lm.utilsimportgenerate_stepexceptImportError:raiseImportError("Could not import mlx_lm python package. ""Please install it with `pip install mlx_lm`.")model_kwargs=kwargs.get("model_kwargs",self.llm.pipeline_kwargs)temp:float=model_kwargs.get("temp",0.0)max_new_tokens:int=model_kwargs.get("max_tokens",100)repetition_penalty:Optional[float]=model_kwargs.get("repetition_penalty",None)repetition_context_size:Optional[int]=model_kwargs.get("repetition_context_size",None)top_p:float=model_kwargs.get("top_p",1.0)min_p:float=model_kwargs.get("min_p",0.0)min_tokens_to_keep:int=model_kwargs.get("min_tokens_to_keep",1)llm_input=self._to_chat_prompt(messages,tokenize=True,return_tensors="np")prompt_tokens=mx.array(llm_input[0])eos_token_id=self.tokenizer.eos_token_idsampler=make_sampler(tempor0.0,top_p,min_p,min_tokens_to_keep)logits_processors=make_logits_processors(None,repetition_penalty,repetition_context_size)for(token,prob),ninzip(generate_step(prompt_tokens,self.llm.model,sampler=sampler,logits_processors=logits_processors,),range(max_new_tokens),):# identify text to yieldtext:Optional[str]=Noneifnotisinstance(token,int):text=self.tokenizer.decode(token.item())else:text=self.tokenizer.decode(token)# yield text, if anyiftext:chunk=ChatGenerationChunk(message=AIMessageChunk(content=text))ifrun_manager:run_manager.on_llm_new_token(text,chunk=chunk)yieldchunk# break if stop sequence foundiftoken==eos_token_idor(stopisnotNoneandtextinstop):break
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type,Callable,BaseTool]],*,tool_choice:Optional[Union[dict,str,Literal["auto","none"],bool]]=None,**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model. Assumes model is compatible with OpenAI tool-calling API. Args: tools: A list of tool definitions to bind to this chat model. Supports any tool definition handled by :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function or "auto" to automatically determine which function to call (if any), or a dict of the form: {"type": "function", "function": {"name": <<tool_name>>}}. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """formatted_tools=[convert_to_openai_tool(tool)fortoolintools]iftool_choiceisnotNoneandtool_choice:iflen(formatted_tools)!=1:raiseValueError("When specifying `tool_choice`, you must provide exactly one "f"tool. Received {len(formatted_tools)} tools.")ifisinstance(tool_choice,str):iftool_choicenotin("auto","none"):tool_choice={"type":"function","function":{"name":tool_choice},}elifisinstance(tool_choice,bool):tool_choice=formatted_tools[0]elifisinstance(tool_choice,dict):if(formatted_tools[0]["function"]["name"]!=tool_choice["function"]["name"]):raiseValueError(f"Tool choice {tool_choice} was specified, but the only "f"provided tool was {formatted_tools[0]['function']['name']}.")else:raiseValueError(f"Unrecognized tool_choice type. Expected str, bool or dict. "f"Received: {tool_choice}")kwargs["tool_choice"]=tool_choicereturnsuper().bind(tools=formatted_tools,**kwargs)