Source code for langchain_community.chat_models.mlx
"""MLX Chat Wrapper."""fromtypingimportAny,Iterator,List,Optionalfromlangchain_core.callbacks.managerimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.chat_modelsimportBaseChatModelfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,HumanMessage,SystemMessage,)fromlangchain_core.outputsimport(ChatGeneration,ChatGenerationChunk,ChatResult,LLMResult,)fromlangchain_community.llms.mlx_pipelineimportMLXPipelineDEFAULT_SYSTEM_PROMPT="""You are a helpful, respectful, and honest assistant."""
[docs]classChatMLX(BaseChatModel):"""MLX chat models. Works with `MLXPipeline` LLM. To use, you should have the ``mlx-lm`` python package installed. Example: .. code-block:: python from langchain_community.chat_models import chatMLX from langchain_community.llms import MLXPipeline llm = MLXPipeline.from_model_id( model_id="mlx-community/quantized-gemma-2b-it", ) chat = chatMLX(llm=llm) """llm:MLXPipelinesystem_message:SystemMessage=SystemMessage(content=DEFAULT_SYSTEM_PROMPT)tokenizer:Any=Nonedef__init__(self,**kwargs:Any):super().__init__(**kwargs)self.tokenizer=self.llm.tokenizerdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:llm_input=self._to_chat_prompt(messages)llm_result=self.llm._generate(prompts=[llm_input],stop=stop,run_manager=run_manager,**kwargs)returnself._to_chat_result(llm_result)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:llm_input=self._to_chat_prompt(messages)llm_result=awaitself.llm._agenerate(prompts=[llm_input],stop=stop,run_manager=run_manager,**kwargs)returnself._to_chat_result(llm_result)def_to_chat_prompt(self,messages:List[BaseMessage],tokenize:bool=False,return_tensors:Optional[str]=None,)->str:"""Convert a list of messages into a prompt format expected by wrapped LLM."""ifnotmessages:raiseValueError("At least one HumanMessage must be provided!")ifnotisinstance(messages[-1],HumanMessage):raiseValueError("Last message must be a HumanMessage!")messages_dicts=[self._to_chatml_format(m)forminmessages]returnself.tokenizer.apply_chat_template(messages_dicts,tokenize=tokenize,add_generation_prompt=True,return_tensors=return_tensors,)def_to_chatml_format(self,message:BaseMessage)->dict:"""Convert LangChain message to ChatML format."""ifisinstance(message,SystemMessage):role="system"elifisinstance(message,AIMessage):role="assistant"elifisinstance(message,HumanMessage):role="user"else:raiseValueError(f"Unknown message type: {type(message)}")return{"role":role,"content":message.content}@staticmethoddef_to_chat_result(llm_result:LLMResult)->ChatResult:chat_generations=[]forginllm_result.generations[0]:chat_generation=ChatGeneration(message=AIMessage(content=g.text),generation_info=g.generation_info)chat_generations.append(chat_generation)returnChatResult(generations=chat_generations,llm_output=llm_result.llm_output)@propertydef_llm_type(self)->str:return"mlx-chat-wrapper"def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:importmlx.coreasmxfrommlx_lm.utilsimportgenerate_steptry:importmlx.coreasmxfrommlx_lm.utilsimportgenerate_stepexceptImportError:raiseImportError("Could not import mlx_lm python package. ""Please install it with `pip install mlx_lm`.")model_kwargs=kwargs.get("model_kwargs",self.llm.pipeline_kwargs)temp:float=model_kwargs.get("temp",0.0)max_new_tokens:int=model_kwargs.get("max_tokens",100)repetition_penalty:Optional[float]=model_kwargs.get("repetition_penalty",None)repetition_context_size:Optional[int]=model_kwargs.get("repetition_context_size",None)llm_input=self._to_chat_prompt(messages,tokenize=True,return_tensors="np")prompt_tokens=mx.array(llm_input[0])eos_token_id=self.tokenizer.eos_token_idfor(token,prob),ninzip(generate_step(prompt_tokens,self.llm.model,temp,repetition_penalty,repetition_context_size,),range(max_new_tokens),):# identify text to yieldtext:Optional[str]=Nonetext=self.tokenizer.decode(token.item())# yield text, if anyiftext:chunk=ChatGenerationChunk(message=AIMessageChunk(content=text))ifrun_manager:run_manager.on_llm_new_token(text,chunk=chunk)yieldchunk# break if stop sequence foundiftoken==eos_token_idor(stopisnotNoneandtextinstop):break