Source code for langchain_community.tools.zenguard.tool

import os
from enum import Enum
from typing import Any, Dict, List, Optional, Type

import requests
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, ValidationError, validator


[docs] class Detector(str, Enum): ALLOWED_TOPICS = "allowed_subjects" BANNED_TOPICS = "banned_subjects" PROMPT_INJECTION = "prompt_injection" KEYWORDS = "keywords" PII = "pii" SECRETS = "secrets" TOXICITY = "toxicity"
[docs] class DetectorAPI(str, Enum): ALLOWED_TOPICS = "v1/detect/topics/allowed" BANNED_TOPICS = "v1/detect/topics/banned" PROMPT_INJECTION = "v1/detect/prompt_injection" KEYWORDS = "v1/detect/keywords" PII = "v1/detect/pii" SECRETS = "v1/detect/secrets" TOXICITY = "v1/detect/toxicity"
[docs] class ZenGuardInput(BaseModel): prompts: List[str] = Field( ..., min_length=1, description="Prompt to check", ) detectors: List[Detector] = Field( ..., min_length=1, description="List of detectors by which you want to check the prompt", ) in_parallel: bool = Field( default=True, description="Run prompt detection by the detector in parallel or sequentially", )
[docs] class ZenGuardTool(BaseTool): # type: ignore[override, override] name: str = "ZenGuard" description: str = ( "ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails." ) args_schema: Type[BaseModel] = ZenGuardInput return_direct: bool = True zenguard_api_key: Optional[str] = Field(default=None) _ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/" _ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY"
[docs] @validator("zenguard_api_key", pre=True, always=True, check_fields=False) def set_api_key(cls, v: str) -> str: if v is None: v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME) if v is None: raise ValidationError( "The zenguard_api_key tool option must be set either " "by passing zenguard_api_key to the tool or by setting " f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable" ) return v
@property def _api_key(self) -> str: if self.zenguard_api_key is None: raise ValueError( "API key is required for the ZenGuardTool. " "Please provide the API key by either:\n" "1. Manually specifying it when initializing the tool: " "ZenGuardTool(zenguard_api_key='your_api_key')\n" "2. Setting it as an environment variable:" f" {self._ZENGUARD_API_KEY_ENV_NAME}" ) return self.zenguard_api_key def _run( self, prompts: List[str], detectors: List[Detector], in_parallel: bool = True, ) -> Dict[str, Any]: try: postfix = None json: Optional[Dict[str, Any]] = None if len(detectors) == 1: postfix = self._convert_detector_to_api(detectors[0]) json = {"messages": prompts} else: postfix = "v1/detect" json = { "messages": prompts, "in_parallel": in_parallel, "detectors": detectors, } response = requests.post( self._ZENGUARD_API_URL_ROOT + postfix, json=json, headers={"x-api-key": self._api_key}, timeout=5, ) response.raise_for_status() return response.json() except (requests.HTTPError, requests.Timeout) as e: return {"error": str(e)} def _convert_detector_to_api(self, detector: Detector) -> str: return DetectorAPI[detector.name].value