"""
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""
import datetime
import inspect
import logging
from importlib.metadata import version
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.vectorstores import VectorStoreRetriever
from pydantic import ConfigDict, Field, validator
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
SUPPORTED_VECTORSTORES,
set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
App,
AuthContext,
ChainInfo,
Framework,
Model,
SemanticContext,
VectorDB,
)
from langchain_community.chains.pebblo_retrieval.utilities import (
PLUGIN_VERSION,
PebbloRetrievalAPIWrapper,
get_runtime,
)
logger = logging.getLogger(__name__)
[docs]
class PebbloRetrievalQA(Chain):
"""
Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""
combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents or not."""
retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStore to use for retrieval."""
auth_context_key: str = "auth_context" #: :meta private:
"""Authentication context for identity enforcement."""
semantic_context_key: str = "semantic_context" #: :meta private:
"""Semantic context for semantic enforcement."""
app_name: str #: :meta private:
"""App name."""
owner: str #: :meta private:
"""Owner of app."""
description: str #: :meta private:
"""Description of app."""
api_key: Optional[str] = None #: :meta private:
"""Pebblo cloud API key for app."""
classifier_url: Optional[str] = None #: :meta private:
"""Classifier endpoint."""
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent: bool = False #: :meta private:
"""Flag to check if discover payload has been sent."""
enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance is enabled or not"""
pb_client: PebbloRetrievalAPIWrapper = Field(
default_factory=PebbloRetrievalAPIWrapper
)
"""Pebblo Retrieval API client"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
_, prompt_entities = self.pb_client.check_prompt_validity(question)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
self.pb_client.send_prompt(
self.app_name,
self.retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
self.enable_prompt_gov,
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
_, prompt_entities = await self.pb_client.acheck_prompt_validity(question)
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
await self.pb_client.asend_prompt(
self.app_name,
self.retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
self.enable_prompt_gov,
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
extra="forbid",
)
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key, self.auth_context_key, self.semantic_context_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys += ["source_documents"]
return _output_keys
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "pebblo_retrieval_qa"
[docs]
@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
app_name: str,
description: str,
owner: str,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: Optional[str] = None,
classifier_location: str = "local",
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
from langchain.chains.question_answering import load_qa_chain
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
# generate app
app: App = PebbloRetrievalQA._get_app_details(
app_name=app_name,
description=description,
owner=owner,
llm=llm,
**kwargs,
)
# initialize Pebblo API client
pb_client = PebbloRetrievalAPIWrapper(
api_key=api_key,
classifier_location=classifier_location,
classifier_url=classifier_url,
)
# send app discovery request
pb_client.send_app_discover(app)
return cls(
combine_documents_chain=combine_documents_chain,
app_name=app_name,
owner=owner,
description=description,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
pb_client=pb_client,
**kwargs,
)
[docs]
@validator("retriever", pre=True, always=True)
def validate_vectorstore(
cls, retriever: VectorStoreRetriever
) -> VectorStoreRetriever:
"""
Validate that the vectorstore of the retriever is supported vectorstores.
"""
if retriever.vectorstore.__class__.__name__ not in SUPPORTED_VECTORSTORES:
raise ValueError(
f"Vectorstore must be an instance of one of the supported "
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
f"Got '{retriever.vectorstore.__class__.__name__}' instead."
)
return retriever
def _get_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
async def _aget_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
@staticmethod
def _get_app_details( # type: ignore
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs
) -> App:
"""Fetch app details. Internal method.
Returns:
App: App details.
"""
framework, runtime = get_runtime()
chains = PebbloRetrievalQA.get_chain_details(llm, **kwargs)
app = App(
name=app_name,
owner=owner,
description=description,
runtime=runtime,
framework=framework,
chains=chains,
plugin_version=PLUGIN_VERSION,
client_version=Framework(
name="langchain_community",
version=version("langchain_community"),
),
)
return app
[docs]
@classmethod
def set_discover_sent(cls) -> None:
cls._discover_sent = True
[docs]
@classmethod
def get_chain_details(
cls, llm: BaseLanguageModel, **kwargs: Any
) -> List[ChainInfo]:
"""
Get chain details.
Args:
llm (BaseLanguageModel): Language model instance.
**kwargs: Additional keyword arguments.
Returns:
List[ChainInfo]: Chain details.
"""
llm_dict = llm.__dict__
chains = [
ChainInfo(
name=cls.__name__,
model=Model(
name=llm_dict.get("model_name", llm_dict.get("model")),
vendor=llm.__class__.__name__,
),
vector_dbs=[
VectorDB(
name=kwargs["retriever"].vectorstore.__class__.__name__,
embedding_model=str(
kwargs["retriever"].vectorstore._embeddings.model
)
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
else (
str(kwargs["retriever"].vectorstore._embedding.model)
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
else None
),
)
],
),
]
return chains