Source code for langchain_community.llms.databricks
importosimportreimportwarningsfromabcimportABC,abstractmethodfromtypingimportAny,Callable,Dict,List,Mapping,Optionalimportrequestsfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_modelsimportLLMfromlangchain_core.pydantic_v1import(BaseModel,Field,PrivateAttr,root_validator,validator,)__all__=["Databricks"]class_DatabricksClientBase(BaseModel,ABC):"""A base JSON API client that talks to Databricks."""api_url:strapi_token:strdefrequest(self,method:str,url:str,request:Any)->Any:headers={"Authorization":f"Bearer {self.api_token}"}response=requests.request(method=method,url=url,headers=headers,json=request)# TODO: error handling and automatic retriesifnotresponse.ok:raiseValueError(f"HTTP {response.status_code} error: {response.text}")returnresponse.json()def_get(self,url:str)->Any:returnself.request("GET",url,None)def_post(self,url:str,request:Any)->Any:returnself.request("POST",url,request)@abstractmethoddefpost(self,request:Any,transform_output_fn:Optional[Callable[...,str]]=None)->Any:...@propertydefllm(self)->bool:returnFalsedef_transform_completions(response:Dict[str,Any])->str:returnresponse["choices"][0]["text"]def_transform_llama2_chat(response:Dict[str,Any])->str:returnresponse["candidates"][0]["text"]def_transform_chat(response:Dict[str,Any])->str:returnresponse["choices"][0]["message"]["content"]class_DatabricksServingEndpointClient(_DatabricksClientBase):"""An API client that talks to a Databricks serving endpoint."""host:strendpoint_name:strdatabricks_uri:strclient:Any=Noneexternal_or_foundation:bool=Falsetask:Optional[str]=Nonedef__init__(self,**data:Any):super().__init__(**data)try:frommlflow.deploymentsimportget_deploy_clientself.client=get_deploy_client(self.databricks_uri)exceptImportErrorase:raiseImportError("Failed to create the client. ""Please install mlflow with `pip install mlflow`.")fromeendpoint=self.client.get_endpoint(self.endpoint_name)self.external_or_foundation=endpoint.get("endpoint_type","").lower()in("external_model","foundation_model_api",)ifself.taskisNone:self.task=endpoint.get("task")@propertydefllm(self)->bool:returnself.taskin("llm/v1/chat","llm/v1/completions","llama2/chat")@root_validator(pre=True)defset_api_url(cls,values:Dict[str,Any])->Dict[str,Any]:if"api_url"notinvalues:host=values["host"]endpoint_name=values["endpoint_name"]api_url=f"https://{host}/serving-endpoints/{endpoint_name}/invocations"values["api_url"]=api_urlreturnvaluesdefpost(self,request:Any,transform_output_fn:Optional[Callable[...,str]]=None)->Any:ifself.external_or_foundation:resp=self.client.predict(endpoint=self.endpoint_name,inputs=request)iftransform_output_fn:returntransform_output_fn(resp)ifself.task=="llm/v1/chat":return_transform_chat(resp)elifself.task=="llm/v1/completions":return_transform_completions(resp)returnrespelse:# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.htmlwrapped_request={"dataframe_records":[request]}response=self.client.predict(endpoint=self.endpoint_name,inputs=wrapped_request)preds=response["predictions"]# For a single-record query, the result is not a list.pred=preds[0]ifisinstance(preds,list)elsepredsifself.task=="llama2/chat":return_transform_llama2_chat(pred)returntransform_output_fn(pred)iftransform_output_fnelsepredclass_DatabricksClusterDriverProxyClient(_DatabricksClientBase):"""An API client that talks to a Databricks cluster driver proxy app."""host:strcluster_id:strcluster_driver_port:str@root_validator(pre=True)defset_api_url(cls,values:Dict[str,Any])->Dict[str,Any]:if"api_url"notinvalues:host=values["host"]cluster_id=values["cluster_id"]port=values["cluster_driver_port"]api_url=f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"values["api_url"]=api_urlreturnvaluesdefpost(self,request:Any,transform_output_fn:Optional[Callable[...,str]]=None)->Any:resp=self._post(self.api_url,request)returntransform_output_fn(resp)iftransform_output_fnelseresp
[docs]defget_repl_context()->Any:"""Get the notebook REPL context if running inside a Databricks notebook. Returns None otherwise. """try:fromdbruntime.databricks_repl_contextimportget_contextreturnget_context()exceptImportError:raiseImportError("Cannot access dbruntime, not running inside a Databricks notebook.")
[docs]defget_default_host()->str:"""Get the default Databricks workspace hostname. Raises an error if the hostname cannot be automatically determined. """host=os.getenv("DATABRICKS_HOST")ifnothost:try:host=get_repl_context().browserHostNameifnothost:raiseValueError("context doesn't contain browserHostName.")exceptExceptionase:raiseValueError("host was not set and cannot be automatically inferred. Set "f"environment variable 'DATABRICKS_HOST'. Received error: {e}")# TODO: support Databricks CLI profilehost=host.lstrip("https://").lstrip("http://").rstrip("/")returnhost
[docs]defget_default_api_token()->str:"""Get the default Databricks personal access token. Raises an error if the token cannot be automatically determined. """ifapi_token:=os.getenv("DATABRICKS_TOKEN"):returnapi_tokentry:api_token=get_repl_context().apiTokenifnotapi_token:raiseValueError("context doesn't contain apiToken.")exceptExceptionase:raiseValueError("api_token was not set and cannot be automatically inferred. Set "f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}")# TODO: support Databricks CLI profilereturnapi_token
def_is_hex_string(data:str)->bool:"""Checks if a data is a valid hexadecimal string using a regular expression."""ifnotisinstance(data,str):returnFalsepattern=r"^[0-9a-fA-F]+$"returnbool(re.match(pattern,data))def_load_pickled_fn_from_hex_string(data:str,allow_dangerous_deserialization:Optional[bool])->Callable:"""Loads a pickled function from a hexadecimal string."""ifnotallow_dangerous_deserialization:raiseValueError("This code relies on the pickle module. ""You will need to set allow_dangerous_deserialization=True ""if you want to opt-in to allow deserialization of data using pickle.""Data can be compromised by a malicious actor if ""not handled properly to include ""a malicious payload that when deserialized with ""pickle can execute arbitrary code on your machine.")try:importcloudpickleexceptExceptionase:raiseValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")try:returncloudpickle.loads(bytes.fromhex(data))# ignore[pickle]: explicit-opt-inexceptExceptionase:raiseValueError(f"Failed to load the pickled function from a hexadecimal string. Error: {e}")def_pickle_fn_to_hex_string(fn:Callable)->str:"""Pickles a function and returns the hexadecimal string."""try:importcloudpickleexceptExceptionase:raiseValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")try:returncloudpickle.dumps(fn).hex()exceptExceptionase:raiseValueError(f"Failed to pickle the function: {e}")
[docs]classDatabricks(LLM):"""Databricks serving endpoint or a cluster driver proxy app for LLM. It supports two endpoint types: * **Serving endpoint** (recommended for both production and development). We assume that an LLM was deployed to a serving endpoint. To wrap it as an LLM you must have "Can Query" permission to the endpoint. Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and ``cluster_driver_port``. If the underlying model is a model registered by MLflow, the expected model signature is: * inputs:: [{"name": "prompt", "type": "string"}, {"name": "stop", "type": "list[string]"}] * outputs: ``[{"type": "string"}]`` If the underlying model is an external or foundation model, the response from the endpoint is automatically transformed to the expected format unless ``transform_output_fn`` is provided. * **Cluster driver proxy app** (recommended for interactive development). One can load an LLM on a Databricks interactive cluster and start a local HTTP server on the driver node to serve the model at ``/`` using HTTP POST method with JSON input/output. Please use a port number between ``[3000, 8000]`` and let the server listen to the driver IP address or simply ``0.0.0.0`` instead of localhost only. To wrap it as an LLM you must have "Can Attach To" permission to the cluster. Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``. The expected server schema (using JSON schema) is: * inputs:: {"type": "object", "properties": { "prompt": {"type": "string"}, "stop": {"type": "array", "items": {"type": "string"}}}, "required": ["prompt"]}` * outputs: ``{"type": "string"}`` If the endpoint model signature is different or you want to set extra params, you can use `transform_input_fn` and `transform_output_fn` to apply necessary transformations before and after the query. """host:str=Field(default_factory=get_default_host)"""Databricks workspace hostname. If not provided, the default value is determined by * the ``DATABRICKS_HOST`` environment variable if present, or * the hostname of the current Databricks workspace if running inside a Databricks notebook attached to an interactive cluster in "single user" or "no isolation shared" mode. """api_token:str=Field(default_factory=get_default_api_token)"""Databricks personal access token. If not provided, the default value is determined by * the ``DATABRICKS_TOKEN`` environment variable if present, or * an automatically generated temporary token if running inside a Databricks notebook attached to an interactive cluster in "single user" or "no isolation shared" mode. """endpoint_name:Optional[str]=None"""Name of the model serving endpoint. You must specify the endpoint name to connect to a model serving endpoint. You must not set both ``endpoint_name`` and ``cluster_id``. """cluster_id:Optional[str]=None"""ID of the cluster if connecting to a cluster driver proxy app. If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs inside a Databricks notebook attached to an interactive cluster in "single user" or "no isolation shared" mode, the current cluster ID is used as default. You must not set both ``endpoint_name`` and ``cluster_id``. """cluster_driver_port:Optional[str]=None"""The port number used by the HTTP server running on the cluster driver node. The server should listen on the driver IP address or simply ``0.0.0.0`` to connect. We recommend the server using a port number between ``[3000, 8000]``. """model_kwargs:Optional[Dict[str,Any]]=None""" Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to the endpoint. """transform_input_fn:Optional[Callable]=None"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible request object that the endpoint accepts. For example, you can apply a prompt template to the input prompt. """transform_output_fn:Optional[Callable[...,str]]=None"""A function that transforms the output from the endpoint to the generated text. """databricks_uri:str="databricks""""The databricks URI. Only used when using a serving endpoint."""temperature:float=0.0"""The sampling temperature."""n:int=1"""The number of completion choices to generate."""stop:Optional[List[str]]=None"""The stop sequence."""max_tokens:Optional[int]=None"""The maximum number of tokens to generate."""extra_params:Dict[str,Any]=Field(default_factory=dict)"""Any extra parameters to pass to the endpoint."""task:Optional[str]=None"""The task of the endpoint. Only used when using a serving endpoint. If not provided, the task is automatically inferred from the endpoint. """allow_dangerous_deserialization:bool=False"""Whether to allow dangerous deserialization of the data which involves loading data using pickle. If the data has been modified by a malicious actor, it can deliver a malicious payload that results in execution of arbitrary code on the target machine. """_client:_DatabricksClientBase=PrivateAttr()classConfig:extra="forbid"underscore_attrs_are_private=True@propertydef_llm_params(self)->Dict[str,Any]:params:Dict[str,Any]={"temperature":self.temperature,"n":self.n,}ifself.stop:params["stop"]=self.stopifself.max_tokensisnotNone:params["max_tokens"]=self.max_tokensreturnparams@validator("cluster_id",always=True)defset_cluster_id(cls,v:Any,values:Dict[str,Any])->Optional[str]:ifvandvalues["endpoint_name"]:raiseValueError("Cannot set both endpoint_name and cluster_id.")elifvalues["endpoint_name"]:returnNoneelifv:returnvelse:try:ifv:=get_repl_context().clusterId:returnvraiseValueError("Context doesn't contain clusterId.")exceptExceptionase:raiseValueError("Neither endpoint_name nor cluster_id was set. ""And the cluster_id cannot be automatically determined. Received"f" error: {e}")@validator("cluster_driver_port",always=True)defset_cluster_driver_port(cls,v:Any,values:Dict[str,Any])->Optional[str]:ifvandvalues["endpoint_name"]:raiseValueError("Cannot set both endpoint_name and cluster_driver_port.")elifvalues["endpoint_name"]:returnNoneelifvisNone:raiseValueError("Must set cluster_driver_port to connect to a cluster driver.")elifint(v)<=0:raiseValueError(f"Invalid cluster_driver_port: {v}")else:returnv@validator("model_kwargs",always=True)defset_model_kwargs(cls,v:Optional[Dict[str,Any]])->Optional[Dict[str,Any]]:ifv:assert"prompt"notinv,"model_kwargs must not contain key 'prompt'"assert"stop"notinv,"model_kwargs must not contain key 'stop'"returnvdef__init__(self,**data:Any):if"transform_input_fn"indataand_is_hex_string(data["transform_input_fn"]):data["transform_input_fn"]=_load_pickled_fn_from_hex_string(data=data["transform_input_fn"],allow_dangerous_deserialization=data.get("allow_dangerous_deserialization"),)if"transform_output_fn"indataand_is_hex_string(data["transform_output_fn"]):data["transform_output_fn"]=_load_pickled_fn_from_hex_string(data=data["transform_output_fn"],allow_dangerous_deserialization=data.get("allow_dangerous_deserialization"),)super().__init__(**data)ifself.model_kwargsisnotNoneandself.extra_paramsisnotNone:raiseValueError("Cannot set both extra_params and extra_params.")elifself.model_kwargsisnotNone:warnings.warn("model_kwargs is deprecated. Please use extra_params instead.",DeprecationWarning,)ifself.endpoint_name:self._client=_DatabricksServingEndpointClient(host=self.host,api_token=self.api_token,endpoint_name=self.endpoint_name,databricks_uri=self.databricks_uri,task=self.task,)elifself.cluster_idandself.cluster_driver_port:self._client=_DatabricksClusterDriverProxyClient(# type: ignore[call-arg]host=self.host,api_token=self.api_token,cluster_id=self.cluster_id,cluster_driver_port=self.cluster_driver_port,)else:raiseValueError("Must specify either endpoint_name or cluster_id/cluster_driver_port.")@propertydef_default_params(self)->Dict[str,Any]:"""Return default params."""return{"host":self.host,# "api_token": self.api_token, # Never save the token"endpoint_name":self.endpoint_name,"cluster_id":self.cluster_id,"cluster_driver_port":self.cluster_driver_port,"databricks_uri":self.databricks_uri,"model_kwargs":self.model_kwargs,"temperature":self.temperature,"n":self.n,"stop":self.stop,"max_tokens":self.max_tokens,"extra_params":self.extra_params,"task":self.task,"transform_input_fn":Noneifself.transform_input_fnisNoneelse_pickle_fn_to_hex_string(self.transform_input_fn),"transform_output_fn":Noneifself.transform_output_fnisNoneelse_pickle_fn_to_hex_string(self.transform_output_fn),}@propertydef_identifying_params(self)->Mapping[str,Any]:returnself._default_params@propertydef_llm_type(self)->str:"""Return type of llm."""return"databricks"def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Queries the LLM endpoint with the given prompt and stop sequence."""# TODO: support callbacksrequest:Dict[str,Any]={"prompt":prompt}ifself._client.llm:request.update(self._llm_params)request.update(self.model_kwargsorself.extra_params)request.update(kwargs)ifstop:request["stop"]=stopifself.transform_input_fn:request=self.transform_input_fn(**request)returnself._client.post(request,transform_output_fn=self.transform_output_fn)