importjsonimportloggingfromtypingimportAny,List,Optional,Unionfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_models.llmsimportLLMfromlangchain_core.messagesimport(AIMessage,BaseMessage,FunctionMessage,HumanMessage,SystemMessage,)frompydanticimportFieldfromlangchain_community.llms.utilsimportenforce_stop_tokenslogger=logging.getLogger(__name__)HEADERS={"Content-Type":"application/json"}DEFAULT_TIMEOUT=30def_convert_message_to_dict(message:BaseMessage)->dict:ifisinstance(message,HumanMessage):message_dict={"role":"user","content":message.content}elifisinstance(message,AIMessage):message_dict={"role":"assistant","content":message.content}elifisinstance(message,SystemMessage):message_dict={"role":"system","content":message.content}elifisinstance(message,FunctionMessage):message_dict={"role":"function","content":message.content}else:raiseValueError(f"Got unknown type {message}")returnmessage_dict
[docs]classChatGLM3(LLM):"""ChatGLM3 LLM service."""model_name:str=Field(default="chatglm3-6b",alias="model")endpoint_url:str="http://127.0.0.1:8000/v1/chat/completions""""Endpoint URL to use."""model_kwargs:Optional[dict]=None"""Keyword arguments to pass to the model."""max_tokens:int=20000"""Max token allowed to pass to the model."""temperature:float=0.1"""LLM model temperature from 0 to 10."""top_p:float=0.7"""Top P for nucleus sampling from 0 to 1"""prefix_messages:List[BaseMessage]=Field(default_factory=list)"""Series of messages for Chat input."""streaming:bool=False"""Whether to stream the results or not."""http_client:Union[Any,None]=Nonetimeout:int=DEFAULT_TIMEOUT@propertydef_llm_type(self)->str:return"chat_glm_3"@propertydef_invocation_params(self)->dict:"""Get the parameters used to invoke the model."""params={"model":self.model_name,"temperature":self.temperature,"max_tokens":self.max_tokens,"top_p":self.top_p,"stream":self.streaming,}return{**params,**(self.model_kwargsor{})}@propertydefclient(self)->Any:importhttpxreturnself.http_clientorhttpx.Client(timeout=self.timeout)def_get_payload(self,prompt:str)->dict:params=self._invocation_paramsmessages=self.prefix_messages+[HumanMessage(content=prompt)]params.update({"messages":[_convert_message_to_dict(m)forminmessages],})returnparamsdef_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to a ChatGLM3 LLM inference endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = chatglm_llm.invoke("Who are you?") """importhttpxpayload=self._get_payload(prompt)logger.debug(f"ChatGLM3 payload: {payload}")try:response=self.client.post(self.endpoint_url,headers=HEADERS,json=payload)excepthttpx.NetworkErrorase:raiseValueError(f"Error raised by inference endpoint: {e}")logger.debug(f"ChatGLM3 response: {response}")ifresponse.status_code!=200:raiseValueError(f"Failed with response: {response}")try:parsed_response=response.json()ifisinstance(parsed_response,dict):content_keys="choices"ifcontent_keysinparsed_response:choices=parsed_response[content_keys]iflen(choices):text=choices[0]["message"]["content"]else:raiseValueError(f"No content in response : {parsed_response}")else:raiseValueError(f"Unexpected response type: {parsed_response}")exceptjson.JSONDecodeErrorase:raiseValueError(f"Error raised during decoding response from inference endpoint: {e}."f"\nResponse: {response.text}")ifstopisnotNone:text=enforce_stop_tokens(text,stop)returntext