Source code for langchain_community.tools.zenguard.tool

import os
from enum import Enum
from typing import Any, Dict, List, Optional

import requests
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError, validator
from langchain_core.tools import BaseTool


[docs]class Detector(str, Enum): ALLOWED_TOPICS = "allowed_subjects" BANNED_TOPICS = "banned_subjects" PROMPT_INJECTION = "prompt_injection" KEYWORDS = "keywords" PII = "pii" SECRETS = "secrets" TOXICITY = "toxicity"
[docs]class DetectorAPI(str, Enum): ALLOWED_TOPICS = "v1/detect/topics/allowed" BANNED_TOPICS = "v1/detect/topics/banned" PROMPT_INJECTION = "v1/detect/prompt_injection" KEYWORDS = "v1/detect/keywords" PII = "v1/detect/pii" SECRETS = "v1/detect/secrets" TOXICITY = "v1/detect/toxicity"
[docs]class ZenGuardInput(BaseModel): prompts: List[str] = Field( ..., min_items=1, min_length=1, description="Prompt to check", ) detectors: List[Detector] = Field( ..., min_items=1, description="List of detectors by which you want to check the prompt", ) in_parallel: bool = Field( default=True, description="Run prompt detection by the detector in parallel or sequentially", )
[docs]class ZenGuardTool(BaseTool): name: str = "ZenGuard" description: str = ( "ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails." ) args_schema = ZenGuardInput return_direct: bool = True zenguard_api_key: Optional[str] = Field(default=None) _ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/" _ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY" @validator("zenguard_api_key", pre=True, always=True, check_fields=False) def set_api_key(cls, v: str) -> str: if v is None: v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME) if v is None: raise ValidationError( "The zenguard_api_key tool option must be set either " "by passing zenguard_api_key to the tool or by setting " f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable" ) return v def _run( self, prompts: List[str], detectors: List[Detector], in_parallel: bool = True, ) -> Dict[str, Any]: try: postfix = None json: Optional[Dict[str, Any]] = None if len(detectors) == 1: postfix = self._convert_detector_to_api(detectors[0]) json = {"messages": prompts} else: postfix = "v1/detect" json = { "messages": prompts, "in_parallel": in_parallel, "detectors": detectors, } response = requests.post( self._ZENGUARD_API_URL_ROOT + postfix, json=json, headers={"x-api-key": self.zenguard_api_key}, timeout=5, ) response.raise_for_status() return response.json() except (requests.HTTPError, requests.Timeout) as e: return {"error": str(e)} def _convert_detector_to_api(self, detector: Detector) -> str: return DetectorAPI[detector.name].value