[docs]def__init__(self,endpoint_url:str,endpoint_api_key:str,deployment_name:str="",timeout:int=DEFAULT_TIMEOUT,)->None:"""Initialize the class."""ifnotendpoint_api_keyornotendpoint_url:raiseValueError("""A key/token and REST endpoint should be provided to invoke the endpoint""")self.endpoint_url=endpoint_urlself.endpoint_api_key=endpoint_api_keyself.deployment_name=deployment_nameself.timeout=timeout
[docs]defcall(self,body:bytes,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->bytes:"""call."""# The azureml-model-deployment header will force the request to go to a# specific deployment. Remove this header to have the request observe the# endpoint traffic rules.headers={"Content-Type":"application/json","Authorization":("Bearer "+self.endpoint_api_key),}ifself.deployment_name!="":headers["azureml-model-deployment"]=self.deployment_namereq=urllib.request.Request(self.endpoint_url,body,headers)response=urllib.request.urlopen(req,timeout=kwargs.get("timeout",self.timeout))result=response.read()returnresult
[docs]classAzureMLEndpointApiType(str,Enum):"""Azure ML endpoints API types. Use `dedicated` for models deployed in hosted infrastructure (also known as Online Endpoints in Azure Machine Learning), or `serverless` for models deployed as a service with a pay-as-you-go billing or PTU. """dedicated="dedicated"realtime="realtime"#: Deprecatedserverless="serverless"
[docs]classContentFormatterBase:"""Transform request and response of AzureML endpoint to match with required schema. """""" Example: .. code-block:: python class ContentFormatter(ContentFormatterBase): content_type = "application/json" accepts = "application/json" def format_request_payload( self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType, ) -> bytes: input_str = json.dumps( { "inputs": {"input_string": [prompt]}, "parameters": model_kwargs, } ) return str.encode(input_str) def format_response_payload( self, output: str, api_type: AzureMLEndpointApiType ) -> str: response_json = json.loads(output) return response_json[0]["0"] """content_type:Optional[str]="application/json""""The MIME type of the input data passed to the endpoint"""accepts:Optional[str]="application/json""""The MIME type of the response data returned from the endpoint"""format_error_msg:str=("Error while formatting response payload for chat model of type "" `{api_type}`. Are you using the right formatter for the deployed "" model and endpoint type?")
[docs]@staticmethoddefescape_special_characters(prompt:str)->str:"""Escapes any special characters in `prompt`"""escape_map={"\\":"\\\\",'"':'\\"',"\b":"\\b","\f":"\\f","\n":"\\n","\r":"\\r","\t":"\\t",}# Replace each occurrence of the specified characters with escaped versionsforescape_sequence,escaped_sequenceinescape_map.items():prompt=prompt.replace(escape_sequence,escaped_sequence)returnprompt
@propertydefsupported_api_types(self)->List[AzureMLEndpointApiType]:"""Supported APIs for the given formatter. Azure ML supports deploying models using different hosting methods. Each method may have a different API structure."""return[AzureMLEndpointApiType.dedicated]
[docs]defformat_request_payload(self,prompt:str,model_kwargs:Dict,api_type:AzureMLEndpointApiType=AzureMLEndpointApiType.dedicated,)->Any:"""Formats the request body according to the input schema of the model. Returns bytes or seekable file like object in the format specified in the content_type request header. """raiseNotImplementedError()
[docs]@abstractmethoddefformat_response_payload(self,output:bytes,api_type:AzureMLEndpointApiType=AzureMLEndpointApiType.dedicated,)->Generation:"""Formats the response body according to the output schema of the model. Returns the data type that is received from the response. """
[docs]classGPT2ContentFormatter(ContentFormatterBase):"""Content handler for GPT2"""@propertydefsupported_api_types(self)->List[AzureMLEndpointApiType]:return[AzureMLEndpointApiType.dedicated]
[docs]classOSSContentFormatter(GPT2ContentFormatter):"""Deprecated: Kept for backwards compatibility Content handler for LLMs from the OSS catalog."""content_formatter:Any=None
[docs]def__init__(self)->None:super().__init__()warnings.warn("""`OSSContentFormatter` will be deprecated in the future. Please use `GPT2ContentFormatter` instead. """)
[docs]classHFContentFormatter(ContentFormatterBase):"""Content handler for LLMs from the HuggingFace catalog."""@propertydefsupported_api_types(self)->List[AzureMLEndpointApiType]:return[AzureMLEndpointApiType.dedicated]
[docs]classDollyContentFormatter(ContentFormatterBase):"""Content handler for the Dolly-v2-12b model"""@propertydefsupported_api_types(self)->List[AzureMLEndpointApiType]:return[AzureMLEndpointApiType.dedicated]
[docs]classCustomOpenAIContentFormatter(ContentFormatterBase):"""Content formatter for models that use the OpenAI like API scheme."""@propertydefsupported_api_types(self)->List[AzureMLEndpointApiType]:return[AzureMLEndpointApiType.dedicated,AzureMLEndpointApiType.serverless]
[docs]defformat_request_payload(# type: ignore[override]self,prompt:str,model_kwargs:Dict,api_type:AzureMLEndpointApiType)->bytes:"""Formats the request according to the chosen api"""prompt=ContentFormatterBase.escape_special_characters(prompt)ifapi_typein[AzureMLEndpointApiType.dedicated,AzureMLEndpointApiType.realtime,]:request_payload=json.dumps({"input_data":{"input_string":[f'"{prompt}"'],"parameters":model_kwargs,}})elifapi_type==AzureMLEndpointApiType.serverless:request_payload=json.dumps({"prompt":prompt,**model_kwargs})else:raiseValueError(f"`api_type` {api_type} is not supported by this formatter")returnstr.encode(request_payload)
[docs]defformat_response_payload(# type: ignore[override]self,output:bytes,api_type:AzureMLEndpointApiType)->Generation:"""Formats response"""ifapi_typein[AzureMLEndpointApiType.dedicated,AzureMLEndpointApiType.realtime,]:try:choice=json.loads(output)[0]["0"]except(KeyError,IndexError,TypeError)ase:raiseValueError(self.format_error_msg.format(api_type=api_type))frome# type: ignore[union-attr]returnGeneration(text=choice)ifapi_type==AzureMLEndpointApiType.serverless:try:choice=json.loads(output)["choices"][0]ifnotisinstance(choice,dict):raiseTypeError("Endpoint response is not well formed for a chat ""model. Expected `dict` but `{type(choice)}` was ""received.")except(KeyError,IndexError,TypeError)ase:raiseValueError(self.format_error_msg.format(api_type=api_type))frome# type: ignore[union-attr]returnGeneration(text=choice["text"].strip(),generation_info=dict(finish_reason=choice.get("finish_reason"),logprobs=choice.get("logprobs"),),)raiseValueError(f"`api_type` {api_type} is not supported by this formatter")
[docs]classLlamaContentFormatter(CustomOpenAIContentFormatter):"""Deprecated: Kept for backwards compatibility Content formatter for Llama."""content_formatter:Any=None
[docs]def__init__(self)->None:super().__init__()warnings.warn("""`LlamaContentFormatter` will be deprecated in the future. Please use `CustomOpenAIContentFormatter` instead. """)
[docs]classAzureMLBaseEndpoint(BaseModel):"""Azure ML Online Endpoint models."""endpoint_url:str="""""URL of pre-existing Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_URL`."""endpoint_api_type:AzureMLEndpointApiType=AzureMLEndpointApiType.dedicated"""Type of the endpoint being consumed. Possible values are `serverless` for pay-as-you-go and `dedicated` for dedicated endpoints. """endpoint_api_key:SecretStr=convert_to_secret_str("")"""Authentication Key for Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_API_KEY`."""deployment_name:str="""""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""timeout:int=DEFAULT_TIMEOUT"""Request timeout for calls to the endpoint"""http_client:Any=None#: :meta private:max_retries:int=1content_formatter:Any=None"""The content formatter that provides an input and output transform function to handle formats between the LLM and the endpoint"""model_kwargs:Optional[dict]=None"""Keyword arguments to pass to the model."""@root_validator(pre=True)defvalidate_environ(cls,values:Dict)->Dict:values["endpoint_api_key"]=convert_to_secret_str(get_from_dict_or_env(values,"endpoint_api_key","AZUREML_ENDPOINT_API_KEY"))values["endpoint_url"]=get_from_dict_or_env(values,"endpoint_url","AZUREML_ENDPOINT_URL")values["deployment_name"]=get_from_dict_or_env(values,"deployment_name","AZUREML_DEPLOYMENT_NAME","")values["endpoint_api_type"]=get_from_dict_or_env(values,"endpoint_api_type","AZUREML_ENDPOINT_API_TYPE",AzureMLEndpointApiType.dedicated,)values["timeout"]=get_from_dict_or_env(values,"timeout","AZUREML_TIMEOUT",str(DEFAULT_TIMEOUT),)returnvalues@validator("content_formatter")defvalidate_content_formatter(cls,field_value:Any,values:Dict)->ContentFormatterBase:"""Validate that content formatter is supported by endpoint type."""endpoint_api_type=values.get("endpoint_api_type")ifendpoint_api_typenotinfield_value.supported_api_types:raiseValueError(f"Content formatter f{type(field_value)} is not supported by this "f"endpoint. Supported types are {field_value.supported_api_types} "f"but endpoint is {endpoint_api_type}.")returnfield_value@validator("endpoint_url")defvalidate_endpoint_url(cls,field_value:Any)->str:"""Validate that endpoint url is complete."""iffield_value.endswith("/"):field_value=field_value[:-1]iffield_value.endswith("inference.ml.azure.com"):raiseValueError("`endpoint_url` should contain the full invocation URL including ""`/score` for `endpoint_api_type='dedicated'` or `/completions` ""or `/chat/completions` for `endpoint_api_type='serverless'`")returnfield_value@validator("endpoint_api_type")defvalidate_endpoint_api_type(cls,field_value:Any,values:Dict)->AzureMLEndpointApiType:"""Validate that endpoint api type is compatible with the URL format."""endpoint_url=values.get("endpoint_url")if((field_value==AzureMLEndpointApiType.dedicatedorfield_value==AzureMLEndpointApiType.realtime)andnotendpoint_url.endswith("/score")# type: ignore[union-attr]):raiseValueError("Endpoints of type `dedicated` should follow the format ""`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."" If your endpoint URL ends with `/completions` or""`/chat/completions`, use `endpoint_api_type='serverless'` instead.")iffield_value==AzureMLEndpointApiType.serverlessandnot(endpoint_url.endswith("/completions")# type: ignore[union-attr]orendpoint_url.endswith("/chat/completions")# type: ignore[union-attr]):raiseValueError("Endpoints of type `serverless` should follow the format ""`https://<your-endpoint>.<your_region>.inference.ml.azure.com/chat/completions`"" or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/chat/completions`")returnfield_value@validator("http_client",always=True)defvalidate_client(cls,field_value:Any,values:Dict)->AzureMLEndpointClient:"""Validate that api key and python package exists in environment."""endpoint_url=values.get("endpoint_url")endpoint_key=values.get("endpoint_api_key")deployment_name=values.get("deployment_name")timeout=values.get("timeout",DEFAULT_TIMEOUT)http_client=AzureMLEndpointClient(endpoint_url,# type: ignoreendpoint_key.get_secret_value(),# type: ignoredeployment_name,# type: ignoretimeout,# type: ignore)returnhttp_client
[docs]classAzureMLOnlineEndpoint(BaseLLM,AzureMLBaseEndpoint):"""Azure ML Online Endpoint models. Example: .. code-block:: python azure_llm = AzureMLOnlineEndpoint( endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_api_type=AzureMLApiType.dedicated, endpoint_api_key="my-api-key", timeout=120, content_formatter=content_formatter, ) """@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""_model_kwargs=self.model_kwargsor{}return{**{"deployment_name":self.deployment_name},**{"model_kwargs":_model_kwargs},}@propertydef_llm_type(self)->str:"""Return type of llm."""return"azureml_endpoint"def_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:"""Run the LLM on the given prompts. Args: prompts: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = azureml_model.invoke("Tell me a joke.") """_model_kwargs=self.model_kwargsor{}_model_kwargs.update(kwargs)ifstop:_model_kwargs["stop"]=stopgenerations=[]forpromptinprompts:request_payload=self.content_formatter.format_request_payload(prompt,_model_kwargs,self.endpoint_api_type)response_payload=self.http_client.call(body=request_payload,run_manager=run_manager)generated_text=self.content_formatter.format_response_payload(response_payload,self.endpoint_api_type)generations.append([generated_text])returnLLMResult(generations=generations)