[docs]classHuggingFaceEndpoint(LLM):""" HuggingFace Endpoint. To use this class, you should have installed the ``huggingface_hub`` package, and the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or given as a named parameter to the constructor. Example: .. code-block:: python # Basic Example (no streaming) llm = HuggingFaceEndpoint( endpoint_url="http://localhost:8010/", max_new_tokens=512, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, huggingfacehub_api_token="my-api-key" ) print(llm.invoke("What is Deep Learning?")) # Streaming response example from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler callbacks = [StreamingStdOutCallbackHandler()] llm = HuggingFaceEndpoint( endpoint_url="http://localhost:8010/", max_new_tokens=512, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, callbacks=callbacks, streaming=True, huggingfacehub_api_token="my-api-key" ) print(llm.invoke("What is Deep Learning?")) """# noqa: E501endpoint_url:Optional[str]=None"""Endpoint URL to use. If repo_id is not specified then this needs to given or should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""repo_id:Optional[str]=None"""Repo to use. If endpoint_url is not specified then this needs to given"""huggingfacehub_api_token:Optional[str]=Field(default_factory=from_env("HUGGINGFACEHUB_API_TOKEN",default=None))max_new_tokens:int=512"""Maximum number of generated tokens"""top_k:Optional[int]=None"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""top_p:Optional[float]=0.95"""If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation."""typical_p:Optional[float]=0.95"""Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information."""temperature:Optional[float]=0.8"""The value used to module the logits distribution."""repetition_penalty:Optional[float]=None"""The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""return_full_text:bool=False"""Whether to prepend the prompt to the generated text"""truncate:Optional[int]=None"""Truncate inputs tokens to the given size"""stop_sequences:List[str]=Field(default_factory=list)"""Stop generating tokens if a member of `stop_sequences` is generated"""seed:Optional[int]=None"""Random sampling seed"""inference_server_url:str="""""text-generation-inference instance base url"""timeout:int=120"""Timeout in seconds"""streaming:bool=False"""Whether to generate a stream of tokens asynchronously"""do_sample:bool=False"""Activate logits sampling"""watermark:bool=False"""Watermarking with [A Watermark for Large Language Models] (https://arxiv.org/abs/2301.10226)"""server_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any text-generation-inference server parameters not explicitly specified"""model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for `call` not explicitly specified"""model:strclient:Any=None#: :meta private:async_client:Any=None#: :meta private:task:Optional[str]=None"""Task to call the model with. Should be a task that returns `generated_text` or `summary_text`."""model_config=ConfigDict(extra="forbid",)@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")iffield_namenotinall_required_field_names:logger.warning(f"""WARNING! {field_name} is not default parameter.{field_name} was transferred to model_kwargs. Please make sure that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)invalid_model_kwargs=all_required_field_names.intersection(extra.keys())ifinvalid_model_kwargs:raiseValueError(f"Parameters {invalid_model_kwargs} should be specified explicitly. "f"Instead they were passed in as part of `model_kwargs` parameter.")values["model_kwargs"]=extra# to correctly create the InferenceClient and AsyncInferenceClient# in validate_environment, we need to populate values["model"].# from InferenceClient docstring:# model (`str`, `optional`):# The model to run inference with. Can be a model id hosted on the Hugging# Face Hub, e.g. `bigcode/starcoder`# or a URL to a deployed Inference Endpoint. Defaults to None, in which# case a recommended model is# automatically selected for the task.# this string could be in 3 places of descending priority:# 2. values["model"] or values["endpoint_url"] or values["repo_id"]# (equal priority - don't allow both set)# 3. values["HF_INFERENCE_ENDPOINT"] (if none above set)model=values.get("model")endpoint_url=values.get("endpoint_url")repo_id=values.get("repo_id")ifsum([bool(model),bool(endpoint_url),bool(repo_id)])>1:raiseValueError("Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,""not more than one.")values["model"]=(modelorendpoint_urlorrepo_idoros.environ.get("HF_INFERENCE_ENDPOINT"))ifnotvalues["model"]:raiseValueError("Please specify a `model` or an `endpoint_url` or a `repo_id` for the ""model.")returnvalues@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that package is installed and that the API token is valid."""try:fromhuggingface_hubimportlogin# type: ignore[import]exceptImportError:raiseImportError("Could not import huggingface_hub python package. ""Please install it with `pip install huggingface_hub`.")huggingfacehub_api_token=self.huggingfacehub_api_tokenoros.getenv("HF_TOKEN")ifhuggingfacehub_api_tokenisnotNone:try:login(token=huggingfacehub_api_token)exceptExceptionase:raiseValueError("Could not authenticate with huggingface_hub. ""Please check your API token.")fromefromhuggingface_hubimportAsyncInferenceClient,InferenceClient# Instantiate clients with supported kwargssync_supported_kwargs=set(inspect.signature(InferenceClient).parameters)self.client=InferenceClient(model=self.model,timeout=self.timeout,token=huggingfacehub_api_token,**{key:valueforkey,valueinself.server_kwargs.items()ifkeyinsync_supported_kwargs},)async_supported_kwargs=set(inspect.signature(AsyncInferenceClient).parameters)self.async_client=AsyncInferenceClient(model=self.model,timeout=self.timeout,token=huggingfacehub_api_token,**{key:valueforkey,valueinself.server_kwargs.items()ifkeyinasync_supported_kwargs},)ignored_kwargs=(set(self.server_kwargs.keys())-sync_supported_kwargs-async_supported_kwargs)iflen(ignored_kwargs)>0:logger.warning(f"Ignoring following parameters as they are not supported by the "f"InferenceClient or AsyncInferenceClient: {ignored_kwargs}.")returnself@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling text generation inference API."""return{"max_new_tokens":self.max_new_tokens,"top_k":self.top_k,"top_p":self.top_p,"typical_p":self.typical_p,"temperature":self.temperature,"repetition_penalty":self.repetition_penalty,"return_full_text":self.return_full_text,"truncate":self.truncate,"stop_sequences":self.stop_sequences,"seed":self.seed,"do_sample":self.do_sample,"watermark":self.watermark,**self.model_kwargs,}@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""_model_kwargs=self.model_kwargsor{}return{**{"endpoint_url":self.endpoint_url,"task":self.task},**{"model_kwargs":_model_kwargs},}@propertydef_llm_type(self)->str:"""Return type of llm."""return"huggingface_endpoint"def_invocation_params(self,runtime_stop:Optional[List[str]],**kwargs:Any)->Dict[str,Any]:params={**self._default_params,**kwargs}params["stop_sequences"]=params["stop_sequences"]+(runtime_stopor[])returnparamsdef_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to HuggingFace Hub's inference endpoint."""invocation_params=self._invocation_params(stop,**kwargs)ifself.streaming:completion=""forchunkinself._stream(prompt,stop,run_manager,**invocation_params):completion+=chunk.textreturncompletionelse:invocation_params["stop"]=invocation_params["stop_sequences"]# porting 'stop_sequences' into the 'stop' argumentresponse=self.client.post(json={"inputs":prompt,"parameters":invocation_params},stream=False,task=self.task,)response_text=json.loads(response.decode())[0]["generated_text"]# Maybe the generation has stopped at one of the stop sequences:# then we remove this stop sequence from the end of the generated textforstop_seqininvocation_params["stop_sequences"]:ifresponse_text[-len(stop_seq):]==stop_seq:response_text=response_text[:-len(stop_seq)]returnresponse_textasyncdef_acall(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->str:invocation_params=self._invocation_params(stop,**kwargs)ifself.streaming:completion=""asyncforchunkinself._astream(prompt,stop,run_manager,**invocation_params):completion+=chunk.textreturncompletionelse:invocation_params["stop"]=invocation_params["stop_sequences"]response=awaitself.async_client.post(json={"inputs":prompt,"parameters":invocation_params},stream=False,task=self.task,)response_text=json.loads(response.decode())[0]["generated_text"]# Maybe the generation has stopped at one of the stop sequences:# then remove this stop sequence from the end of the generated textforstop_seqininvocation_params["stop_sequences"]:ifresponse_text[-len(stop_seq):]==stop_seq:response_text=response_text[:-len(stop_seq)]returnresponse_textdef_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:invocation_params=self._invocation_params(stop,**kwargs)forresponseinself.client.text_generation(prompt,**invocation_params,stream=True):# identify stop sequence in generated text, if anystop_seq_found:Optional[str]=Noneforstop_seqininvocation_params["stop_sequences"]:ifstop_seqinresponse:stop_seq_found=stop_seq# identify text to yieldtext:Optional[str]=Noneifstop_seq_found:text=response[:response.index(stop_seq_found)]else:text=response# yield text, if anyiftext:chunk=GenerationChunk(text=text)ifrun_manager:run_manager.on_llm_new_token(chunk.text)yieldchunk# break if stop sequence foundifstop_seq_found:breakasyncdef_astream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[GenerationChunk]:invocation_params=self._invocation_params(stop,**kwargs)asyncforresponseinawaitself.async_client.text_generation(prompt,**invocation_params,stream=True):# identify stop sequence in generated text, if anystop_seq_found:Optional[str]=Noneforstop_seqininvocation_params["stop_sequences"]:ifstop_seqinresponse:stop_seq_found=stop_seq# identify text to yieldtext:Optional[str]=Noneifstop_seq_found:text=response[:response.index(stop_seq_found)]else:text=response# yield text, if anyiftext:chunk=GenerationChunk(text=text)ifrun_manager:awaitrun_manager.on_llm_new_token(chunk.text)yieldchunk# break if stop sequence foundifstop_seq_found:break