Source code for langchain_community.chat_models.huggingface
"""Hugging Face Chat Wrapper."""fromtypingimportAny,AsyncIterator,Iterator,List,Optionalfromlangchain_core._api.deprecationimportdeprecatedfromlangchain_core.callbacks.managerimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.chat_modelsimport(BaseChatModel,agenerate_from_stream,generate_from_stream,)fromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,HumanMessage,SystemMessage,)fromlangchain_core.outputsimport(ChatGeneration,ChatGenerationChunk,ChatResult,LLMResult,)frompydanticimportmodel_validatorfromtyping_extensionsimportSelffromlangchain_community.llms.huggingface_endpointimportHuggingFaceEndpointfromlangchain_community.llms.huggingface_hubimportHuggingFaceHubfromlangchain_community.llms.huggingface_text_gen_inferenceimport(HuggingFaceTextGenInference,)DEFAULT_SYSTEM_PROMPT="""You are a helpful, respectful, and honest assistant."""
[docs]@deprecated(since="0.0.37",removal="1.0",alternative_import="langchain_huggingface.ChatHuggingFace",)classChatHuggingFace(BaseChatModel):""" Wrapper for using Hugging Face LLM's as ChatModels. Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`, and `HuggingFaceHub` LLMs. Upon instantiating this class, the model_id is resolved from the url provided to the LLM, and the appropriate tokenizer is loaded from the HuggingFace Hub. Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat """llm:Any"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or HuggingFaceHub."""system_message:SystemMessage=SystemMessage(content=DEFAULT_SYSTEM_PROMPT)tokenizer:Any=Nonemodel_id:Optional[str]=Nonestreaming:bool=Falsedef__init__(self,**kwargs:Any):super().__init__(**kwargs)fromtransformersimportAutoTokenizerself._resolve_model_id()self.tokenizer=(AutoTokenizer.from_pretrained(self.model_id)ifself.tokenizerisNoneelseself.tokenizer)@model_validator(mode="after")defvalidate_llm(self)->Self:ifnotisinstance(self.llm,(HuggingFaceTextGenInference,HuggingFaceEndpoint,HuggingFaceHub),):raiseTypeError("Expected llm to be one of HuggingFaceTextGenInference, "f"HuggingFaceEndpoint, HuggingFaceHub, received {type(self.llm)}")returnselfdef_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:request=self._to_chat_prompt(messages)fordatainself.llm.stream(request,**kwargs):delta=datachunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:request=self._to_chat_prompt(messages)asyncfordatainself.llm.astream(request,**kwargs):delta=datachunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:awaitrun_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)llm_input=self._to_chat_prompt(messages)llm_result=self.llm._generate(prompts=[llm_input],stop=stop,run_manager=run_manager,**kwargs)returnself._to_chat_result(llm_result)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)llm_input=self._to_chat_prompt(messages)llm_result=awaitself.llm._agenerate(prompts=[llm_input],stop=stop,run_manager=run_manager,**kwargs)returnself._to_chat_result(llm_result)def_to_chat_prompt(self,messages:List[BaseMessage],)->str:"""Convert a list of messages into a prompt format expected by wrapped LLM."""ifnotmessages:raiseValueError("At least one HumanMessage must be provided!")ifnotisinstance(messages[-1],HumanMessage):raiseValueError("Last message must be a HumanMessage!")messages_dicts=[self._to_chatml_format(m)forminmessages]returnself.tokenizer.apply_chat_template(messages_dicts,tokenize=False,add_generation_prompt=True)def_to_chatml_format(self,message:BaseMessage)->dict:"""Convert LangChain message to ChatML format."""ifisinstance(message,SystemMessage):role="system"elifisinstance(message,AIMessage):role="assistant"elifisinstance(message,HumanMessage):role="user"else:raiseValueError(f"Unknown message type: {type(message)}")return{"role":role,"content":message.content}@staticmethoddef_to_chat_result(llm_result:LLMResult)->ChatResult:chat_generations=[]forginllm_result.generations[0]:chat_generation=ChatGeneration(message=AIMessage(content=g.text),generation_info=g.generation_info)chat_generations.append(chat_generation)returnChatResult(generations=chat_generations,llm_output=llm_result.llm_output)def_resolve_model_id(self)->None:"""Resolve the model_id from the LLM's inference_server_url"""fromhuggingface_hubimportlist_inference_endpointsavailable_endpoints=list_inference_endpoints("*")ifisinstance(self.llm,HuggingFaceHub)or(hasattr(self.llm,"repo_id")andself.llm.repo_id):self.model_id=self.llm.repo_idreturnelifisinstance(self.llm,HuggingFaceTextGenInference):endpoint_url:Optional[str]=self.llm.inference_server_urlelse:endpoint_url=self.llm.endpoint_urlforendpointinavailable_endpoints:ifendpoint.url==endpoint_url:self.model_id=endpoint.repositoryifnotself.model_id:raiseValueError("Failed to resolve model_id:"f"Could not find model id for inference server: {endpoint_url}""Make sure that your Hugging Face token has access to the endpoint.")@propertydef_llm_type(self)->str:return"huggingface-chat-wrapper"