from__future__importannotationsimporthashlibimportjsonimportloggingimportosimportreimportsslimportuuidfromoperatorimportitemgetterfromtypingimport(Any,AsyncContextManager,AsyncIterator,Callable,Dict,Iterator,List,Literal,Optional,Sequence,Tuple,Type,Union,cast,)importcertifiimporthttpxfromhttpx_sseimportEventSource,aconnect_sse,connect_ssefromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimport(BaseChatModel,LangSmithParams,agenerate_from_stream,generate_from_stream,)fromlangchain_core.language_models.llmsimportcreate_base_retry_decoratorfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,BaseMessageChunk,ChatMessage,ChatMessageChunk,HumanMessage,HumanMessageChunk,InvalidToolCall,SystemMessage,SystemMessageChunk,ToolCall,ToolMessage,)fromlangchain_core.messages.toolimporttool_call_chunkfromlangchain_core.output_parsersimport(JsonOutputParser,PydanticOutputParser,)fromlangchain_core.output_parsers.baseimportOutputParserLikefromlangchain_core.output_parsers.openai_toolsimport(JsonOutputKeyToolsParser,PydanticToolsParser,make_invalid_tool_call,parse_tool_call,)fromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_core.runnablesimportRunnable,RunnableMap,RunnablePassthroughfromlangchain_core.toolsimportBaseToolfromlangchain_core.utilsimportget_pydantic_field_names,secret_from_envfromlangchain_core.utils.function_callingimportconvert_to_openai_toolfromlangchain_core.utils.pydanticimportis_basemodel_subclassfromlangchain_core.utils.utilsimport_build_model_kwargsfrompydanticimport(BaseModel,ConfigDict,Field,SecretStr,model_validator,)fromtyping_extensionsimportSelflogger=logging.getLogger(__name__)# Mistral enforces a specific pattern for tool call IDsTOOL_CALL_ID_PATTERN=re.compile(r"^[a-zA-Z0-9]{9}$")# This SSL context is equivelent to the default `verify=True`.# https://www.python-httpx.org/advanced/ssl/#configuring-client-instancesglobal_ssl_context=ssl.create_default_context(cafile=certifi.where())def_create_retry_decorator(llm:ChatMistralAI,run_manager:Optional[Union[AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun]]=None,)->Callable[[Any],Any]:"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""errors=[httpx.RequestError,httpx.StreamError]returncreate_base_retry_decorator(error_types=errors,max_retries=llm.max_retries,run_manager=run_manager)def_is_valid_mistral_tool_call_id(tool_call_id:str)->bool:"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9"""returnbool(TOOL_CALL_ID_PATTERN.match(tool_call_id))def_base62_encode(num:int)->str:"""Encodes a number in base62 and ensures result is of a specified length."""base62="0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"ifnum==0:returnbase62[0]arr=[]base=len(base62)whilenum:num,rem=divmod(num,base)arr.append(base62[rem])arr.reverse()return"".join(arr)def_convert_tool_call_id_to_mistral_compatible(tool_call_id:str)->str:"""Convert a tool call ID to a Mistral-compatible format"""if_is_valid_mistral_tool_call_id(tool_call_id):returntool_call_idelse:hash_bytes=hashlib.sha256(tool_call_id.encode()).digest()hash_int=int.from_bytes(hash_bytes,byteorder="big")base62_str=_base62_encode(hash_int)iflen(base62_str)>=9:returnbase62_str[:9]else:returnbase62_str.rjust(9,"0")def_convert_mistral_chat_message_to_message(_message:Dict,)->BaseMessage:role=_message["role"]assertrole=="assistant",f"Expected role to be 'assistant', got {role}"content=cast(str,_message["content"])additional_kwargs:Dict={}tool_calls=[]invalid_tool_calls=[]ifraw_tool_calls:=_message.get("tool_calls"):additional_kwargs["tool_calls"]=raw_tool_callsforraw_tool_callinraw_tool_calls:try:parsed:dict=cast(dict,parse_tool_call(raw_tool_call,return_id=True))ifnotparsed["id"]:parsed["id"]=uuid.uuid4().hex[:]tool_calls.append(parsed)exceptExceptionase:invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call,str(e)))returnAIMessage(content=content,additional_kwargs=additional_kwargs,tool_calls=tool_calls,invalid_tool_calls=invalid_tool_calls,)def_raise_on_error(response:httpx.Response)->None:"""Raise an error if the response is an error."""ifhttpx.codes.is_error(response.status_code):error_message=response.read().decode("utf-8")raisehttpx.HTTPStatusError(f"Error response {response.status_code} "f"while fetching {response.url}: {error_message}",request=response.request,response=response,)asyncdef_araise_on_error(response:httpx.Response)->None:"""Raise an error if the response is an error."""ifhttpx.codes.is_error(response.status_code):error_message=(awaitresponse.aread()).decode("utf-8")raisehttpx.HTTPStatusError(f"Error response {response.status_code} "f"while fetching {response.url}: {error_message}",request=response.request,response=response,)asyncdef_aiter_sse(event_source_mgr:AsyncContextManager[EventSource],)->AsyncIterator[Dict]:"""Iterate over the server-sent events."""asyncwithevent_source_mgrasevent_source:await_araise_on_error(event_source.response)asyncforeventinevent_source.aiter_sse():ifevent.data=="[DONE]":returnyieldevent.json()asyncdefacompletion_with_retry(llm:ChatMistralAI,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the async completion call."""retry_decorator=_create_retry_decorator(llm,run_manager=run_manager)@retry_decoratorasyncdef_completion_with_retry(**kwargs:Any)->Any:if"stream"notinkwargs:kwargs["stream"]=Falsestream=kwargs["stream"]ifstream:event_source=aconnect_sse(llm.async_client,"POST","/chat/completions",json=kwargs)return_aiter_sse(event_source)else:response=awaitllm.async_client.post(url="/chat/completions",json=kwargs)await_araise_on_error(response)returnresponse.json()returnawait_completion_with_retry(**kwargs)def_convert_chunk_to_message_chunk(chunk:Dict,default_class:Type[BaseMessageChunk])->BaseMessageChunk:_choice=chunk["choices"][0]_delta=_choice["delta"]role=_delta.get("role")content=_delta.get("content")or""ifrole=="user"ordefault_class==HumanMessageChunk:returnHumanMessageChunk(content=content)elifrole=="assistant"ordefault_class==AIMessageChunk:additional_kwargs:Dict={}response_metadata={}ifraw_tool_calls:=_delta.get("tool_calls"):additional_kwargs["tool_calls"]=raw_tool_callstry:tool_call_chunks=[]forraw_tool_callinraw_tool_calls:ifnotraw_tool_call.get("index")andnotraw_tool_call.get("id"):tool_call_id=uuid.uuid4().hex[:]else:tool_call_id=raw_tool_call.get("id")tool_call_chunks.append(tool_call_chunk(name=raw_tool_call["function"].get("name"),args=raw_tool_call["function"].get("arguments"),id=tool_call_id,index=raw_tool_call.get("index"),))exceptKeyError:passelse:tool_call_chunks=[]iftoken_usage:=chunk.get("usage"):usage_metadata={"input_tokens":token_usage.get("prompt_tokens",0),"output_tokens":token_usage.get("completion_tokens",0),"total_tokens":token_usage.get("total_tokens",0),}else:usage_metadata=Noneif_choice.get("finish_reason")isnotNoneandisinstance(chunk.get("model"),str):response_metadata["model_name"]=chunk.get("model")returnAIMessageChunk(content=content,additional_kwargs=additional_kwargs,tool_call_chunks=tool_call_chunks,# type: ignore[arg-type]usage_metadata=usage_metadata,# type: ignore[arg-type]response_metadata=response_metadata,)elifrole=="system"ordefault_class==SystemMessageChunk:returnSystemMessageChunk(content=content)elifroleordefault_class==ChatMessageChunk:returnChatMessageChunk(content=content,role=role)else:returndefault_class(content=content)# type: ignore[call-arg]def_format_tool_call_for_mistral(tool_call:ToolCall)->dict:"""Format Langchain ToolCall to dict expected by Mistral."""result:Dict[str,Any]={"function":{"name":tool_call["name"],"arguments":json.dumps(tool_call["args"]),}}if_id:=tool_call.get("id"):result["id"]=_convert_tool_call_id_to_mistral_compatible(_id)returnresultdef_format_invalid_tool_call_for_mistral(invalid_tool_call:InvalidToolCall)->dict:"""Format Langchain InvalidToolCall to dict expected by Mistral."""result:Dict[str,Any]={"function":{"name":invalid_tool_call["name"],"arguments":invalid_tool_call["args"],}}if_id:=invalid_tool_call.get("id"):result["id"]=_convert_tool_call_id_to_mistral_compatible(_id)returnresultdef_convert_message_to_mistral_chat_message(message:BaseMessage,)->Dict:ifisinstance(message,ChatMessage):returndict(role=message.role,content=message.content)elifisinstance(message,HumanMessage):returndict(role="user",content=message.content)elifisinstance(message,AIMessage):message_dict:Dict[str,Any]={"role":"assistant"}tool_calls=[]ifmessage.tool_callsormessage.invalid_tool_calls:fortool_callinmessage.tool_calls:tool_calls.append(_format_tool_call_for_mistral(tool_call))forinvalid_tool_callinmessage.invalid_tool_calls:tool_calls.append(_format_invalid_tool_call_for_mistral(invalid_tool_call))elif"tool_calls"inmessage.additional_kwargs:fortcinmessage.additional_kwargs["tool_calls"]:chunk={"function":{"name":tc["function"]["name"],"arguments":tc["function"]["arguments"],}}if_id:=tc.get("id"):chunk["id"]=_idtool_calls.append(chunk)else:passiftool_calls:# do not populate empty list tool_callsmessage_dict["tool_calls"]=tool_callsiftool_callsandmessage.content:# Assistant message must have either content or tool_calls, but not both.# Some providers may not support tool_calls in the same message as content.# This is done to ensure compatibility with messages from other providers.message_dict["content"]=""else:message_dict["content"]=message.contentif"prefix"inmessage.additional_kwargs:message_dict["prefix"]=message.additional_kwargs["prefix"]returnmessage_dictelifisinstance(message,SystemMessage):returndict(role="system",content=message.content)elifisinstance(message,ToolMessage):return{"role":"tool","content":message.content,"name":message.name,"tool_call_id":_convert_tool_call_id_to_mistral_compatible(message.tool_call_id),}else:raiseValueError(f"Got unknown type {message}")
[docs]classChatMistralAI(BaseChatModel):"""A chat model that uses the MistralAI API."""# The type for client and async_client is ignored because the type is not# an Optional after the model is initialized and the model_validator# is run.client:httpx.Client=Field(# type: ignore # : meta private:default=None,exclude=True)async_client:httpx.AsyncClient=Field(# type: ignore # : meta private:default=None,exclude=True)#: :meta private:mistral_api_key:Optional[SecretStr]=Field(alias="api_key",default_factory=secret_from_env("MISTRAL_API_KEY",default=None),)endpoint:Optional[str]=Field(default=None,alias="base_url")max_retries:int=5timeout:int=120max_concurrent_requests:int=64model:str=Field(default="mistral-small",alias="model_name")temperature:float=0.7max_tokens:Optional[int]=Nonetop_p:float=1"""Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""random_seed:Optional[int]=Nonesafe_mode:Optional[bool]=Nonestreaming:bool=Falsemodel_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any invocation parameters not explicitly specified."""model_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,)@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra kwargs from additional params that were passed in."""all_required_field_names=get_pydantic_field_names(cls)values=_build_model_kwargs(values,all_required_field_names)returnvalues@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling the API."""defaults={"model":self.model,"temperature":self.temperature,"max_tokens":self.max_tokens,"top_p":self.top_p,"random_seed":self.random_seed,"safe_prompt":self.safe_mode,**self.model_kwargs,}filtered={k:vfork,vindefaults.items()ifvisnotNone}returnfiltereddef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._get_invocation_params(stop=stop,**kwargs)ls_params=LangSmithParams(ls_provider="mistral",ls_model_name=self.model,ls_model_type="chat",ls_temperature=params.get("temperature",self.temperature),)ifls_max_tokens:=params.get("max_tokens",self.max_tokens):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporparams.get("stop",None):ls_params["ls_stop"]=ls_stopreturnls_params@propertydef_client_params(self)->Dict[str,Any]:"""Get the parameters used for the client."""returnself._default_params
[docs]defcompletion_with_retry(self,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(self,run_manager=run_manager)@retry_decoratordef_completion_with_retry(**kwargs:Any)->Any:if"stream"notinkwargs:kwargs["stream"]=Falsestream=kwargs["stream"]ifstream:defiter_sse()->Iterator[Dict]:withconnect_sse(self.client,"POST","/chat/completions",json=kwargs)asevent_source:_raise_on_error(event_source.response)foreventinevent_source.iter_sse():ifevent.data=="[DONE]":returnyieldevent.json()returniter_sse()else:response=self.client.post(url="/chat/completions",json=kwargs)_raise_on_error(response)returnresponse.json()rtn=_completion_with_retry(**kwargs)returnrtn
def_combine_llm_outputs(self,llm_outputs:List[Optional[dict]])->dict:overall_token_usage:dict={}foroutputinllm_outputs:ifoutputisNone:# Happens in streamingcontinuetoken_usage=output["token_usage"]iftoken_usageisnotNone:fork,vintoken_usage.items():ifkinoverall_token_usage:overall_token_usage[k]+=velse:overall_token_usage[k]=vcombined={"token_usage":overall_token_usage,"model_name":self.model}returncombined@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate api key, python package exists, temperature, and top_p."""ifisinstance(self.mistral_api_key,SecretStr):api_key_str:Optional[str]=self.mistral_api_key.get_secret_value()else:api_key_str=self.mistral_api_key# todo: handle retriesbase_url_str=(self.endpointoros.environ.get("MISTRAL_BASE_URL")or"https://api.mistral.ai/v1")self.endpoint=base_url_strifnotself.client:self.client=httpx.Client(base_url=base_url_str,headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"Bearer {api_key_str}",},timeout=self.timeout,verify=global_ssl_context,)# todo: handle retries and max_concurrencyifnotself.async_client:self.async_client=httpx.AsyncClient(base_url=base_url_str,headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"Bearer {api_key_str}",},timeout=self.timeout,verify=global_ssl_context,)ifself.temperatureisnotNoneandnot0<=self.temperature<=1:raiseValueError("temperature must be in the range [0.0, 1.0]")ifself.top_pisnotNoneandnot0<=self.top_p<=1:raiseValueError("top_p must be in the range [0.0, 1.0]")returnselfdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs}response=self.completion_with_retry(messages=message_dicts,run_manager=run_manager,**params)returnself._create_chat_result(response)def_create_chat_result(self,response:Dict)->ChatResult:generations=[]token_usage=response.get("usage",{})forresinresponse["choices"]:finish_reason=res.get("finish_reason")message=_convert_mistral_chat_message_to_message(res["message"])iftoken_usageandisinstance(message,AIMessage):message.usage_metadata={"input_tokens":token_usage.get("prompt_tokens",0),"output_tokens":token_usage.get("completion_tokens",0),"total_tokens":token_usage.get("total_tokens",0),}gen=ChatGeneration(message=message,generation_info={"finish_reason":finish_reason},)generations.append(gen)llm_output={"token_usage":token_usage,"model_name":self.model,"model":self.model,# Backwards compatability}returnChatResult(generations=generations,llm_output=llm_output)def_create_message_dicts(self,messages:List[BaseMessage],stop:Optional[List[str]])->Tuple[List[Dict],Dict[str,Any]]:params=self._client_paramsifstopisnotNoneor"stop"inparams:if"stop"inparams:params.pop("stop")logger.warning("Parameter `stop` not yet supported (https://docs.mistral.ai/api)")message_dicts=[_convert_message_to_mistral_chat_message(m)forminmessages]returnmessage_dicts,paramsdef_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs,"stream":True}default_chunk_class:Type[BaseMessageChunk]=AIMessageChunkforchunkinself.completion_with_retry(messages=message_dicts,run_manager=run_manager,**params):iflen(chunk.get("choices",[]))==0:continuenew_chunk=_convert_chunk_to_message_chunk(chunk,default_chunk_class)# make future chunks same type as first chunkdefault_chunk_class=new_chunk.__class__gen_chunk=ChatGenerationChunk(message=new_chunk)ifrun_manager:run_manager.on_llm_new_token(token=cast(str,new_chunk.content),chunk=gen_chunk)yieldgen_chunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs,"stream":True}default_chunk_class:Type[BaseMessageChunk]=AIMessageChunkasyncforchunkinawaitacompletion_with_retry(self,messages=message_dicts,run_manager=run_manager,**params):iflen(chunk.get("choices",[]))==0:continuenew_chunk=_convert_chunk_to_message_chunk(chunk,default_chunk_class)# make future chunks same type as first chunkdefault_chunk_class=new_chunk.__class__gen_chunk=ChatGenerationChunk(message=new_chunk)ifrun_manager:awaitrun_manager.on_llm_new_token(token=cast(str,new_chunk.content),chunk=gen_chunk)yieldgen_chunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._astream(messages=messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs}response=awaitacompletion_with_retry(self,messages=message_dicts,run_manager=run_manager,**params)returnself._create_chat_result(response)
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type,Callable,BaseTool]],tool_choice:Optional[Union[dict,str,Literal["auto","any"]]]=None,**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model. Assumes model is compatible with OpenAI tool-calling API. Args: tools: A list of tool definitions to bind to this chat model. Supports any tool definition handled by :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function or "auto" to automatically determine which function to call (if any), or a dict of the form: {"type": "function", "function": {"name": <<tool_name>>}}. kwargs: Any additional parameters are passed directly to ``self.bind(**kwargs)``. """formatted_tools=[convert_to_openai_tool(tool)fortoolintools]iftool_choice:tool_names=[]fortoolinformatted_tools:if"function"intooland(name:=tool["function"].get("name")):tool_names.append(name)elifname:=tool.get("name"):tool_names.append(name)else:passiftool_choiceintool_names:kwargs["tool_choice"]={"type":"function","function":{"name":tool_choice},}else:kwargs["tool_choice"]=tool_choicereturnsuper().bind(tools=formatted_tools,**kwargs)
[docs]defwith_structured_output(self,schema:Optional[Union[Dict,Type]]=None,*,method:Literal["function_calling","json_mode","json_schema"]="function_calling",include_raw:bool=False,**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:"""Model wrapper that returns outputs formatted to match the given schema. Args: schema: The output schema. Can be passed in as: - an OpenAI function/tool schema, - a JSON Schema, - a TypedDict class (support added in 0.1.12), - or a Pydantic class. If ``schema`` is a Pydantic class then the model output will be a Pydantic instance of that class, and the model-generated fields will be validated by the Pydantic class. Otherwise the model output will be a dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` for more on how to properly specify types and descriptions of schema fields when specifying a Pydantic or TypedDict class. .. versionchanged:: 0.1.12 Added support for TypedDict class. method: The method for steering model generation, one of: - "function_calling": Uses Mistral's `function-calling feature <https://docs.mistral.ai/capabilities/function_calling/>`_. - "json_schema": Uses Mistral's `structured output feature <https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/>`_. - "json_mode": Uses Mistral's `JSON mode <https://docs.mistral.ai/capabilities/structured-output/json_mode/>`_. Note that if using JSON mode then you must include instructions for formatting the output into the desired schema into the model call. .. versionchanged:: 0.2.5 Added method="json_schema" include_raw: If False then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If True then both the raw model response (a BaseMessage) and the parsed model response will be returned. If an error occurs during output parsing it will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". Returns: A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict. If ``include_raw`` is True, then Runnable outputs a dict with keys: - ``"raw"``: BaseMessage - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. - ``"parsing_error"``: Optional[BaseException] Example: schema=Pydantic class, method="function_calling", include_raw=False: .. code-block:: python from typing import Optional from langchain_mistralai import ChatMistralAI from pydantic import BaseModel, Field class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' answer: str # If we provide default values and/or descriptions for fields, these will be passed # to the model. This is an important part of improving a model's ability to # correctly return structured outputs. justification: Optional[str] = Field( default=None, description="A justification for the answer." ) llm = ChatMistralAI(model="mistral-large-latest", temperature=0) structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm.invoke( "What weighs more a pound of bricks or a pound of feathers" ) # -> AnswerWithJustification( # answer='They weigh the same', # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' # ) Example: schema=Pydantic class, method="function_calling", include_raw=True: .. code-block:: python from langchain_mistralai import ChatMistralAI from pydantic import BaseModel class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' answer: str justification: str llm = ChatMistralAI(model="mistral-large-latest", temperature=0) structured_llm = llm.with_structured_output( AnswerWithJustification, include_raw=True ) structured_llm.invoke( "What weighs more a pound of bricks or a pound of feathers" ) # -> { # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), # 'parsing_error': None # } Example: schema=TypedDict class, method="function_calling", include_raw=False: .. code-block:: python # IMPORTANT: If you are using Python <=3.8, you need to import Annotated # from typing_extensions, not from typing. from typing_extensions import Annotated, TypedDict from langchain_mistralai import ChatMistralAI class AnswerWithJustification(TypedDict): '''An answer to the user question along with justification for the answer.''' answer: str justification: Annotated[ Optional[str], None, "A justification for the answer." ] llm = ChatMistralAI(model="mistral-large-latest", temperature=0) structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm.invoke( "What weighs more a pound of bricks or a pound of feathers" ) # -> { # 'answer': 'They weigh the same', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' # } Example: schema=OpenAI function schema, method="function_calling", include_raw=False: .. code-block:: python from langchain_mistralai import ChatMistralAI oai_schema = { 'name': 'AnswerWithJustification', 'description': 'An answer to the user question along with justification for the answer.', 'parameters': { 'type': 'object', 'properties': { 'answer': {'type': 'string'}, 'justification': {'description': 'A justification for the answer.', 'type': 'string'} }, 'required': ['answer'] } } llm = ChatMistralAI(model="mistral-large-latest", temperature=0) structured_llm = llm.with_structured_output(oai_schema) structured_llm.invoke( "What weighs more a pound of bricks or a pound of feathers" ) # -> { # 'answer': 'They weigh the same', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' # } Example: schema=Pydantic class, method="json_mode", include_raw=True: .. code-block:: from langchain_mistralai import ChatMistralAI from pydantic import BaseModel class AnswerWithJustification(BaseModel): answer: str justification: str llm = ChatMistralAI(model="mistral-large-latest", temperature=0) structured_llm = llm.with_structured_output( AnswerWithJustification, method="json_mode", include_raw=True ) structured_llm.invoke( "Answer the following question. " "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n" "What's heavier a pound of bricks or a pound of feathers?" ) # -> { # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'), # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), # 'parsing_error': None # } Example: schema=None, method="json_mode", include_raw=True: .. code-block:: structured_llm = llm.with_structured_output(method="json_mode", include_raw=True) structured_llm.invoke( "Answer the following question. " "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n" "What's heavier a pound of bricks or a pound of feathers?" ) # -> { # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'), # 'parsed': { # 'answer': 'They are both the same weight.', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' # }, # 'parsing_error': None # } """# noqa: E501_=kwargs.pop("strict",None)ifkwargs:raiseValueError(f"Received unsupported arguments {kwargs}")is_pydantic_schema=isinstance(schema,type)andis_basemodel_subclass(schema)ifmethod=="function_calling":ifschemaisNone:raiseValueError("schema must be specified when method is 'function_calling'. ""Received None.")# TODO: Update to pass in tool name as tool_choice if/when Mistral supports# specifying a tool.llm=self.bind_tools([schema],tool_choice="any",ls_structured_output_format={"kwargs":{"method":"function_calling"},"schema":schema,},)ifis_pydantic_schema:output_parser:OutputParserLike=PydanticToolsParser(tools=[schema],# type: ignore[list-item]first_tool_only=True,# type: ignore[list-item])else:key_name=convert_to_openai_tool(schema)["function"]["name"]output_parser=JsonOutputKeyToolsParser(key_name=key_name,first_tool_only=True)elifmethod=="json_mode":llm=self.bind(response_format={"type":"json_object"},ls_structured_output_format={"kwargs":{# this is correct - name difference with mistral api"method":"json_mode"},"schema":schema,},)output_parser=(PydanticOutputParser(pydantic_object=schema)# type: ignore[type-var, arg-type]ifis_pydantic_schemaelseJsonOutputParser())elifmethod=="json_schema":ifschemaisNone:raiseValueError("schema must be specified when method is 'json_schema'. ""Received None.")response_format=_convert_to_openai_response_format(schema,strict=True)llm=self.bind(response_format=response_format,ls_structured_output_format={"kwargs":{"method":"json_schema"},"schema":schema,},)output_parser=(PydanticOutputParser(pydantic_object=schema)# type: ignore[arg-type]ifis_pydantic_schemaelseJsonOutputParser())ifinclude_raw:parser_assign=RunnablePassthrough.assign(parsed=itemgetter("raw")|output_parser,parsing_error=lambda_:None)parser_none=RunnablePassthrough.assign(parsed=lambda_:None)parser_with_fallback=parser_assign.with_fallbacks([parser_none],exception_key="parsing_error")returnRunnableMap(raw=llm)|parser_with_fallbackelse:returnllm|output_parser
@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""returnself._default_params@propertydef_llm_type(self)->str:"""Return type of chat model."""return"mistralai-chat"@propertydeflc_secrets(self)->Dict[str,str]:return{"mistral_api_key":"MISTRAL_API_KEY"}@classmethoddefis_lc_serializable(cls)->bool:"""Return whether this model can be serialized by Langchain."""returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","chat_models","mistralai"]
def_convert_to_openai_response_format(schema:Union[Dict[str,Any],Type],*,strict:Optional[bool]=None)->Dict:"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""if(isinstance(schema,dict)and"json_schema"inschemaandschema.get("type")=="json_schema"):response_format=schemaelifisinstance(schema,dict)and"name"inschemaand"schema"inschema:response_format={"type":"json_schema","json_schema":schema}else:ifstrictisNone:ifisinstance(schema,dict)andisinstance(schema.get("strict"),bool):strict=schema["strict"]else:strict=Falsefunction=convert_to_openai_tool(schema,strict=strict)["function"]function["schema"]=function.pop("parameters")response_format={"type":"json_schema","json_schema":function}ifstrictisnotNoneandstrictisnotresponse_format["json_schema"].get("strict"):msg=(f"Output schema already has 'strict' value set to "f"{schema['json_schema']['strict']} but 'strict' also passed in to "f"with_structured_output as {strict}. Please make sure that "f"'strict' is only specified in one place.")raiseValueError(msg)returnresponse_format