Source code for langchain_azure_ai.chat_models.inference
"""Azure AI Inference Chat Models API."""importjsonimportloggingfromoperatorimportitemgetterfromtypingimport(Any,AsyncIterator,Callable,Dict,Iterable,Iterator,List,Literal,Optional,Sequence,Type,Union,cast,)fromazure.ai.inferenceimportChatCompletionsClientfromazure.ai.inference.aioimportChatCompletionsClientasChatCompletionsClientAsyncfromazure.ai.inference.modelsimport(ChatCompletions,ChatRequestMessage,ChatResponseMessage,JsonSchemaFormat,StreamingChatCompletionsUpdate,)fromazure.core.credentialsimportAzureKeyCredential,TokenCredentialfromazure.core.exceptionsimportHttpResponseErrorfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimportBaseChatModel,ChatGenerationfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,BaseMessageChunk,ChatMessage,ChatMessageChunk,FunctionMessageChunk,HumanMessage,HumanMessageChunk,InvalidToolCall,SystemMessage,SystemMessageChunk,ToolCall,ToolCallChunk,ToolMessage,ToolMessageChunk,)fromlangchain_core.messages.toolimporttool_call_chunkfromlangchain_core.output_parsersimportJsonOutputParser,PydanticOutputParserfromlangchain_core.output_parsers.openai_toolsimportmake_invalid_tool_callfromlangchain_core.outputsimportChatGenerationChunk,ChatResultfromlangchain_core.runnablesimportRunnable,RunnableMap,RunnablePassthroughfromlangchain_core.toolsimportBaseToolfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfromlangchain_core.utils.function_callingimportconvert_to_openai_toolfromlangchain_core.utils.pydanticimportis_basemodel_subclassfrompydanticimportBaseModel,PrivateAttr,model_validatorfromlangchain_azure_ai.utils.utilsimportget_endpoint_from_projectlogger=logging.getLogger(__name__)
[docs]defto_inference_message(messages:List[BaseMessage],)->List[ChatRequestMessage]:"""Converts a sequence of `BaseMessage` to `ChatRequestMessage`. Args: messages (Sequence[BaseMessage]): The messages to convert. Returns: List[ChatRequestMessage]: The converted messages. """new_messages=[]forminmessages:message_dict:Dict[str,Any]={}ifisinstance(m,ChatMessage):message_dict={"role":m.type,"content":m.content,}elifisinstance(m,HumanMessage):message_dict={"role":"user","content":m.content,}elifisinstance(m,AIMessage):message_dict={"role":"assistant","content":m.content,}tool_calls=[]ifm.tool_calls:fortool_callinm.tool_calls:tool_calls.append(_format_tool_call_for_azure_inference(tool_call))elif"tool_calls"inm.additional_kwargs:fortcinm.additional_kwargs["tool_calls"]:chunk={"function":{"name":tc["function"]["name"],"arguments":tc["function"]["arguments"],}}if_id:=tc.get("id"):chunk["id"]=_idtool_calls.append(chunk)else:passiftool_calls:message_dict["tool_calls"]=tool_callselifisinstance(m,SystemMessage):message_dict={"role":"system","content":m.content,}elifisinstance(m,ToolMessage):message_dict={"role":"tool","content":m.content,"name":m.name,"tool_call_id":m.tool_call_id,}new_messages.append(ChatRequestMessage(message_dict))returnnew_messages
[docs]deffrom_inference_message(message:ChatResponseMessage)->BaseMessage:"""Convert an inference message dict to generic message."""ifmessage.role=="user":returnHumanMessage(content=message.content)elifmessage.role=="assistant":tool_calls:List[ToolCall]=[]invalid_tool_calls:List[InvalidToolCall]=[]additional_kwargs:Dict={}ifmessage.tool_calls:fortool_callinmessage.tool_calls:try:tool_calls.append(ToolCall(id=tool_call.get("id"),name=tool_call.function.name,args=json.loads(tool_call.function.arguments),))exceptjson.JSONDecodeErrorase:invalid_tool_calls.append(make_invalid_tool_call(tool_call.as_dict(),str(e)))additional_kwargs.update(tool_calls=tool_calls)ifaudio:=message.get("audio"):additional_kwargs.update(audio=audio)returnAIMessage(id=message.get("id"),content=message.contentor"",additional_kwargs=additional_kwargs,tool_calls=tool_calls,invalid_tool_calls=invalid_tool_calls,)elifmessage.role=="system":returnSystemMessage(content=message.content)elifmessage=="tool":additional_kwargs={}iftool_name:=message.get("name"):additional_kwargs["name"]=tool_namereturnToolMessage(content=message.content,tool_call_id=cast(str,message.get("tool_call_id")),additional_kwargs=additional_kwargs,name=tool_name,id=message.get("id"),)else:returnChatMessage(content=message.content,role=message.role)
def_convert_streaming_result_to_message_chunk(chunk:StreamingChatCompletionsUpdate,default_class:Type[BaseMessageChunk],)->Iterable[ChatGenerationChunk]:token_usage=chunk.get("usage",{})forresinchunk["choices"]:finish_reason=res.get("finish_reason")message=_convert_delta_to_message_chunk(res.delta,default_class)iftoken_usageandisinstance(message,AIMessage):message.usage_metadata={"input_tokens":token_usage.get("prompt_tokens",0),"output_tokens":token_usage.get("completion_tokens",0),"total_tokens":token_usage.get("total_tokens",0),}gen=ChatGenerationChunk(message=message,generation_info={"finish_reason":finish_reason},)yieldgendef_convert_delta_to_message_chunk(_dict:Any,default_class:Type[BaseMessageChunk])->BaseMessageChunk:"""Convert a delta response to a message chunk."""id=_dict.get("id",None)role=_dict.rolecontent=_dict.contentor""additional_kwargs:Dict={}tool_call_chunks:List[ToolCallChunk]=[]ifraw_tool_calls:=_dict.get("tool_calls"):additional_kwargs["tool_calls"]=raw_tool_callstry:tool_call_chunks=[tool_call_chunk(name=rtc["function"].get("name"),args=rtc["function"].get("arguments"),id=rtc.get("id"),index=rtc["index"],)forrtcinraw_tool_calls]exceptKeyError:passifrole=="user"ordefault_class==HumanMessageChunk:returnHumanMessageChunk(content=content)elifrole=="assistant"ordefault_class==AIMessageChunk:returnAIMessageChunk(id=id,content=content,additional_kwargs=additional_kwargs,tool_call_chunks=tool_call_chunks,)elifrole=="system"ordefault_class==SystemMessageChunk:returnSystemMessageChunk(content=content)elifrole=="function"ordefault_class==FunctionMessageChunk:returnFunctionMessageChunk(content=content,name=_dict.name)elifrole=="tool"ordefault_class==ToolMessageChunk:returnToolMessageChunk(content=content,tool_call_id=_dict["tool_call_id"],id=id)elifroleordefault_class==ChatMessageChunk:returnChatMessageChunk(content=content,role=role)else:returndefault_class(content=content)# type: ignore[call-arg]def_format_tool_call_for_azure_inference(tool_call:ToolCall)->dict:"""Format Langchain ToolCall to dict expected by Azure AI Inference."""result:Dict[str,Any]={"function":{"name":tool_call["name"],"arguments":json.dumps(tool_call["args"]),},"type":"function",}if_id:=tool_call.get("id"):result["id"]=_idreturnresult
[docs]classAzureAIChatCompletionsModel(BaseChatModel):"""Azure AI Chat Completions Model. The Azure AI model inference API (https://aka.ms/azureai/modelinference) provides a common layer to talk with most models deployed to Azure AI. This class providers inference for chat completions models supporting it. See documentation for the list of models supporting the API. Examples: .. code-block:: python from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel from langchain_core.messages import HumanMessage, SystemMessage model = AzureAIChatCompletionsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model_name="mistral-large-2407", ) messages = [ SystemMessage( content="Translate the following from English into Italian" ), HumanMessage(content="hi!"), ] model.invoke(messages) For serverless endpoints running a single model, the `model_name` parameter can be omitted: .. code-block:: python from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel from langchain_core.messages import HumanMessage, SystemMessage model = AzureAIChatCompletionsModel( endpoint="https://[your-service].inference.ai.azure.com", credential="your-api-key", ) messages = [ SystemMessage( content="Translate the following from English into Italian" ), HumanMessage(content="hi!"), ] model.invoke(messages) You can pass additional properties to the underlying model, including `temperature`, `top_p`, `presence_penalty`, etc. .. code-block:: python model = AzureAIChatCompletionsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model_name="mistral-large-2407", temperature=0.5, top_p=0.9, ) Certain models may require to pass the `api_version` parameter. When not indicate, the default version of the Azure AI Inference SDK is used. Check the model documentation to know which api version to use. .. code-block:: python model = AzureAIChatCompletionsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model_name="gpt-4o", api_version="2024-05-01-preview", ) Troubleshooting: To diagnostic issues with the model, you can enable debug logging: .. code-block:: python import sys import logging from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel logger = logging.getLogger("azure") # Set the desired logging level. logging. logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(stream=sys.stdout) logger.addHandler(handler) model = AzureAIChatCompletionsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model_name="mistral-large-2407", client_kwargs={ "logging_enable": True } ) """project_connection_string:Optional[str]=None"""The connection string to use for the Azure AI project. If this is specified, then the `endpoint` parameter becomes optional and `credential` has to be of type `TokenCredential`."""endpoint:Optional[str]=None"""The endpoint URI where the model is deployed. Either this or the `project_connection_string` parameter must be specified."""credential:Optional[Union[str,AzureKeyCredential,TokenCredential]]=None"""The API key or credential to use for the Azure AI model inference service."""api_version:Optional[str]=None"""The API version to use for the Azure AI model inference API. If None, the default version is used."""model_name:Optional[str]=None"""The name of the model to use for inference, if the endpoint is running more than one model. If not, this parameter is ignored."""max_tokens:Optional[int]=None"""The maximum number of tokens to generate in the response. If None, the default maximum tokens is used."""temperature:Optional[float]=None"""The temperature to use for sampling from the model. If None, the default temperature is used."""top_p:Optional[float]=None"""The top-p value to use for sampling from the model. If None, the default top-p value is used."""presence_penalty:Optional[float]=None"""The presence penalty to use for sampling from the model. If None, the default presence penalty is used."""frequency_penalty:Optional[float]=None"""The frequency penalty to use for sampling from the model. If None, the default frequency penalty is used."""stop:Optional[str]=None"""The stop token to use for stopping generation. If None, the default stop token is used."""seed:Optional[int]=None"""The seed to use for random number generation. If None, the default seed is used."""model_kwargs:Dict[str,Any]={}"""Additional kwargs model parameters."""client_kwargs:Dict[str,Any]={}"""Additional kwargs for the Azure AI client used."""_client:ChatCompletionsClient=PrivateAttr()_async_client:ChatCompletionsClientAsync=PrivateAttr()_model_name:str=PrivateAttr()
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Any:"""Validate that api key exists in environment."""values["endpoint"]=get_from_dict_or_env(values,"endpoint","AZURE_INFERENCE_ENDPOINT")values["credential"]=get_from_dict_or_env(values,"credential","AZURE_INFERENCE_CREDENTIAL")ifvalues["api_version"]:values["client_kwargs"]["api_version"]=values["api_version"]returnvalues
@model_validator(mode="after")definitialize_client(self)->"AzureAIChatCompletionsModel":"""Initialize the Azure AI model inference client."""ifself.project_connection_string:ifnotisinstance(self.credential,TokenCredential):raiseValueError("When using the `project_connection_string` parameter, the ""`credential` parameter must be of type `TokenCredential`.")self.endpoint,self.credential=get_endpoint_from_project(self.project_connection_string,self.credential)credential=(AzureKeyCredential(self.credential)ifisinstance(self.credential,str)elseself.credential)ifnotself.endpoint:raiseValueError("You must provide an endpoint to use the Azure AI model inference ""client. Pass the endpoint as a parameter or set the ""AZURE_INFERENCE_ENDPOINT environment variable.")ifnotself.credential:raiseValueError("You must provide an credential to use the Azure AI model inference.""client. Pass the credential as a parameter or set the ""AZURE_INFERENCE_CREDENTIAL environment variable.")self._client=ChatCompletionsClient(endpoint=self.endpoint,# type: ignore[arg-type]credential=credential,# type: ignore[arg-type]model=self.model_name,user_agent="langchain-azure-ai",**self.client_kwargs,)self._async_client=ChatCompletionsClientAsync(endpoint=self.endpoint,# type: ignore[arg-type]credential=credential,# type: ignore[arg-type]model=self.model_name,user_agent="langchain-azure-ai",**self.client_kwargs,)ifnotself.model_name:try:# Get model info from the endpoint. This method may not be supported# by all endpoints.model_info=self._client.get_model_info()self._model_name=model_info.get("model_name",None)exceptHttpResponseError:logger.warning(f"Endpoint '{self.endpoint}' does not support model metadata ""retrieval. Unable to populate model attributes. If this endpoint ""supports multiple models, you may be forgetting to indicate ""`model_name` parameter.")self._model_name=""else:self._model_name=self.model_namereturnself@propertydef_llm_type(self)->str:"""Return type of llm."""return"AzureAIChatCompletionsModel"@propertydef_identifying_params(self)->Dict[str,Any]:params:Dict[str,Any]={}ifself.temperature:params["temperature"]=self.temperatureifself.top_p:params["top_p"]=self.top_pifself.presence_penalty:params["presence_penalty"]=self.presence_penaltyifself.frequency_penalty:params["frequency_penalty"]=self.frequency_penaltyifself.max_tokens:params["max_tokens"]=self.max_tokensifself.seed:params["seed"]=self.seedifself.model_kwargs:params["model_extras"]=self.model_kwargsreturnparamsdef_create_chat_result(self,response:ChatCompletions)->ChatResult:generations=[]token_usage=response.get("usage",{})forresinresponse["choices"]:finish_reason=res.get("finish_reason")message=from_inference_message(res.message)iftoken_usageandisinstance(message,AIMessage):message.usage_metadata={"input_tokens":token_usage.get("prompt_tokens",0),"output_tokens":token_usage.get("completion_tokens",0),"total_tokens":token_usage.get("total_tokens",0),}gen=ChatGeneration(message=message,generation_info={"finish_reason":finish_reason},)generations.append(gen)llm_output:Dict[str,Any]={"model":response.modelorself._model_name}ifisinstance(message,AIMessage):llm_output["token_usage"]=message.usage_metadatareturnChatResult(generations=generations,llm_output=llm_output)def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:inference_messages=to_inference_message(messages)response=self._client.complete(messages=inference_messages,stop=stoporself.stop,**self._identifying_params,**kwargs,)returnself._create_chat_result(response)# type: ignore[arg-type]asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:inference_messages=to_inference_message(messages)response=awaitself._async_client.complete(messages=inference_messages,stop=stoporself.stop,**self._identifying_params,**kwargs,)returnself._create_chat_result(response)# type: ignore[arg-type]def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:inference_messages=to_inference_message(messages)default_chunk_class=AIMessageChunkresponse=self._client.complete(messages=inference_messages,stream=True,stop=stoporself.stop,**self._identifying_params,**kwargs,)assertisinstance(response,Iterator)forchunkinresponse:cg_chunks=_convert_streaming_result_to_message_chunk(chunk,default_chunk_class)forcg_chunkincg_chunks:default_chunk_class=cg_chunk.message.__class__# type: ignore[assignment]ifrun_manager:run_manager.on_llm_new_token(cg_chunk.message.content,# type: ignore[arg-type]chunk=cg_chunk,)yieldcg_chunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:inference_messages=to_inference_message(messages)default_chunk_class=AIMessageChunkresponse=awaitself._async_client.complete(messages=inference_messages,stream=True,stop=stoporself.stop,**self._identifying_params,**kwargs,)assertisinstance(response,AsyncIterator)asyncforchunkinresponse:cg_chunks=_convert_streaming_result_to_message_chunk(chunk,default_chunk_class)forcg_chunkincg_chunks:default_chunk_class=cg_chunk.message.__class__# type: ignore[assignment]ifrun_manager:awaitrun_manager.on_llm_new_token(cg_chunk.message.content,# type: ignore[arg-type]chunk=cg_chunk,)yieldcg_chunk
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type,Callable,BaseTool]],**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model. Args: tools: A list of tool definitions to bind to this chat model. Supports any tool definition handled by :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function or "auto" to automatically determine which function to call (if any), or a dict of the form: {"type": "function", "function": {"name": <<tool_name>>}}. kwargs: Any additional parameters are passed directly to ``self.bind(**kwargs)``. """formatted_tools=[convert_to_openai_tool(tool)fortoolintools]returnsuper().bind(tools=formatted_tools,**kwargs)
[docs]defwith_structured_output(self,schema:Union[Dict,type],# noqa: UP006method:Literal["function_calling","json_mode","json_schema"]="function_calling",strict:Optional[bool]=None,*,include_raw:bool=False,**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:# noqa: UP006"""Model wrapper that returns outputs formatted to match the given schema. Args: schema: The schema to use for the output. If a pydantic model is provided, it will be used as the output type. If a dict is provided, it will be used as the schema for the output. method: The method to use for structured output. Can be "function_calling", "json_mode", or "json_schema". strict: Whether to enforce strict mode for "json_schema". include_raw: Whether to include the raw response from the model in the output. kwargs: Any additional parameters are passed directly to ``self.with_structured_output(**kwargs)``. """ifstrictisnotNoneandmethod=="json_mode":raiseValueError("Argument `strict` is not supported with `method`='json_mode'")ifmethod=="json_schema"andschemaisNone:raiseValueError("Argument `schema` must be specified when method is 'json_schema'. ")ifmethodin["json_mode","json_schema"]:ifmethod=="json_mode":llm=self.bind(response_format="json_object")elifmethod=="json_schema":ifisinstance(schema,dict):json_schema=schema.copy()schema_name=json_schema.pop("name",None)output_parser=JsonOutputParser()elifis_basemodel_subclass(schema):json_schema=schema.model_json_schema()# type: ignore[attr-defined]schema_name=json_schema.pop("title",None)output_parser=PydanticOutputParser(pydantic_object=schema)else:raiseValueError("Invalid schema type. Must be dict or BaseModel.")llm=self.bind(response_format=JsonSchemaFormat(name=schema_name,schema=json_schema,description=json_schema.pop("description",None),strict=strict,))ifinclude_raw:parser_assign=RunnablePassthrough.assign(parsed=itemgetter("raw")|output_parser,parsing_error=lambda_:None,)parser_none=RunnablePassthrough.assign(parsed=lambda_:None)parser_with_fallback=parser_assign.with_fallbacks([parser_none],exception_key="parsing_error")returnRunnableMap(raw=llm)|parser_with_fallbackelse:returnllm|output_parserelse:returnsuper().with_structured_output(schema,include_raw=include_raw,**kwargs)
@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","chat_models","azure_inference"]