Source code for langchain_databricks.embeddings
from typing import Any, Dict, Iterator, List
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, PrivateAttr
from langchain_databricks.utils import get_deployment_client
[docs]
class DatabricksEmbeddings(Embeddings, BaseModel):
"""Databricks embedding model integration.
Setup:
Install ``langchain-databricks``.
.. code-block:: bash
pip install -U langchain-databricks
If you are outside Databricks, set the Databricks workspace
hostname and personal access token to environment variables:
.. code-block:: bash
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
export DATABRICKS_TOKEN="your-personal-access-token"
Key init args — completion params:
endpoint: str
Name of Databricks Model Serving endpoint to query.
target_uri: str
The target URI to use. Defaults to ``databricks``.
query_params: Dict[str, str]
The parameters to use for queries.
documents_params: Dict[str, str]
The parameters to use for documents.
Instantiate:
.. code-block:: python
from langchain_databricks import DatabricksEmbeddings
embed = DatabricksEmbeddings(
endpoint="databricks-bge-large-en",
)
Embed single text:
.. code-block:: python
input_text = "The meaning of life is 42"
embed.embed_query(input_text)
.. code-block:: python
[
0.01605224609375,
-0.0298309326171875,
...
]
"""
endpoint: str
"""The endpoint to use."""
target_uri: str = "databricks"
"""The parameters to use for queries."""
query_params: Dict[str, Any] = {}
"""The parameters to use for documents."""
documents_params: Dict[str, Any] = {}
"""The target URI to use."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._client = get_deployment_client(self.target_uri)
[docs]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embed(texts, params=self.documents_params)
[docs]
def embed_query(self, text: str) -> List[float]:
return self._embed([text], params=self.query_params)[0]
def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]:
embeddings: List[List[float]] = []
for txt in _chunk(texts, 20):
resp = self._client.predict(
endpoint=self.endpoint,
inputs={"input": txt, **params}, # type: ignore[arg-type]
)
embeddings.extend(r["embedding"] for r in resp["data"])
return embeddings
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]