[docs]@asynccontextmanagerasyncdefaconnect_httpx_sse(client:Any,method:str,url:str,**kwargs:Any)->AsyncIterator:"""Async context manager for connecting to an SSE stream. Args: client: The httpx client. method: The HTTP method. url: The URL to connect to. kwargs: Additional keyword arguments to pass to the client. Yields: An EventSource object. """fromhttpx_sseimportEventSourceasyncwithclient.stream(method,url,**kwargs)asresponse:yieldEventSource(response)
[docs]classChatBaichuan(BaseChatModel):"""Baichuan chat model integration. Setup: To use, you should have the environment variable``BAICHUAN_API_KEY`` set with your API KEY. .. code-block:: bash export BAICHUAN_API_KEY="your-api-key" Key init args — completion params: model: Optional[str] Name of Baichuan model to use. max_tokens: Optional[int] Max number of tokens to generate. streaming: Optional[bool] Whether to stream the results or not. temperature: Optional[float] Sampling temperature. top_p: Optional[float] What probability mass to use. top_k: Optional[int] What search sampling control to use. Key init args — client params: api_key: Optional[str] Baichuan API key. If not passed in will be read from env var BAICHUAN_API_KEY. base_url: Optional[str] Base URL for API requests. See full list of supported init args and their descriptions in the params section. Instantiate: .. code-block:: python from langchain_community.chat_models import ChatBaichuan chat = ChatBaichuan( api_key=api_key, model='Baichuan4', # temperature=..., # other params... ) Invoke: .. code-block:: python messages = [ ("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"), ("human", "我喜欢编程。"), ] chat.invoke(messages) .. code-block:: python AIMessage( content='I enjoy programming.', response_metadata={ 'token_usage': { 'prompt_tokens': 93, 'completion_tokens': 5, 'total_tokens': 98 }, 'model': 'Baichuan4' }, id='run-944ff552-6a93-44cf-a861-4e4d849746f9-0' ) Stream: .. code-block:: python for chunk in chat.stream(messages): print(chunk) .. code-block:: python content='I' id='run-f99fcd6f-dd31-46d5-be8f-0b6a22bf77d8' content=' enjoy programming.' id='run-f99fcd6f-dd31-46d5-be8f-0b6a22bf77d8 .. code-block:: python stream = chat.stream(messages) full = next(stream) for chunk in stream: full += chunk full .. code-block:: python AIMessageChunk( content='I like programming.', id='run-74689970-dc31-461d-b729-3b6aa93508d2' ) Async: .. code-block:: python await chat.ainvoke(messages) # stream # async for chunk in chat.astream(messages): # print(chunk) # batch # await chat.abatch([messages]) .. code-block:: python AIMessage( content='I enjoy programming.', response_metadata={ 'token_usage': { 'prompt_tokens': 93, 'completion_tokens': 5, 'total_tokens': 98 }, 'model': 'Baichuan4' }, id='run-952509ed-9154-4ff9-b187-e616d7ddfbba-0' ) Tool calling: .. code-block:: python class get_current_weather(BaseModel): '''Get current weather.''' location: str = Field('City or province, such as Shanghai') llm_with_tools = ChatBaichuan(model='Baichuan3-Turbo').bind_tools([get_current_weather]) llm_with_tools.invoke('How is the weather today?') .. code-block:: python [{'name': 'get_current_weather', 'args': {'location': 'New York'}, 'id': '3951017OF8doB0A', 'type': 'tool_call'}] Response metadata .. code-block:: python ai_msg = chat.invoke(messages) ai_msg.response_metadata .. code-block:: python { 'token_usage': { 'prompt_tokens': 93, 'completion_tokens': 5, 'total_tokens': 98 }, 'model': 'Baichuan4' } """# noqa: E501@propertydeflc_secrets(self)->Dict[str,str]:return{"baichuan_api_key":"BAICHUAN_API_KEY",}@propertydeflc_serializable(self)->bool:returnTruebaichuan_api_base:str=Field(default=DEFAULT_API_BASE,alias="base_url")"""Baichuan custom endpoints"""baichuan_api_key:SecretStr=Field(alias="api_key")"""Baichuan API Key"""baichuan_secret_key:Optional[SecretStr]=None"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""streaming:bool=False"""Whether to stream the results or not."""max_tokens:Optional[int]=None"""Maximum number of tokens to generate."""request_timeout:int=Field(default=60,alias="timeout")"""request timeout for chat http requests"""model:str="Baichuan2-Turbo-192K""""model name of Baichuan, default is `Baichuan2-Turbo-192K`, other options include `Baichuan2-Turbo`"""temperature:Optional[float]=Field(default=0.3)"""What sampling temperature to use."""top_k:int=5"""What search sampling control to use."""top_p:float=0.85"""What probability mass to use."""with_search_enhance:bool=False"""[DEPRECATED, keeping it for for backward compatibility], Whether to use search enhance, default is False."""model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for API call not explicitly specified."""model_config=ConfigDict(populate_by_name=True,)@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)extra=values.get("model_kwargs",{})forfield_nameinlist(values):iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")iffield_namenotinall_required_field_names:logger.warning(f"""WARNING! {field_name} is not default parameter.{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""")extra[field_name]=values.pop(field_name)invalid_model_kwargs=all_required_field_names.intersection(extra.keys())ifinvalid_model_kwargs:raiseValueError(f"Parameters {invalid_model_kwargs} should be specified explicitly. "f"Instead they were passed in as part of `model_kwargs` parameter.")values["model_kwargs"]=extrareturnvalues@model_validator(mode="before")@classmethoddefvalidate_environment(cls,values:Dict)->Any:values["baichuan_api_base"]=get_from_dict_or_env(values,"baichuan_api_base","BAICHUAN_API_BASE",DEFAULT_API_BASE,)values["baichuan_api_key"]=convert_to_secret_str(get_from_dict_or_env(values,["baichuan_api_key","api_key"],"BAICHUAN_API_KEY",))returnvalues@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling Baichuan API."""normal_params={"model":self.model,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,"stream":self.streaming,"max_tokens":self.max_tokens,}return{**normal_params,**self.model_kwargs}def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages=messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)res=self._chat(messages,**kwargs)ifres.status_code!=200:raiseValueError(f"Error from Baichuan api response: {res}")response=res.json()returnself._create_chat_result(response)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:res=self._chat(messages,stream=True,**kwargs)ifres.status_code!=200:raiseValueError(f"Error from Baichuan api response: {res}")default_chunk_class=AIMessageChunkforchunkinres.iter_lines():chunk=chunk.decode("utf-8").strip("\r\n")parts=chunk.split("data: ",1)chunk=parts[1]iflen(parts)>1elseNoneifchunkisNone:continueifchunk=="[DONE]":breakresponse=json.loads(chunk)forminresponse.get("choices"):chunk=_convert_delta_to_message_chunk(m.get("delta"),default_chunk_class)default_chunk_class=chunk.__class__cg_chunk=ChatGenerationChunk(message=chunk)ifrun_manager:run_manager.on_llm_new_token(chunk.content,chunk=cg_chunk)yieldcg_chunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)headers=self._create_headers_parameters(**kwargs)payload=self._create_payload_parameters(messages,**kwargs)importhttpxasyncwithhttpx.AsyncClient(headers=headers,timeout=self.request_timeout)asclient:response=awaitclient.post(self.baichuan_api_base,json=payload)response.raise_for_status()returnself._create_chat_result(response.json())asyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:headers=self._create_headers_parameters(**kwargs)payload=self._create_payload_parameters(messages,stream=True,**kwargs)importhttpxasyncwithhttpx.AsyncClient(headers=headers,timeout=self.request_timeout)asclient:asyncwithaconnect_httpx_sse(client,"POST",self.baichuan_api_base,json=payload)asevent_source:asyncforsseinevent_source.aiter_sse():chunk=json.loads(sse.data)iflen(chunk["choices"])==0:continuechoice=chunk["choices"][0]chunk=_convert_delta_to_message_chunk(choice["delta"],AIMessageChunk)finish_reason=choice.get("finish_reason",None)generation_info=({"finish_reason":finish_reason}iffinish_reasonisnotNoneelseNone)chunk=ChatGenerationChunk(message=chunk,generation_info=generation_info)ifrun_manager:awaitrun_manager.on_llm_new_token(chunk.text,chunk=chunk)yieldchunkiffinish_reasonisnotNone:breakdef_chat(self,messages:List[BaseMessage],**kwargs:Any)->requests.Response:payload=self._create_payload_parameters(messages,**kwargs)url=self.baichuan_api_baseheaders=self._create_headers_parameters(**kwargs)res=requests.post(url=url,timeout=self.request_timeout,headers=headers,json=payload,stream=self.streaming,)returnresdef_create_payload_parameters(# type: ignore[no-untyped-def]self,messages:List[BaseMessage],**kwargs)->Dict[str,Any]:parameters={**self._default_params,**kwargs}temperature=parameters.pop("temperature",0.3)top_k=parameters.pop("top_k",5)top_p=parameters.pop("top_p",0.85)model=parameters.pop("model")with_search_enhance=parameters.pop("with_search_enhance",False)stream=parameters.pop("stream",False)tools=parameters.pop("tools",[])payload={"model":model,"messages":[_convert_message_to_dict(m)forminmessages],"top_k":top_k,"top_p":top_p,"temperature":temperature,"with_search_enhance":with_search_enhance,"stream":stream,"tools":tools,}returnpayloaddef_create_headers_parameters(self,**kwargs)->Dict[str,Any]:# type: ignore[no-untyped-def]parameters={**self._default_params,**kwargs}default_headers=parameters.pop("headers",{})api_key=""ifself.baichuan_api_key:api_key=self.baichuan_api_key.get_secret_value()headers={"Content-Type":"application/json","Authorization":f"Bearer {api_key}",**default_headers,}returnheadersdef_create_chat_result(self,response:Mapping[str,Any])->ChatResult:generations=[]forcinresponse["choices"]:message=_convert_dict_to_message(c["message"])gen=ChatGeneration(message=message)generations.append(gen)token_usage=response["usage"]llm_output={"token_usage":token_usage,"model":self.model}returnChatResult(generations=generations,llm_output=llm_output)@propertydef_llm_type(self)->str:return"baichuan-chat"
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type[BaseModel],Callable,BaseTool]],**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model. Args: tools: A list of tool definitions to bind to this chat model. Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic models, callables, and BaseTools will be automatically converted to their schema dictionary representation. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """formatted_tools=[convert_to_openai_tool(tool)fortoolintools]returnsuper().bind(tools=formatted_tools,**kwargs)