[docs]classMLXPipeline(LLM):"""MLX Pipeline API. To use, you should have the ``mlx-lm`` python package installed. Example using from_model_id: .. code-block:: python from langchain_community.llms import MLXPipeline pipe = MLXPipeline.from_model_id( model_id="mlx-community/quantized-gemma-2b", pipeline_kwargs={"max_tokens": 10, "temp": 0.7}, ) Example passing model and tokenizer in directly: .. code-block:: python from langchain_community.llms import MLXPipeline from mlx_lm import load model_id="mlx-community/quantized-gemma-2b" model, tokenizer = load(model_id) pipe = MLXPipeline(model=model, tokenizer=tokenizer) """model_id:str=DEFAULT_MODEL_ID"""Model name to use."""model:Any=None#: :meta private:"""Model."""tokenizer:Any=None#: :meta private:"""Tokenizer."""tokenizer_config:Optional[dict]=None""" Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary. """adapter_file:Optional[str]=None""" Path to the adapter file. If provided, applies LoRA layers to the model. Defaults to None. """lazy:bool=False""" If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` """pipeline_kwargs:Optional[dict]=None""" Keyword arguments passed to the pipeline. Defaults include: - temp (float): Temperature for generation, default is 0.0. - max_tokens (int): Maximum tokens to generate, default is 100. - verbose (bool): Whether to output verbose logging, default is False. - formatter (Optional[Callable]): A callable to format the output. Default is None. - repetition_penalty (Optional[float]): The penalty factor for repeated sequences, default is None. - repetition_context_size (Optional[int]): Size of the context for applying repetition penalty, default is None. - top_p (float): The cumulative probability threshold for top-p filtering, default is 1.0. """model_config=ConfigDict(extra="forbid",)
[docs]@classmethoddeffrom_model_id(cls,model_id:str,tokenizer_config:Optional[dict]=None,adapter_file:Optional[str]=None,lazy:bool=False,pipeline_kwargs:Optional[dict]=None,**kwargs:Any,)->MLXPipeline:"""Construct the pipeline object from model_id and task."""try:frommlx_lmimportloadexceptImportError:raiseImportError("Could not import mlx_lm python package. ""Please install it with `pip install mlx_lm`.")tokenizer_config=tokenizer_configor{}ifadapter_file:model,tokenizer=load(model_id,tokenizer_config,adapter_path=adapter_file,lazy=lazy)else:model,tokenizer=load(model_id,tokenizer_config,lazy=lazy)_pipeline_kwargs=pipeline_kwargsor{}returncls(model_id=model_id,model=model,tokenizer=tokenizer,tokenizer_config=tokenizer_config,adapter_file=adapter_file,lazy=lazy,pipeline_kwargs=_pipeline_kwargs,**kwargs,)
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""return{"model_id":self.model_id,"tokenizer_config":self.tokenizer_config,"adapter_file":self.adapter_file,"lazy":self.lazy,"pipeline_kwargs":self.pipeline_kwargs,}@propertydef_llm_type(self)->str:return"mlx_pipeline"def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:try:frommlx_lmimportgeneratefrommlx_lm.sample_utilsimportmake_logits_processors,make_samplerexceptImportError:raiseImportError("Could not import mlx_lm python package. ""Please install it with `pip install mlx_lm`.")pipeline_kwargs=kwargs.get("pipeline_kwargs",self.pipeline_kwargs)temp:float=pipeline_kwargs.get("temp",0.0)max_tokens:int=pipeline_kwargs.get("max_tokens",100)verbose:bool=pipeline_kwargs.get("verbose",False)formatter:Optional[Callable]=pipeline_kwargs.get("formatter",None)repetition_penalty:Optional[float]=pipeline_kwargs.get("repetition_penalty",None)repetition_context_size:Optional[int]=pipeline_kwargs.get("repetition_context_size",None)top_p:float=pipeline_kwargs.get("top_p",1.0)min_p:float=pipeline_kwargs.get("min_p",0.0)min_tokens_to_keep:int=pipeline_kwargs.get("min_tokens_to_keep",1)sampler=make_sampler(temp,top_p,min_p,min_tokens_to_keep)logits_processors=make_logits_processors(None,repetition_penalty,repetition_context_size)returngenerate(model=self.model,tokenizer=self.tokenizer,prompt=prompt,max_tokens=max_tokens,verbose=verbose,formatter=formatter,sampler=sampler,logits_processors=logits_processors,)def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:try:importmlx.coreasmxfrommlx_lm.sample_utilsimportmake_logits_processors,make_samplerfrommlx_lm.utilsimportgenerate_stepexceptImportError:raiseImportError("Could not import mlx_lm python package. ""Please install it with `pip install mlx_lm`.")pipeline_kwargs=kwargs.get("pipeline_kwargs",self.pipeline_kwargs)temp:float=pipeline_kwargs.get("temp",0.0)max_new_tokens:int=pipeline_kwargs.get("max_tokens",100)repetition_penalty:Optional[float]=pipeline_kwargs.get("repetition_penalty",None)repetition_context_size:Optional[int]=pipeline_kwargs.get("repetition_context_size",None)top_p:float=pipeline_kwargs.get("top_p",1.0)min_p:float=pipeline_kwargs.get("min_p",0.0)min_tokens_to_keep:int=pipeline_kwargs.get("min_tokens_to_keep",1)prompt=self.tokenizer.encode(prompt,return_tensors="np")prompt_tokens=mx.array(prompt[0])eos_token_id=self.tokenizer.eos_token_iddetokenizer=self.tokenizer.detokenizerdetokenizer.reset()sampler=make_sampler(tempor0.0,top_p,min_p,min_tokens_to_keep)logits_processors=make_logits_processors(None,repetition_penalty,repetition_context_size)for(token,prob),ninzip(generate_step(prompt=prompt_tokens,model=self.model,sampler=sampler,logits_processors=logits_processors,),range(max_new_tokens),):# identify text to yieldtext:Optional[str]=Nonedetokenizer.add_token(token)detokenizer.finalize()text=detokenizer.last_segment# yield text, if anyiftext:chunk=GenerationChunk(text=text)ifrun_manager:run_manager.on_llm_new_token(chunk.text)yieldchunk# break if stop sequence foundiftoken==eos_token_idor(stopisnotNoneandtextinstop):break