Source code for langchain_experimental.prompt_injection_identifier.hugging_face_identifier
"""Tool for the identification of prompt injection attacks."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Any,Unionfromlangchain_core.toolsimportBaseToolfrompydanticimportField,model_validatorifTYPE_CHECKING:fromtransformersimportPipeline
[docs]classPromptInjectionException(ValueError):"""Exception raised when prompt injection attack is detected."""def__init__(self,message:str="Prompt injection attack detected",score:float=1.0):self.message=messageself.score=scoresuper().__init__(self.message)
def_model_default_factory(model_name:str="protectai/deberta-v3-base-prompt-injection-v2",)->Pipeline:try:fromtransformersimport(AutoModelForSequenceClassification,AutoTokenizer,pipeline,)exceptImportErrorase:raiseImportError("Cannot import transformers, please install with ""`pip install transformers`.")frometokenizer=AutoTokenizer.from_pretrained(model_name)model=AutoModelForSequenceClassification.from_pretrained(model_name)returnpipeline("text-classification",model=model,tokenizer=tokenizer,max_length=512,# default length of BERT modelstruncation=True,# otherwise it will fail on long prompts)
[docs]classHuggingFaceInjectionIdentifier(BaseTool):"""Tool that uses HuggingFace Prompt Injection model to detect prompt injection attacks."""name:str="hugging_face_injection_identifier"description:str=("A wrapper around HuggingFace Prompt Injection security model. ""Useful for when you need to ensure that prompt is free of injection attacks. ""Input should be any message from the user.")model:Union[Pipeline,str,None]=Field(default_factory=_model_default_factory)"""Model to use for prompt injection detection. Can be specified as transformers Pipeline or string. String should correspond to the model name of a text-classification transformers model. Defaults to ``protectai/deberta-v3-base-prompt-injection-v2`` model. """threshold:float=Field(description="Threshold for prompt injection detection.",default=0.5)"""Threshold for prompt injection detection. Defaults to 0.5."""injection_label:str=Field(description="Label of the injection for prompt injection detection.",default="INJECTION",)"""Label for prompt injection detection model. Defaults to ``INJECTION``. Value depends on the model used."""@model_validator(mode="before")@classmethoddefvalidate_environment(cls,values:dict)->Any:ifisinstance(values.get("model"),str):values["model"]=_model_default_factory(model_name=values["model"])returnvaluesdef_run(self,query:str)->str:"""Use the tool."""result=self.model(query)# type: ignorescore=(result[0]["score"]ifresult[0]["label"]==self.injection_labelelse1-result[0]["score"])ifscore>self.threshold:raisePromptInjectionException("Prompt injection attack detected",score)returnquery