Source code for langchain_community.chat_models.llama_edge
importjsonimportloggingimportrefromtypingimportAny,Dict,Iterator,List,Mapping,Optional,Typeimportrequestsfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_models.chat_modelsimport(BaseChatModel,generate_from_stream,)fromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,BaseMessageChunk,ChatMessage,ChatMessageChunk,HumanMessage,HumanMessageChunk,SystemMessage,)fromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_core.pydantic_v1importroot_validatorfromlangchain_core.utilsimportget_pydantic_field_nameslogger=logging.getLogger(__name__)def_convert_dict_to_message(_dict:Mapping[str,Any])->BaseMessage:role=_dict["role"]ifrole=="user":returnHumanMessage(content=_dict["content"])elifrole=="assistant":returnAIMessage(content=_dict.get("content","")or"")else:returnChatMessage(content=_dict["content"],role=role)def_convert_message_to_dict(message:BaseMessage)->dict:message_dict:Dict[str,Any]ifisinstance(message,ChatMessage):message_dict={"role":message.role,"content":message.content}elifisinstance(message,SystemMessage):message_dict={"role":"system","content":message.content}elifisinstance(message,HumanMessage):message_dict={"role":"user","content":message.content}elifisinstance(message,AIMessage):message_dict={"role":"assistant","content":message.content}else:raiseTypeError(f"Got unknown type {message}")returnmessage_dictdef_convert_delta_to_message_chunk(_dict:Mapping[str,Any],default_class:Type[BaseMessageChunk])->BaseMessageChunk:role=_dict.get("role")content=_dict.get("content")or""ifrole=="user"ordefault_class==HumanMessageChunk:returnHumanMessageChunk(content=content)elifrole=="assistant"ordefault_class==AIMessageChunk:returnAIMessageChunk(content=content)elifroleordefault_class==ChatMessageChunk:returnChatMessageChunk(content=content,role=role)# type: ignore[arg-type]else:returndefault_class(content=content)# type: ignore[call-arg]
[docs]classLlamaEdgeChatService(BaseChatModel):"""Chat with LLMs via `llama-api-server` For the information about `llama-api-server`, visit https://github.com/second-state/LlamaEdge """request_timeout:int=60"""request timeout for chat http requests"""service_url:Optional[str]=None"""URL of WasmChat service"""model:str="NA""""model name, default is `NA`."""streaming:bool=False"""Whether to stream the results or not."""classConfig:allow_population_by_field_name=True@root_validator(pre=True)defbuild_extra(cls,values:Dict[str,Any])->Dict[str,Any]:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")iffield_namenotinall_required_field_names:logger.warning(f"""WARNING! {field_name} is not default parameter.{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)invalid_model_kwargs=all_required_field_names.intersection(extra.keys())ifinvalid_model_kwargs:raiseValueError(f"Parameters {invalid_model_kwargs} should be specified explicitly. "f"Instead they were passed in as part of `model_kwargs` parameter.")values["model_kwargs"]=extrareturnvaluesdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages=messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)res=self._chat(messages,**kwargs)ifres.status_code!=200:raiseValueError(f"Error code: {res.status_code}, reason: {res.reason}")response=res.json()returnself._create_chat_result(response)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:res=self._chat(messages,**kwargs)default_chunk_class=AIMessageChunksubstring='"object":"chat.completion.chunk"}'forlineinres.iter_lines():chunks=[]ifline:json_string=line.decode("utf-8")# Find all positions of the substringpositions=[m.start()forminre.finditer(substring,json_string)]positions=[-1*len(substring)]+positionsforiinrange(len(positions)-1):chunk=json.loads(json_string[positions[i]+len(substring):positions[i+1]+len(substring)])chunks.append(chunk)forchunkinchunks:ifnotisinstance(chunk,dict):chunk=chunk.dict()iflen(chunk["choices"])==0:continuechoice=chunk["choices"][0]chunk=_convert_delta_to_message_chunk(choice["delta"],default_chunk_class)if(choice.get("finish_reason")isnotNoneandchoice.get("finish_reason")=="stop"):breakfinish_reason=choice.get("finish_reason")generation_info=(dict(finish_reason=finish_reason)iffinish_reasonisnotNoneelseNone)default_chunk_class=chunk.__class__cg_chunk=ChatGenerationChunk(message=chunk,generation_info=generation_info)ifrun_manager:run_manager.on_llm_new_token(cg_chunk.text,chunk=cg_chunk)yieldcg_chunkdef_chat(self,messages:List[BaseMessage],**kwargs:Any)->requests.Response:ifself.service_urlisNone:res=requests.models.Response()res.status_code=503res.reason="The IP address or port of the chat service is incorrect."returnresservice_url=f"{self.service_url}/v1/chat/completions"ifself.streaming:payload={"model":self.model,"messages":[_convert_message_to_dict(m)forminmessages],"stream":self.streaming,}else:payload={"model":self.model,"messages":[_convert_message_to_dict(m)forminmessages],}res=requests.post(url=service_url,timeout=self.request_timeout,headers={"accept":"application/json","Content-Type":"application/json",},data=json.dumps(payload),)returnresdef_create_chat_result(self,response:Mapping[str,Any])->ChatResult:message=_convert_dict_to_message(response["choices"][0].get("message"))generations=[ChatGeneration(message=message)]token_usage=response["usage"]llm_output={"token_usage":token_usage,"model":self.model}returnChatResult(generations=generations,llm_output=llm_output)@propertydef_llm_type(self)->str:return"wasm-chat"