[docs]classContentHandlerBase(Generic[INPUT_TYPE,OUTPUT_TYPE]):"""A handler class to transform input from LLM and BaseChatModel to a format that SageMaker endpoint expects. Similarly, the class handles transforming output from the SageMaker endpoint to a format that LLM & BaseChatModel class expects. """""" Example: .. code-block:: python class ContentHandler(ContentHandlerBase): content_type = "application/json" accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: input_str = json.dumps({prompt: prompt, **model_kwargs}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) return response_json[0]["generated_text"] """content_type:Optional[str]="text/plain""""The MIME type of the input data passed to endpoint"""accepts:Optional[str]="text/plain""""The MIME type of the response data returned from endpoint"""
[docs]@abstractmethoddeftransform_input(self,prompt:INPUT_TYPE,model_kwargs:Dict)->bytes:"""Transforms the input to a format that model can accept as the request Body. Should return bytes or seekable file like object in the format specified in the content_type request header. """
[docs]@abstractmethoddeftransform_output(self,output:bytes)->OUTPUT_TYPE:"""Transforms the output from the model to string that the LLM class expects. """
[docs]defenforce_stop_tokens(text:str,stop:List[str])->str:"""Cut off the text as soon as any stop words occur."""returnre.split("|".join(stop),text,maxsplit=1)[0]
[docs]defanthropic_tokens_supported()->bool:"""Check if all requirements for Anthropic count_tokens() are met."""try:importanthropicexceptImportError:returnFalseifversion.parse(anthropic.__version__)>version.parse("0.38.0"):returnFalsetry:importhttpxifversion.parse(httpx.__version__)>version.parse("0.27.2"):raiseImportError()exceptImportError:raiseImportError("httpx<=0.27.2 is required.")returnTrue
[docs]defget_num_tokens_anthropic(text:str)->int:"""Get the number of tokens in a string of text."""client=_get_anthropic_client()returnclient.count_tokens(text=text)
[docs]defget_token_ids_anthropic(text:str)->List[int]:"""Get the token ids for a string of text."""client=_get_anthropic_client()tokenizer=client.get_tokenizer()encoded_text=tokenizer.encode(text)returnencoded_text.ids
[docs]defcreate_aws_client(service_name:str,region_name:Optional[str]=None,credentials_profile_name:Optional[str]=None,aws_access_key_id:Optional[SecretStr]=None,aws_secret_access_key:Optional[SecretStr]=None,aws_session_token:Optional[SecretStr]=None,endpoint_url:Optional[str]=None,config:Any=None,):"""Helper function to validate AWS credentials and create an AWS client. Args: service_name: The name of the AWS service to create a client for. region_name: AWS region name. If not provided, will try to get from environment variables. credentials_profile_name: The name of the AWS credentials profile to use. aws_access_key_id: AWS access key ID. aws_secret_access_key: AWS secret access key. aws_session_token: AWS session token. endpoint_url: The complete URL to use for the constructed client. config: Advanced client configuration options. Returns: boto3.client: An AWS service client instance. """try:importboto3region_name=(region_nameoros.getenv("AWS_REGION")oros.getenv("AWS_DEFAULT_REGION"))client_params={"service_name":service_name,"region_name":region_name,"endpoint_url":endpoint_url,"config":config,}client_params={k:vfork,vinclient_params.items()ifv}needs_session=bool(credentials_profile_nameoraws_access_key_idoraws_secret_access_keyoraws_session_token)ifnotneeds_session:returnboto3.client(**client_params)ifcredentials_profile_name:session=boto3.Session(profile_name=credentials_profile_name)elifaws_access_key_idandaws_secret_access_key:session_params={"aws_access_key_id":aws_access_key_id.get_secret_value(),"aws_secret_access_key":aws_secret_access_key.get_secret_value(),}ifaws_session_token:session_params["aws_session_token"]=aws_session_token.get_secret_value()session=boto3.Session(**session_params)else:raiseValueError("If providing credentials, both aws_access_key_id and ""aws_secret_access_key must be specified.")ifnotclient_params.get("region_name")andsession.region_name:client_params["region_name"]=session.region_namereturnsession.client(**client_params)exceptUnknownServiceErrorase:raiseModuleNotFoundError(f"Ensure that you have installed the latest boto3 package "f"that contains the API for `{service_name}`.")fromeexceptBotoCoreErrorase:raiseValueError("Could not load credentials to authenticate with AWS client. ""Please check that the specified profile name and/or its credentials are valid. "f"Service error: {e}")fromeexceptExceptionase:raiseValueError(f"Error raised by service:\n\n{e}")frome
[docs]defthinking_in_params(params:dict)->bool:"""Check if the thinking parameter is enabled in the request."""returnparams.get("thinking",{}).get("type")=="enabled"